In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
import pandas as pd
import numpy as np
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import

In [3]:
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [4]:
from tqdm.notebook import tqdm
from PIL import Image
import os

In [5]:
from torchinfo import summary

### read extracted feartures

In [6]:
ex_feat = np.load("extract_features.npy")

In [7]:
## ex_feat.shape

In [8]:
step = 5

def sliding_window(datas,steps=1,width=step):
    win_set=[]
    for i in tqdm(np.arange(0,len(datas),steps)):
        temp=datas[i:i+width]
        if temp.shape[0] == width:
            win_set.append(temp)
    return np.array(win_set)

In [9]:
data_input = sliding_window(ex_feat,steps=1,width=step)
data_input = torch.tensor(data_input, dtype=torch.float32)

  0%|          | 0/4046 [00:00<?, ?it/s]

In [10]:
data_input.shape

torch.Size([4042, 5, 32])

### build the model

In [11]:
class Lstm_encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size=32, hidden_size=16)
        self.lstm2 = nn.LSTM(input_size=16, hidden_size=16)
        
    def forward(self, x):
        #reshape x to fit the input requirement of lstm
        x = x.permute(1, 0, 2)
        output, hn = self.lstm1(x)
        output, (hidden, cell) = self.lstm2(output)
        # output include all timestep, while hidden just include the last timestep.
        hidden = hidden.repeat((output.shape[0], 1, 1))
        return hidden

In [12]:
class Lstm_decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size=16, hidden_size=32)
    
    def forward(self, x):
        # not need to reshape
        
        output, hn = self.lstm1(x)
        #reshape output
        output = output.permute(1, 0, 2)
        ## output = torch.flip(output, dims=[1])
        return output

In [13]:
class net(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.Lstm_encoder = args[0]
        self.Lstm_decoder = args[1]
    
    def forward(self, x):
        output = self.Lstm_encoder(x)
        output = self.Lstm_decoder(output)
        return output

In [14]:
model = net(Lstm_encoder(), Lstm_decoder())

### train the model

In [15]:
def train(model, device, train_loader, optimizer, epoch):
    
    model.train() #trian model
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)

        ##calculate loss
        #loss = 0
        #for i in range(data.shape[0]):
            #loss += F.mse_loss(output[i], data[i], reduction='sum')
        #loss /= data.shape[0]
        loss = F.mse_loss(output, data)
        loss.backward()
        optimizer.step()
        # print result every 10 batch
        if batch_idx % 10 == 0:
            print('Train Epoch: {} ... Batch: {} ... Loss: {:.8f}'.format(epoch, batch_idx, loss))

In [16]:
def test(model, device, test_loader):
    model.eval() #evaluate model
    test_loss = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output = model(data)
            #calculate sum loss
            test_loss += F.mse_loss(output, data, reduction='sum').item()
    
        test_loss /= len(test_loader.dataset)
        print('------------------- Test set: Average loss: {:.4f} ... Samples: {}'.format(test_loss, len(test_loader.dataset)))

### train test split

In [17]:
train_window_, val_window_ = train_test_split(data_input, test_size=0.2, random_state=2022)

In [18]:
train_loader = torch.utils.data.DataLoader(train_window_, batch_size=16,shuffle=True)
test_loader = torch.utils.data.DataLoader(val_window_, batch_size=16,shuffle=False)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [20]:
model = model.to(device)

In [24]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [25]:
epochs = 100

In [26]:
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

Train Epoch: 1 ... Batch: 0 ... Loss: 8124.04003906
Train Epoch: 1 ... Batch: 10 ... Loss: 6601.99365234
Train Epoch: 1 ... Batch: 20 ... Loss: 9207.70996094
Train Epoch: 1 ... Batch: 30 ... Loss: 7069.19873047
Train Epoch: 1 ... Batch: 40 ... Loss: 7608.19677734
Train Epoch: 1 ... Batch: 50 ... Loss: 7527.05029297
Train Epoch: 1 ... Batch: 60 ... Loss: 7438.39404297
Train Epoch: 1 ... Batch: 70 ... Loss: 7318.56494141
Train Epoch: 1 ... Batch: 80 ... Loss: 7073.29833984
Train Epoch: 1 ... Batch: 90 ... Loss: 8958.78125000
Train Epoch: 1 ... Batch: 100 ... Loss: 8402.92578125
Train Epoch: 1 ... Batch: 110 ... Loss: 7716.16796875
Train Epoch: 1 ... Batch: 120 ... Loss: 8945.07714844
Train Epoch: 1 ... Batch: 130 ... Loss: 6334.67626953
Train Epoch: 1 ... Batch: 140 ... Loss: 8864.29492188
Train Epoch: 1 ... Batch: 150 ... Loss: 7453.33935547
Train Epoch: 1 ... Batch: 160 ... Loss: 7787.67529297
Train Epoch: 1 ... Batch: 170 ... Loss: 7178.70019531
Train Epoch: 1 ... Batch: 180 ... Loss:

Train Epoch: 8 ... Batch: 20 ... Loss: 9437.02246094
Train Epoch: 8 ... Batch: 30 ... Loss: 6657.64990234
Train Epoch: 8 ... Batch: 40 ... Loss: 7728.15380859
Train Epoch: 8 ... Batch: 50 ... Loss: 6782.71875000
Train Epoch: 8 ... Batch: 60 ... Loss: 7021.23437500
Train Epoch: 8 ... Batch: 70 ... Loss: 7406.45312500
Train Epoch: 8 ... Batch: 80 ... Loss: 7340.95703125
Train Epoch: 8 ... Batch: 90 ... Loss: 7539.25927734
Train Epoch: 8 ... Batch: 100 ... Loss: 8308.68066406
Train Epoch: 8 ... Batch: 110 ... Loss: 7915.03271484
Train Epoch: 8 ... Batch: 120 ... Loss: 8590.68261719
Train Epoch: 8 ... Batch: 130 ... Loss: 8989.40332031
Train Epoch: 8 ... Batch: 140 ... Loss: 7280.27880859
Train Epoch: 8 ... Batch: 150 ... Loss: 9407.39355469
Train Epoch: 8 ... Batch: 160 ... Loss: 7264.98437500
Train Epoch: 8 ... Batch: 170 ... Loss: 8098.30175781
Train Epoch: 8 ... Batch: 180 ... Loss: 7899.45019531
Train Epoch: 8 ... Batch: 190 ... Loss: 7563.84765625
Train Epoch: 8 ... Batch: 200 ... Lo

Train Epoch: 15 ... Batch: 0 ... Loss: 7814.58154297
Train Epoch: 15 ... Batch: 10 ... Loss: 7184.00781250
Train Epoch: 15 ... Batch: 20 ... Loss: 8043.36474609
Train Epoch: 15 ... Batch: 30 ... Loss: 6789.99951172
Train Epoch: 15 ... Batch: 40 ... Loss: 7734.59521484
Train Epoch: 15 ... Batch: 50 ... Loss: 7928.39941406
Train Epoch: 15 ... Batch: 60 ... Loss: 8078.75537109
Train Epoch: 15 ... Batch: 70 ... Loss: 7537.05322266
Train Epoch: 15 ... Batch: 80 ... Loss: 6976.68750000
Train Epoch: 15 ... Batch: 90 ... Loss: 8080.70068359
Train Epoch: 15 ... Batch: 100 ... Loss: 7440.23437500
Train Epoch: 15 ... Batch: 110 ... Loss: 7821.74365234
Train Epoch: 15 ... Batch: 120 ... Loss: 7403.46044922
Train Epoch: 15 ... Batch: 130 ... Loss: 7004.55468750
Train Epoch: 15 ... Batch: 140 ... Loss: 7478.53271484
Train Epoch: 15 ... Batch: 150 ... Loss: 7066.39697266
Train Epoch: 15 ... Batch: 160 ... Loss: 7633.39404297
Train Epoch: 15 ... Batch: 170 ... Loss: 7975.81787109
Train Epoch: 15 ... B

Train Epoch: 21 ... Batch: 190 ... Loss: 7197.08740234
Train Epoch: 21 ... Batch: 200 ... Loss: 8315.19238281
------------------- Test set: Average loss: 1218619.0507 ... Samples: 809
Train Epoch: 22 ... Batch: 0 ... Loss: 7831.58203125
Train Epoch: 22 ... Batch: 10 ... Loss: 7723.28369141
Train Epoch: 22 ... Batch: 20 ... Loss: 8742.20996094
Train Epoch: 22 ... Batch: 30 ... Loss: 7547.58300781
Train Epoch: 22 ... Batch: 40 ... Loss: 7762.93750000
Train Epoch: 22 ... Batch: 50 ... Loss: 8126.70166016
Train Epoch: 22 ... Batch: 60 ... Loss: 8041.93310547
Train Epoch: 22 ... Batch: 70 ... Loss: 7403.96875000
Train Epoch: 22 ... Batch: 80 ... Loss: 6053.90771484
Train Epoch: 22 ... Batch: 90 ... Loss: 7206.35156250
Train Epoch: 22 ... Batch: 100 ... Loss: 7668.54541016
Train Epoch: 22 ... Batch: 110 ... Loss: 7303.29052734
Train Epoch: 22 ... Batch: 120 ... Loss: 8398.14746094
Train Epoch: 22 ... Batch: 130 ... Loss: 8478.76464844
Train Epoch: 22 ... Batch: 140 ... Loss: 7193.12060547
Tr

Train Epoch: 28 ... Batch: 160 ... Loss: 8467.97949219
Train Epoch: 28 ... Batch: 170 ... Loss: 7261.39404297
Train Epoch: 28 ... Batch: 180 ... Loss: 8108.70166016
Train Epoch: 28 ... Batch: 190 ... Loss: 8130.60937500
Train Epoch: 28 ... Batch: 200 ... Loss: 8002.89404297
------------------- Test set: Average loss: 1218615.5847 ... Samples: 809
Train Epoch: 29 ... Batch: 0 ... Loss: 8570.84082031
Train Epoch: 29 ... Batch: 10 ... Loss: 7179.55615234
Train Epoch: 29 ... Batch: 20 ... Loss: 8232.60058594
Train Epoch: 29 ... Batch: 30 ... Loss: 7638.56103516
Train Epoch: 29 ... Batch: 40 ... Loss: 7058.30468750
Train Epoch: 29 ... Batch: 50 ... Loss: 7703.33349609
Train Epoch: 29 ... Batch: 60 ... Loss: 7505.48388672
Train Epoch: 29 ... Batch: 70 ... Loss: 7931.75927734
Train Epoch: 29 ... Batch: 80 ... Loss: 8002.72656250
Train Epoch: 29 ... Batch: 90 ... Loss: 7640.08447266
Train Epoch: 29 ... Batch: 100 ... Loss: 10170.22558594
Train Epoch: 29 ... Batch: 110 ... Loss: 7035.53271484
T

Train Epoch: 35 ... Batch: 110 ... Loss: 7866.81591797
Train Epoch: 35 ... Batch: 120 ... Loss: 9184.68164062
Train Epoch: 35 ... Batch: 130 ... Loss: 7140.65332031
Train Epoch: 35 ... Batch: 140 ... Loss: 7313.20410156
Train Epoch: 35 ... Batch: 150 ... Loss: 8277.00585938
Train Epoch: 35 ... Batch: 160 ... Loss: 7295.03613281
Train Epoch: 35 ... Batch: 170 ... Loss: 8241.69824219
Train Epoch: 35 ... Batch: 180 ... Loss: 8793.54199219
Train Epoch: 35 ... Batch: 190 ... Loss: 6981.63525391
Train Epoch: 35 ... Batch: 200 ... Loss: 7031.50488281
------------------- Test set: Average loss: 1218611.9604 ... Samples: 809
Train Epoch: 36 ... Batch: 0 ... Loss: 9329.10156250
Train Epoch: 36 ... Batch: 10 ... Loss: 6997.87988281
Train Epoch: 36 ... Batch: 20 ... Loss: 7351.46875000
Train Epoch: 36 ... Batch: 30 ... Loss: 8475.75195312
Train Epoch: 36 ... Batch: 40 ... Loss: 7037.12353516
Train Epoch: 36 ... Batch: 50 ... Loss: 8583.38574219
Train Epoch: 36 ... Batch: 60 ... Loss: 8700.85937500

Train Epoch: 42 ... Batch: 80 ... Loss: 6779.08300781
Train Epoch: 42 ... Batch: 90 ... Loss: 7448.05029297
Train Epoch: 42 ... Batch: 100 ... Loss: 9780.87011719
Train Epoch: 42 ... Batch: 110 ... Loss: 8187.89550781
Train Epoch: 42 ... Batch: 120 ... Loss: 7828.19921875
Train Epoch: 42 ... Batch: 130 ... Loss: 8823.18750000
Train Epoch: 42 ... Batch: 140 ... Loss: 7261.08691406
Train Epoch: 42 ... Batch: 150 ... Loss: 8055.06494141
Train Epoch: 42 ... Batch: 160 ... Loss: 9217.72949219
Train Epoch: 42 ... Batch: 170 ... Loss: 7621.24707031
Train Epoch: 42 ... Batch: 180 ... Loss: 7647.70166016
Train Epoch: 42 ... Batch: 190 ... Loss: 8286.01855469
Train Epoch: 42 ... Batch: 200 ... Loss: 8726.26074219
------------------- Test set: Average loss: 1218604.5822 ... Samples: 809
Train Epoch: 43 ... Batch: 0 ... Loss: 6254.96191406
Train Epoch: 43 ... Batch: 10 ... Loss: 7422.30468750
Train Epoch: 43 ... Batch: 20 ... Loss: 8236.83886719
Train Epoch: 43 ... Batch: 30 ... Loss: 7765.9921875

Train Epoch: 49 ... Batch: 30 ... Loss: 7224.82958984
Train Epoch: 49 ... Batch: 40 ... Loss: 6540.15185547
Train Epoch: 49 ... Batch: 50 ... Loss: 6776.25634766
Train Epoch: 49 ... Batch: 60 ... Loss: 8035.90234375
Train Epoch: 49 ... Batch: 70 ... Loss: 7471.69628906
Train Epoch: 49 ... Batch: 80 ... Loss: 7619.15771484
Train Epoch: 49 ... Batch: 90 ... Loss: 7393.28759766
Train Epoch: 49 ... Batch: 100 ... Loss: 8228.04785156
Train Epoch: 49 ... Batch: 110 ... Loss: 7641.31884766
Train Epoch: 49 ... Batch: 120 ... Loss: 7822.41259766
Train Epoch: 49 ... Batch: 130 ... Loss: 8553.10449219
Train Epoch: 49 ... Batch: 140 ... Loss: 8154.59082031
Train Epoch: 49 ... Batch: 150 ... Loss: 8086.69384766
Train Epoch: 49 ... Batch: 160 ... Loss: 7594.45019531
Train Epoch: 49 ... Batch: 170 ... Loss: 7138.25097656
Train Epoch: 49 ... Batch: 180 ... Loss: 7844.48583984
Train Epoch: 49 ... Batch: 190 ... Loss: 7538.97656250
Train Epoch: 49 ... Batch: 200 ... Loss: 7243.65478516
-----------------

Train Epoch: 56 ... Batch: 60 ... Loss: 10101.72167969
Train Epoch: 56 ... Batch: 70 ... Loss: 7524.38623047
Train Epoch: 56 ... Batch: 80 ... Loss: 8026.13281250
Train Epoch: 56 ... Batch: 90 ... Loss: 8413.14550781
Train Epoch: 56 ... Batch: 100 ... Loss: 7455.92333984
Train Epoch: 56 ... Batch: 110 ... Loss: 6591.57177734
Train Epoch: 56 ... Batch: 120 ... Loss: 6529.67529297
Train Epoch: 56 ... Batch: 130 ... Loss: 10300.84570312
Train Epoch: 56 ... Batch: 140 ... Loss: 7680.92675781
Train Epoch: 56 ... Batch: 150 ... Loss: 8254.27636719
Train Epoch: 56 ... Batch: 160 ... Loss: 7768.20800781
Train Epoch: 56 ... Batch: 170 ... Loss: 6259.30615234
Train Epoch: 56 ... Batch: 180 ... Loss: 8557.51855469
Train Epoch: 56 ... Batch: 190 ... Loss: 8091.00390625
Train Epoch: 56 ... Batch: 200 ... Loss: 8833.10253906
------------------- Test set: Average loss: 1218595.6316 ... Samples: 809
Train Epoch: 57 ... Batch: 0 ... Loss: 7337.73291016
Train Epoch: 57 ... Batch: 10 ... Loss: 7820.56201

Train Epoch: 63 ... Batch: 20 ... Loss: 7701.35156250
Train Epoch: 63 ... Batch: 30 ... Loss: 6540.34228516
Train Epoch: 63 ... Batch: 40 ... Loss: 7721.25488281
Train Epoch: 63 ... Batch: 50 ... Loss: 7677.81884766
Train Epoch: 63 ... Batch: 60 ... Loss: 8284.95312500
Train Epoch: 63 ... Batch: 70 ... Loss: 6722.64208984
Train Epoch: 63 ... Batch: 80 ... Loss: 8420.89453125
Train Epoch: 63 ... Batch: 90 ... Loss: 7509.73779297
Train Epoch: 63 ... Batch: 100 ... Loss: 7208.76123047
Train Epoch: 63 ... Batch: 110 ... Loss: 7312.08154297
Train Epoch: 63 ... Batch: 120 ... Loss: 7387.56738281
Train Epoch: 63 ... Batch: 130 ... Loss: 7457.87841797
Train Epoch: 63 ... Batch: 140 ... Loss: 7184.14306641
Train Epoch: 63 ... Batch: 150 ... Loss: 7547.10937500
Train Epoch: 63 ... Batch: 160 ... Loss: 7976.29052734
Train Epoch: 63 ... Batch: 170 ... Loss: 7123.32519531
Train Epoch: 63 ... Batch: 180 ... Loss: 7363.87988281
Train Epoch: 63 ... Batch: 190 ... Loss: 7893.06103516
Train Epoch: 63 ..

------------------- Test set: Average loss: 1218589.2880 ... Samples: 809
Train Epoch: 70 ... Batch: 0 ... Loss: 7802.03125000
Train Epoch: 70 ... Batch: 10 ... Loss: 7526.74169922
Train Epoch: 70 ... Batch: 20 ... Loss: 7311.58349609
Train Epoch: 70 ... Batch: 30 ... Loss: 7208.10693359
Train Epoch: 70 ... Batch: 40 ... Loss: 6754.04052734
Train Epoch: 70 ... Batch: 50 ... Loss: 8392.19824219
Train Epoch: 70 ... Batch: 60 ... Loss: 6613.04687500
Train Epoch: 70 ... Batch: 70 ... Loss: 7507.75341797
Train Epoch: 70 ... Batch: 80 ... Loss: 8323.35449219
Train Epoch: 70 ... Batch: 90 ... Loss: 8017.76562500
Train Epoch: 70 ... Batch: 100 ... Loss: 6174.18505859
Train Epoch: 70 ... Batch: 110 ... Loss: 8024.10253906
Train Epoch: 70 ... Batch: 120 ... Loss: 8914.28613281
Train Epoch: 70 ... Batch: 130 ... Loss: 8829.53320312
Train Epoch: 70 ... Batch: 140 ... Loss: 7691.30468750
Train Epoch: 70 ... Batch: 150 ... Loss: 8533.78906250
Train Epoch: 70 ... Batch: 160 ... Loss: 7365.09765625
Tr

Train Epoch: 76 ... Batch: 190 ... Loss: 7140.09375000
Train Epoch: 76 ... Batch: 200 ... Loss: 8166.90478516
------------------- Test set: Average loss: 1218589.9172 ... Samples: 809
Train Epoch: 77 ... Batch: 0 ... Loss: 7244.70019531
Train Epoch: 77 ... Batch: 10 ... Loss: 7309.28906250
Train Epoch: 77 ... Batch: 20 ... Loss: 6587.58203125
Train Epoch: 77 ... Batch: 30 ... Loss: 7585.24072266
Train Epoch: 77 ... Batch: 40 ... Loss: 8999.35351562
Train Epoch: 77 ... Batch: 50 ... Loss: 8001.02343750
Train Epoch: 77 ... Batch: 60 ... Loss: 7147.15332031
Train Epoch: 77 ... Batch: 70 ... Loss: 7775.35009766
Train Epoch: 77 ... Batch: 80 ... Loss: 6612.90185547
Train Epoch: 77 ... Batch: 90 ... Loss: 7846.79687500
Train Epoch: 77 ... Batch: 100 ... Loss: 7583.17285156
Train Epoch: 77 ... Batch: 110 ... Loss: 7700.50488281
Train Epoch: 77 ... Batch: 120 ... Loss: 6741.53271484
Train Epoch: 77 ... Batch: 130 ... Loss: 8041.05810547
Train Epoch: 77 ... Batch: 140 ... Loss: 8387.77636719
Tr

Train Epoch: 83 ... Batch: 200 ... Loss: 6941.96582031
------------------- Test set: Average loss: 1218590.6069 ... Samples: 809
Train Epoch: 84 ... Batch: 0 ... Loss: 7861.55029297
Train Epoch: 84 ... Batch: 10 ... Loss: 7447.61425781
Train Epoch: 84 ... Batch: 20 ... Loss: 8089.04248047
Train Epoch: 84 ... Batch: 30 ... Loss: 7989.55957031
Train Epoch: 84 ... Batch: 40 ... Loss: 7309.88134766
Train Epoch: 84 ... Batch: 50 ... Loss: 7462.34716797
Train Epoch: 84 ... Batch: 60 ... Loss: 7143.26318359
Train Epoch: 84 ... Batch: 70 ... Loss: 8801.14550781
Train Epoch: 84 ... Batch: 80 ... Loss: 8419.05273438
Train Epoch: 84 ... Batch: 90 ... Loss: 8169.89306641
Train Epoch: 84 ... Batch: 100 ... Loss: 9148.53613281
Train Epoch: 84 ... Batch: 110 ... Loss: 8688.53710938
Train Epoch: 84 ... Batch: 120 ... Loss: 7855.86083984
Train Epoch: 84 ... Batch: 130 ... Loss: 8083.97363281
Train Epoch: 84 ... Batch: 140 ... Loss: 6728.73779297
Train Epoch: 84 ... Batch: 150 ... Loss: 8012.22656250
Tr

Train Epoch: 90 ... Batch: 200 ... Loss: 7913.73925781
------------------- Test set: Average loss: 1218586.0643 ... Samples: 809
Train Epoch: 91 ... Batch: 0 ... Loss: 8068.35937500
Train Epoch: 91 ... Batch: 10 ... Loss: 8065.48535156
Train Epoch: 91 ... Batch: 20 ... Loss: 8813.37207031
Train Epoch: 91 ... Batch: 30 ... Loss: 8704.62988281
Train Epoch: 91 ... Batch: 40 ... Loss: 8416.86425781
Train Epoch: 91 ... Batch: 50 ... Loss: 7696.29541016
Train Epoch: 91 ... Batch: 60 ... Loss: 8340.49121094
Train Epoch: 91 ... Batch: 70 ... Loss: 7786.51562500
Train Epoch: 91 ... Batch: 80 ... Loss: 7678.86718750
Train Epoch: 91 ... Batch: 90 ... Loss: 6962.84765625
Train Epoch: 91 ... Batch: 100 ... Loss: 7562.48583984
Train Epoch: 91 ... Batch: 110 ... Loss: 7359.86962891
Train Epoch: 91 ... Batch: 120 ... Loss: 8022.27294922
Train Epoch: 91 ... Batch: 130 ... Loss: 8844.78417969
Train Epoch: 91 ... Batch: 140 ... Loss: 7530.49365234
Train Epoch: 91 ... Batch: 150 ... Loss: 8228.26757812
Tr

Train Epoch: 97 ... Batch: 170 ... Loss: 7728.28125000
Train Epoch: 97 ... Batch: 180 ... Loss: 8484.27929688
Train Epoch: 97 ... Batch: 190 ... Loss: 8276.03417969
Train Epoch: 97 ... Batch: 200 ... Loss: 9013.59765625
------------------- Test set: Average loss: 1218586.6910 ... Samples: 809
Train Epoch: 98 ... Batch: 0 ... Loss: 7658.89990234
Train Epoch: 98 ... Batch: 10 ... Loss: 7781.05615234
Train Epoch: 98 ... Batch: 20 ... Loss: 6918.75097656
Train Epoch: 98 ... Batch: 30 ... Loss: 8286.45117188
Train Epoch: 98 ... Batch: 40 ... Loss: 7934.43847656
Train Epoch: 98 ... Batch: 50 ... Loss: 8167.10937500
Train Epoch: 98 ... Batch: 60 ... Loss: 8233.76953125
Train Epoch: 98 ... Batch: 70 ... Loss: 8827.56054688
Train Epoch: 98 ... Batch: 80 ... Loss: 8188.55175781
Train Epoch: 98 ... Batch: 90 ... Loss: 7265.00341797
Train Epoch: 98 ... Batch: 100 ... Loss: 8671.03125000
Train Epoch: 98 ... Batch: 110 ... Loss: 6732.02490234
Train Epoch: 98 ... Batch: 120 ... Loss: 7397.98291016
Tr

In [27]:
## data_input[0][:,:]

In [28]:
## torch.flip(data_input[0], dims=[0])

In [29]:
## data_input[0].shape

In [31]:
with torch.no_grad():
    data = data_input[0:1].to(device)
    output = model(data)

In [32]:
output

tensor([[[-9.9983e-01,  9.9991e-01,  9.9991e-01, -9.9991e-01,  9.9990e-01,
           9.9984e-01,  9.9988e-01, -9.9991e-01,  9.9988e-01,  9.9990e-01,
          -9.9991e-01, -9.9991e-01, -9.9991e-01,  9.9985e-01,  9.9991e-01,
          -9.9984e-01, -9.9991e-01,  9.9991e-01,  9.9991e-01,  9.9985e-01,
          -9.9991e-01,  9.9991e-01, -9.9982e-01,  9.9991e-01,  9.9991e-01,
          -9.9991e-01, -9.9991e-01, -9.9983e-01,  9.9991e-01,  4.1983e-13,
          -9.9991e-01, -9.9991e-01],
         [-9.9877e-01,  9.9933e-01,  9.9933e-01, -9.9933e-01,  9.9928e-01,
           9.9894e-01,  9.9915e-01, -9.9933e-01,  9.9914e-01,  9.9927e-01,
          -9.9931e-01, -9.9933e-01, -9.9933e-01,  9.9888e-01,  9.9933e-01,
          -9.9880e-01, -9.9933e-01,  9.9933e-01,  9.9933e-01,  9.9896e-01,
          -9.9933e-01,  9.9932e-01, -9.9868e-01,  9.9933e-01,  9.9933e-01,
          -9.9933e-01, -9.9931e-01, -9.9880e-01,  9.9931e-01,  4.7505e-13,
          -9.9932e-01, -9.9933e-01],
         [-9.9106e-01,  9.

In [33]:
data_input[0:1]

tensor([[[ -45.8218,   57.4789,   48.6242,  -82.7458,   90.2192,    2.5085,
            18.7802, -135.9315,   29.0594,   66.9606,  -36.3427,  -45.6037,
           -26.2268,   59.3037,  114.4326,  -42.4477, -107.8410,   74.5253,
            41.5111,   25.1215,    7.0498,   36.8380,   16.7537,   67.9338,
            82.8892, -139.4555,  -37.5287,   12.5863,    3.8364,  -35.1793,
           -78.3844,  -91.5303],
         [ -44.9507,   63.4614,   35.0839,  -89.5835,   98.5496,   10.8374,
            24.1075, -142.6458,   29.8852,   66.4066,  -49.6138,  -44.9583,
           -20.1390,   55.7010,  121.8717,  -43.1314, -112.7467,   72.6779,
            43.1624,   24.1007,   -0.3647,   50.5636,    3.0370,   77.0706,
            86.3826, -137.6309,  -33.2479,    6.2407,    8.1310,  -41.9642,
           -78.9873,  -92.0102],
         [ -44.7154,   67.8981,   23.0046,  -95.5193,  106.9143,   17.9277,
            28.2278, -147.4190,   29.9477,   65.4966,  -60.8076,  -44.1688,
           -13.7817,  