In [1]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.distributions as dist
import torch
import copy
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from scripts.utils import ScaleData, train_keys
from scripts.AutoEncoder import AutoEncoderDataset
from tqdm import tqdm

import itertools
import seaborn as sns
#%matplotlib notebook
from tqdm import tqdm

In [2]:
# do we have a gpu 
gpu = torch.cuda.is_available()

if gpu:
    device = torch.cuda.current_device()
    print(f'GPU device: {device}')
else: 
    print('No GPU')

GPU device: 0


In [3]:
train_data_path = "/share/rcifdata/jbarr/UKAEAGroupProject/data/train_data_clipped.pkl"
train_data = AutoEncoderDataset(train_data_path, columns = train_keys, train = True)
train_data.data = train_data.data.sample(10_000)
train_data.scale()

train_loader = DataLoader(train_data, shuffle = True, batch_size = 1024)

valid_data_path = "/share/rcifdata/jbarr/UKAEAGroupProject/data/valid_data_clipped.pkl"
valid_data = AutoEncoderDataset(valid_data_path, columns = train_keys, train = True)
valid_data.data = valid_data.data.sample(10_000)
valid_data.scale()

valid_loader = DataLoader(valid_data, shuffle = True, batch_size = 1024)

In [4]:
test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/test_data_clipped.pkl"

df_test = pd.read_pickle(test)
df_test = df_test.sample(10_000)
target = df_test['target']
df_test_good = df_test[df_test.target == 1]
df_test_good = df_test_good[train_keys]

df_test_good,_ = ScaleData(df_test_good)

df_test_bad = df_test[df_test.target == 0]
df_test_bad = df_test_bad[train_keys]
df_test_bad,_ = ScaleData(df_test_bad)

df_test_good.describe()

Unnamed: 0,ane,ate,autor,machtor,x,zeff,gammae,q,smag,alpha,ani1,ati0,normni1,ti_te0,lognustar
count,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0,6634.0
mean,-5.0431970000000003e-17,-3.1746960000000003e-17,-1.349873e-16,-5.432294e-16,-1.272138e-16,2.694391e-16,7.082902e-16,7.581113e-18,-9.311130000000001e-17,-3.052109e-17,-2.6134340000000003e-17,-2.6274500000000003e-17,1.6919440000000002e-17,-8.095641e-16,-9.462166000000001e-17
std,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075
min,-6.297388,-5.020761,-14.85119,-1.839907,-1.646459,-1.28504,-18.37735,-1.228424,-2.069795,-1.53592,-15.4973,-6.300319,-1.786372,-3.111402,-2.588042
25%,-0.3982774,-0.5573673,-0.2905987,-0.5242054,-0.9668236,-0.799245,0.03722918,-0.7421419,-0.6425405,-0.4631523,-0.3389445,-0.5483892,-0.2689037,-0.1006135,-0.7183725
50%,-0.1920931,-0.2603034,-0.2905987,-0.5242054,-0.1373695,-0.1578197,0.03722918,-0.2888421,-0.3918279,-0.3263465,-0.1827159,-0.252114,-0.2104291,-0.1006135,-0.1362269
75%,0.1142332,0.1936652,-0.03188221,0.1067301,0.9292663,0.5101999,0.03722918,0.4294962,0.2581441,0.07170666,0.05613432,0.2770294,-0.09568391,-0.1006135,0.5715206
max,13.17529,8.528607,26.94902,5.408396,1.636621,15.61337,12.94859,7.053241,6.123863,16.06846,14.64891,10.44813,19.4695,20.36157,6.697918


In [5]:
data_good = torch.from_numpy(df_test_good.values).float()
data_bad = torch.from_numpy(df_test_bad.values).float()
#data_good_batch = next(iter(data_good_loader))
# with torch.no_grad():
#     outputs_good = encoder.forward(data_good).sample()
#     outputs_bad = encoder.forward(data_bad).sample()
# #    outputs_good = encoder.forward(data_good).sample().detach().numpy()
    
# plt.figure()
# plt.scatter(outputs_good[:,0], outputs_good[:,1])
# plt.scatter(outputs_bad[:,0], outputs_bad[:,1])

In [None]:
AE_output = decoder.forward(encoder.forward(data_good).sample()).sample().detach().numpy()
df_ae_output = pd.DataFrame(AE_output, columns = train_keys)
df_ae_output['AE'] = 'Outputs'

df_test_tmp = df_test_good
df_test_tmp['AE'] = 'Inputs'

In [None]:
df_compare = pd.concat([df_ae_output, df_test_tmp], ignore_index=True)
df_compare_sample = df_compare.sample(10_000)

In [None]:
for i in train_keys:
    plt.figure()
    x_min = df_compare_sample[i].quantile(0.1)
    x_max = df_compare_sample[i].quantile(0.9)
    sns.histplot(data = df_compare_sample, x = i, hue = "AE", binrange = (x_min, x_max), bins = 100);
    plt.xlabel(i)

# VAE 2

In [6]:
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE, self).__init__()
        
 
        # encoder
        self.enc1 = nn.Linear(in_features=15, out_features =150)
        self.enc2 = nn.Linear(in_features=150, out_features=75)
        self.enc3 = nn.Linear(in_features=75, out_features =25)
        
        self.mu = nn.Linear(25,3 )
        self.sigma = nn.Linear(25,3)
 
        # decoder 
        self.dec1 = nn.Linear(in_features = 3, out_features = 25)
        self.dec2 = nn.Linear(in_features = 25, out_features = 75)
        self.dec3 = nn.Linear(in_features = 75, out_features = 150)
        self.dec4 = nn.Linear(150, 15)
        
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample
 
    def forward(self, x):
        # encoding
        x = x.float()
        x = F.relu(self.enc1(x.float()))
        x = F.relu(self.enc2(x.float()))
        x = F.relu(self.enc3(x.float()))
        # get `mu` and `log_var`
        mu = self.mu(x) # the first feature values as mean
        log_var = self.sigma(x) # the other feature values as variance
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        z = F.relu(self.dec1(z.float()))
        z = F.relu(self.dec2(z.float()))
        z = F.relu(self.dec3(z.float()))
        
        reconstruction = self.dec4(z.float())
        return reconstruction.float(), mu.float(), log_var.float()

In [7]:
batch_size = 1024
lr = 1e-3
epochs = 50

model = LinearVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss(reduction = "sum")

In [8]:
def final_loss(MSE_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    MSE = MSE_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

In [9]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        #data, _ = data
        data = data.to(device)
        #data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        MSE_loss = criterion(reconstruction.float(), data.float())
        loss = final_loss(MSE_loss, mu, logvar)
        running_loss += loss.item()
        print(loss.item())
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [10]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(valid_data)/dataloader.batch_size)):
            data = data.to(device)
            #data = data.view(data.size(0), -1)
            reconstruction, mu, logvar = model(data)
            MSE_loss = criterion(reconstruction.float(), data.float())
            loss = final_loss(MSE_loss, mu, logvar)
            running_loss += loss.item()
        
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [11]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, valid_loader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 50


 33%|███▎      | 3/9 [00:01<00:01,  3.10it/s]

13383.109375
20425.330078125
15250.078125


 56%|█████▌    | 5/9 [00:01<00:00,  4.90it/s]

18296.96484375
13009.90625
14765.318359375


100%|██████████| 9/9 [00:01<00:00,  7.57it/s]

21952.34375
13236.1865234375
12437.962890625


10it [00:01,  5.58it/s]                      


7582.51953125


10it [00:00, 11.69it/s]                      


Train Loss: 15.0340
Val Loss: 15.0127
Epoch 2 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.14it/s]

12102.3369140625
15486.2578125
11516.8203125


 44%|████▍     | 4/9 [00:00<00:00, 10.89it/s]

10724.484375
20176.29296875


 67%|██████▋   | 6/9 [00:00<00:00, 10.82it/s]

12889.8935546875


 89%|████████▉ | 8/9 [00:00<00:00, 10.80it/s]

13653.2021484375
13277.8447265625
13084.6982421875


10it [00:00, 11.03it/s]                      


27193.384765625


10it [00:00, 11.70it/s]                      


Train Loss: 15.0105
Val Loss: 15.0002
Epoch 3 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.12it/s]

13034.240234375
15193.982421875
15617.34765625


 44%|████▍     | 4/9 [00:00<00:00, 10.85it/s]

13887.626953125
11120.5224609375


 67%|██████▋   | 6/9 [00:00<00:00, 10.85it/s]

22489.27734375


 89%|████████▉ | 8/9 [00:00<00:00, 10.86it/s]

10063.388671875
11048.3447265625
14122.45703125


10it [00:00, 11.04it/s]                      


23364.6328125


10it [00:00, 11.66it/s]                      


Train Loss: 14.9942
Val Loss: 14.9808
Epoch 4 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.17it/s]

12587.0380859375
10242.83984375
24438.203125


 44%|████▍     | 4/9 [00:00<00:00, 10.85it/s]

11257.4814453125
15803.4072265625


 67%|██████▋   | 6/9 [00:00<00:00, 10.84it/s]

13635.205078125


 89%|████████▉ | 8/9 [00:00<00:00, 10.81it/s]

18657.931640625
20707.501953125
12242.6318359375


10it [00:00, 11.03it/s]                      


9951.28515625


10it [00:00, 11.83it/s]                      


Train Loss: 14.9524
Val Loss: 14.8431
Epoch 5 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.20it/s]

10895.306640625
19319.1796875
13747.376953125


 44%|████▍     | 4/9 [00:00<00:00, 10.97it/s]

22342.31640625
11944.140625


 67%|██████▋   | 6/9 [00:00<00:00, 10.97it/s]

17374.673828125


 89%|████████▉ | 8/9 [00:00<00:00, 10.96it/s]

13716.64453125
12033.80078125
17929.4609375


10it [00:00, 11.16it/s]                      


7577.62109375


10it [00:00, 11.84it/s]                      


Train Loss: 14.6881
Val Loss: 14.2083
Epoch 6 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.14it/s]

19861.263671875
20156.32421875
13016.07421875


 44%|████▍     | 4/9 [00:00<00:00, 10.88it/s]

12846.0869140625
17324.677734375


 67%|██████▋   | 6/9 [00:00<00:00, 10.80it/s]

13040.33984375


 89%|████████▉ | 8/9 [00:00<00:00, 10.85it/s]

16290.849609375
9698.6962890625
9822.1767578125


10it [00:00, 11.03it/s]                      


7855.427734375


10it [00:00, 11.71it/s]                      


Train Loss: 13.9912
Val Loss: 13.5979
Epoch 7 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.01it/s]

12839.2724609375
17903.724609375
11209.724609375


 44%|████▍     | 4/9 [00:00<00:00, 10.83it/s]

10811.7021484375
13610.25


 67%|██████▋   | 6/9 [00:00<00:00, 10.86it/s]

25890.560546875


 89%|████████▉ | 8/9 [00:00<00:00, 10.89it/s]

11558.3603515625
11526.98828125
10335.162109375


10it [00:00, 11.08it/s]                      


9719.98828125


10it [00:00, 11.81it/s]                      


Train Loss: 13.5406
Val Loss: 13.2891
Epoch 8 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.09it/s]

11143.80859375
10044.4580078125
12470.654296875


 44%|████▍     | 4/9 [00:00<00:00, 10.89it/s]

13261.017578125
9668.9375


 67%|██████▋   | 6/9 [00:00<00:00, 10.84it/s]

23293.515625


 89%|████████▉ | 8/9 [00:00<00:00, 10.85it/s]

10310.3349609375
10003.6416015625
16615.966796875


10it [00:00, 11.07it/s]                      


16405.51171875


10it [00:00, 11.74it/s]                      


Train Loss: 13.3218
Val Loss: 13.2104
Epoch 9 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.12it/s]

12192.193359375
11290.3955078125
12860.875


 44%|████▍     | 4/9 [00:00<00:00, 10.87it/s]

18193.10546875
12185.1484375


 67%|██████▋   | 6/9 [00:00<00:00, 10.88it/s]

11729.8525390625


 89%|████████▉ | 8/9 [00:00<00:00, 10.86it/s]

14662.640625
10770.4609375
10956.3154296875


10it [00:00, 11.08it/s]                      


17359.09375


10it [00:00, 11.85it/s]                      


Train Loss: 13.2200
Val Loss: 13.0875
Epoch 10 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.16it/s]

10045.17578125
18151.15234375
9717.46875


 44%|████▍     | 4/9 [00:00<00:00, 10.98it/s]

16174.1513671875
15554.4794921875


 67%|██████▋   | 6/9 [00:00<00:00, 10.94it/s]

12231.236328125


 89%|████████▉ | 8/9 [00:00<00:00, 10.94it/s]

10010.0478515625
17095.548828125
11260.3994140625


10it [00:00, 11.14it/s]                      


10412.568359375


10it [00:00, 11.84it/s]                      


Train Loss: 13.0652
Val Loss: 12.9382
Epoch 11 of 50


 22%|██▏       | 2/9 [00:00<00:00, 10.99it/s]

11569.9287109375
11341.873046875
17966.13671875


 44%|████▍     | 4/9 [00:00<00:00, 10.92it/s]

10094.8037109375
8499.779296875


 67%|██████▋   | 6/9 [00:00<00:00, 10.90it/s]

15389.3583984375


 89%|████████▉ | 8/9 [00:00<00:00, 10.92it/s]

18386.576171875
10438.658203125
11919.064453125


10it [00:00, 11.16it/s]                      


12761.94921875


10it [00:00, 11.85it/s]                      


Train Loss: 12.8368
Val Loss: 12.6127
Epoch 12 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.19it/s]

16890.322265625
11635.0830078125
8100.8515625


 44%|████▍     | 4/9 [00:00<00:00, 10.98it/s]

17326.865234375
10473.9189453125


 67%|██████▋   | 6/9 [00:00<00:00, 10.95it/s]

9114.7861328125


 89%|████████▉ | 8/9 [00:00<00:00, 10.94it/s]

11139.5556640625
18670.0390625
14513.9951171875


10it [00:00, 11.13it/s]                      


7113.986328125


10it [00:00, 11.81it/s]                      


Train Loss: 12.4979
Val Loss: 12.3304
Epoch 13 of 50


 22%|██▏       | 2/9 [00:00<00:00, 10.87it/s]

9093.951171875
10908.2509765625
17457.5625


 44%|████▍     | 4/9 [00:00<00:00, 10.83it/s]

8879.703125
10287.89453125


 67%|██████▋   | 6/9 [00:00<00:00, 10.80it/s]

12857.3623046875


 89%|████████▉ | 8/9 [00:00<00:00, 10.74it/s]

22704.0546875
11720.8125
9215.9921875


10it [00:00, 10.96it/s]                      


8694.787109375


10it [00:00, 11.76it/s]                      


Train Loss: 12.1820
Val Loss: 12.0465
Epoch 14 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.06it/s]

10672.0263671875
13883.6953125
11852.38671875


 44%|████▍     | 4/9 [00:00<00:00, 10.82it/s]

7617.46142578125
9033.244140625


 67%|██████▋   | 6/9 [00:00<00:00, 10.81it/s]

8619.1865234375


 89%|████████▉ | 8/9 [00:00<00:00, 10.86it/s]

9607.3115234375
13211.1171875
19924.013671875


10it [00:00, 11.06it/s]                      


14947.404296875


10it [00:00, 11.78it/s]                      


Train Loss: 11.9368
Val Loss: 11.8346
Epoch 15 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.19it/s]

14000.71875
17004.462890625
10412.6953125


 44%|████▍     | 4/9 [00:00<00:00, 10.98it/s]

9399.771484375
10762.384765625


 67%|██████▋   | 6/9 [00:00<00:00, 11.14it/s]

16507.064453125


 89%|████████▉ | 8/9 [00:00<00:00, 11.19it/s]

15896.4091796875
8827.2431640625
8711.9599609375


10it [00:00, 11.29it/s]                      


6383.5625


10it [00:00, 11.83it/s]                      


Train Loss: 11.7906
Val Loss: 11.6953
Epoch 16 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.21it/s]

9660.8583984375
11964.708984375
15675.90234375


 44%|████▍     | 4/9 [00:00<00:00, 11.00it/s]

10470.84765625
10757.15625


 67%|██████▋   | 6/9 [00:00<00:00, 11.09it/s]

10204.7587890625


 89%|████████▉ | 8/9 [00:00<00:00, 11.43it/s]

7777.43701171875
17590.033203125
10349.9462890625


10it [00:00, 11.61it/s]                      


11709.3779296875


10it [00:00, 11.86it/s]                      


Train Loss: 11.6161
Val Loss: 11.6088
Epoch 17 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.08it/s]

11489.330078125
11488.5478515625
12519.3076171875


 44%|████▍     | 4/9 [00:00<00:00, 10.85it/s]

14147.076171875
17119.154296875


 67%|██████▋   | 6/9 [00:00<00:00, 10.84it/s]

12307.3544921875


 89%|████████▉ | 8/9 [00:00<00:00, 10.87it/s]

8483.2744140625
8739.3662109375
8507.9921875


10it [00:00, 11.07it/s]                      


10207.08203125


10it [00:00, 10.19it/s]                      


Train Loss: 11.5008
Val Loss: 11.5435
Epoch 18 of 50


 22%|██▏       | 2/9 [00:00<00:00, 10.95it/s]

9162.904296875
15778.421875
8814.0869140625


 44%|████▍     | 4/9 [00:00<00:00, 10.78it/s]

17765.44140625
9737.8896484375


 67%|██████▋   | 6/9 [00:00<00:00, 10.86it/s]

10351.1142578125


 89%|████████▉ | 8/9 [00:00<00:00, 10.90it/s]

12528.736328125
8924.0185546875
11902.4921875


10it [00:00, 11.08it/s]                      


8618.4560546875


10it [00:00, 11.87it/s]                      


Train Loss: 11.3584
Val Loss: 11.4009
Epoch 19 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.16it/s]

18329.021484375
9016.50390625
13951.470703125


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

9874.265625
9217.6328125


 67%|██████▋   | 6/9 [00:00<00:00, 10.94it/s]

7576.4873046875


 89%|████████▉ | 8/9 [00:00<00:00, 10.92it/s]

9070.6318359375
10441.24609375
8413.6552734375


10it [00:00, 11.13it/s]                      


16906.388671875


10it [00:00, 11.77it/s]                      


Train Loss: 11.2797
Val Loss: 11.3086
Epoch 20 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.18it/s]

19908.57421875
10727.7705078125
13115.3935546875


 44%|████▍     | 4/9 [00:00<00:00, 11.01it/s]

8605.5673828125
8027.361328125


 67%|██████▋   | 6/9 [00:00<00:00, 10.97it/s]

17815.13671875


 89%|████████▉ | 8/9 [00:00<00:00, 10.88it/s]

9059.5048828125
8519.712890625
9807.791015625


10it [00:00, 11.13it/s]                      


6343.28662109375


10it [00:00, 11.86it/s]                      


Train Loss: 11.1930
Val Loss: 11.3042
Epoch 21 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.20it/s]

10640.962890625
9844.5478515625
15573.20703125


 44%|████▍     | 4/9 [00:00<00:00, 10.97it/s]

10528.7734375
9737.62109375


 67%|██████▋   | 6/9 [00:00<00:00, 10.96it/s]

8238.802734375


 89%|████████▉ | 8/9 [00:00<00:00, 10.95it/s]

11697.080078125
8994.1748046875
10371.775390625


10it [00:00, 11.12it/s]                      


16215.7001953125


10it [00:00, 11.76it/s]                      


Train Loss: 11.1843
Val Loss: 11.1804
Epoch 22 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.06it/s]

9792.408203125
8815.2841796875
9215.978515625


 44%|████▍     | 4/9 [00:00<00:00, 10.88it/s]

8000.916015625
9586.1123046875


 67%|██████▋   | 6/9 [00:00<00:00, 10.80it/s]

18067.82421875


 89%|████████▉ | 8/9 [00:00<00:00, 10.80it/s]

8896.125
14018.705078125
15534.55859375


10it [00:00, 11.03it/s]                      


7129.1884765625


10it [00:00, 11.73it/s]                      


Train Loss: 10.9057
Val Loss: 11.1003
Epoch 23 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.10it/s]

8831.0302734375
11117.4609375
9297.091796875


 44%|████▍     | 4/9 [00:00<00:00, 10.81it/s]

7858.57275390625
13918.646484375


 67%|██████▋   | 6/9 [00:00<00:00, 10.87it/s]

7383.26171875


 89%|████████▉ | 8/9 [00:00<00:00, 10.90it/s]

9253.791015625
9032.658203125
20895.283203125


10it [00:00, 11.09it/s]                      


10826.22265625


10it [00:00, 11.77it/s]                      


Train Loss: 10.8414
Val Loss: 10.9850
Epoch 24 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.22it/s]

10250.337890625
11427.2431640625
25933.591796875


 44%|████▍     | 4/9 [00:00<00:00, 10.92it/s]

8603.2099609375
8280.880859375


 67%|██████▋   | 6/9 [00:00<00:00, 10.93it/s]

9440.869140625


 89%|████████▉ | 8/9 [00:00<00:00, 10.92it/s]

9500.359375
8445.2880859375
8790.9716796875


10it [00:00, 11.12it/s]                      


6330.4189453125


10it [00:00, 11.71it/s]                      


Train Loss: 10.7003
Val Loss: 10.8713
Epoch 25 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.21it/s]

10153.4091796875
9361.8583984375
8139.4853515625


 44%|████▍     | 4/9 [00:00<00:00, 10.84it/s]

12545.4638671875
15047.30859375


 67%|██████▋   | 6/9 [00:00<00:00, 10.81it/s]

7453.81201171875


 89%|████████▉ | 8/9 [00:00<00:00, 10.89it/s]

15105.6650390625
9849.7158203125
11308.7841796875


10it [00:00, 11.07it/s]                      


6835.0390625


10it [00:00, 11.79it/s]                      


Train Loss: 10.5801
Val Loss: 10.8342
Epoch 26 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.00it/s]

14651.609375
8187.97412109375
10811.4619140625


 44%|████▍     | 4/9 [00:00<00:00, 10.83it/s]

8566.712890625
8461.6533203125


 67%|██████▋   | 6/9 [00:00<00:00, 10.87it/s]

10497.8974609375


 89%|████████▉ | 8/9 [00:00<00:00, 10.87it/s]

10207.11328125
10371.51953125
13436.2578125


10it [00:00, 11.05it/s]                      


9584.2001953125


10it [00:00, 11.71it/s]                      


Train Loss: 10.4776
Val Loss: 10.7279
Epoch 27 of 50


 22%|██▏       | 2/9 [00:00<00:00, 10.99it/s]

8629.3203125
11562.5634765625
10372.6181640625


 44%|████▍     | 4/9 [00:00<00:00, 10.86it/s]

10175.099609375
8342.7177734375


 67%|██████▋   | 6/9 [00:00<00:00, 10.87it/s]

14947.2109375


 89%|████████▉ | 8/9 [00:00<00:00, 10.85it/s]

8549.720703125
8316.259765625
10090.94921875


10it [00:00, 11.02it/s]                      


12284.333984375


10it [00:00, 11.78it/s]                      


Train Loss: 10.3271
Val Loss: 10.5596
Epoch 28 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.14it/s]

7936.1201171875
8401.9853515625
16362.12109375


 44%|████▍     | 4/9 [00:00<00:00, 10.95it/s]

8594.3095703125
7889.1953125


 67%|██████▋   | 6/9 [00:00<00:00, 10.85it/s]

14360.607421875


 89%|████████▉ | 8/9 [00:00<00:00, 10.89it/s]

10214.78125
10266.666015625
11412.525390625


10it [00:00, 11.07it/s]                      


6894.7080078125


10it [00:00, 11.78it/s]                      


Train Loss: 10.2333
Val Loss: 10.3957
Epoch 29 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.15it/s]

11380.7294921875
10070.541015625
8563.8349609375


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

12350.6875
8137.68359375


 67%|██████▋   | 6/9 [00:00<00:00, 10.95it/s]

8060.533203125


 89%|████████▉ | 8/9 [00:00<00:00, 10.92it/s]

8296.7265625
13934.9453125
11365.859375


10it [00:00, 11.12it/s]                      


9048.7060546875


10it [00:00, 11.80it/s]                      


Train Loss: 10.1210
Val Loss: 10.2589
Epoch 30 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.06it/s]

10777.7236328125
7783.63427734375
12363.6669921875


 44%|████▍     | 4/9 [00:00<00:00, 10.81it/s]

8791.912109375
7425.0380859375


 67%|██████▋   | 6/9 [00:00<00:00, 10.85it/s]

8433.408203125


 89%|████████▉ | 8/9 [00:00<00:00, 10.88it/s]

11653.0546875
12775.361328125
12372.2763671875


10it [00:00, 11.08it/s]                      


7352.42724609375


10it [00:00, 11.80it/s]                      


Train Loss: 9.9729
Val Loss: 10.1480
Epoch 31 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.08it/s]

8468.75390625
9354.6953125
8901.779296875


 44%|████▍     | 4/9 [00:00<00:00, 10.93it/s]

8262.4052734375
14353.4677734375


 67%|██████▋   | 6/9 [00:00<00:00, 10.94it/s]

12054.5419921875


 89%|████████▉ | 8/9 [00:00<00:00, 10.96it/s]

10860.3046875
10979.1513671875
8823.078125


10it [00:00, 11.14it/s]                      


5933.9775390625


10it [00:00, 11.76it/s]                      


Train Loss: 9.7992
Val Loss: 10.0175
Epoch 32 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.21it/s]

7782.455078125
11950.05078125
10589.2529296875


 44%|████▍     | 4/9 [00:00<00:00, 10.98it/s]

8459.125
8963.2578125


 67%|██████▋   | 6/9 [00:00<00:00, 10.91it/s]

7887.21435546875


 89%|████████▉ | 8/9 [00:00<00:00, 10.90it/s]

7029.86474609375
11364.083984375
11022.5029296875


10it [00:00, 11.10it/s]                      


12120.9052734375


10it [00:00, 11.79it/s]                      


Train Loss: 9.7169
Val Loss: 9.8834
Epoch 33 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.17it/s]

14054.11328125
9865.427734375
7055.5244140625


 44%|████▍     | 4/9 [00:00<00:00, 10.93it/s]

9568.748046875
10795.7099609375


 67%|██████▋   | 6/9 [00:00<00:00, 10.92it/s]

8438.044921875


 89%|████████▉ | 8/9 [00:00<00:00, 10.91it/s]

7730.189453125
8442.9375
11687.607421875


10it [00:00, 11.10it/s]                      


8243.0830078125


10it [00:00, 11.87it/s]                      


Train Loss: 9.5881
Val Loss: 9.8376
Epoch 34 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.06it/s]

10469.283203125
9875.3984375
7935.236328125


 44%|████▍     | 4/9 [00:00<00:00, 10.91it/s]

12470.728515625
16014.267578125


 67%|██████▋   | 6/9 [00:00<00:00, 10.93it/s]

8555.505859375


 89%|████████▉ | 8/9 [00:00<00:00, 10.97it/s]

7967.892578125
8484.16796875
7373.86279296875


10it [00:00, 11.14it/s]                      


5855.6396484375


10it [00:00, 11.79it/s]                      


Train Loss: 9.5002
Val Loss: 9.7098
Epoch 35 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.08it/s]

11517.515625
8693.9296875
9091.05859375


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

7822.919921875
7629.79248046875


 67%|██████▋   | 6/9 [00:00<00:00, 10.89it/s]

8727.7919921875


 89%|████████▉ | 8/9 [00:00<00:00, 10.86it/s]

14439.220703125
10454.24609375
9973.958984375


10it [00:00, 11.07it/s]                      


5389.1943359375


10it [00:00, 11.74it/s]                      


Train Loss: 9.3740
Val Loss: 9.7090
Epoch 36 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.15it/s]

7055.201171875
11126.80078125
7918.51171875


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

7568.50927734375
10199.494140625


 67%|██████▋   | 6/9 [00:00<00:00, 10.92it/s]

10404.013671875


 89%|████████▉ | 8/9 [00:00<00:00, 10.92it/s]

8901.1474609375
9440.185546875
15646.0771484375


10it [00:00, 11.15it/s]                      


5156.12158203125


10it [00:00, 11.75it/s]                      


Train Loss: 9.3416
Val Loss: 9.5785
Epoch 37 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.03it/s]

7371.1533203125
9162.32421875
10942.8515625


 44%|████▍     | 4/9 [00:00<00:00, 10.77it/s]

7920.50927734375
8268.3984375


 67%|██████▋   | 6/9 [00:00<00:00, 10.82it/s]

9286.298828125


 89%|████████▉ | 8/9 [00:00<00:00, 10.88it/s]

12592.404296875
9485.5048828125
9031.23828125


10it [00:00, 11.06it/s]                      


7894.8828125


10it [00:00, 11.77it/s]                      


Train Loss: 9.1956
Val Loss: 9.5370
Epoch 38 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.20it/s]

7782.92236328125
13276.23046875
7097.9453125


 44%|████▍     | 4/9 [00:00<00:00, 11.00it/s]

8771.76171875
7444.0634765625


 67%|██████▋   | 6/9 [00:00<00:00, 10.99it/s]

12605.630859375


 89%|████████▉ | 8/9 [00:00<00:00, 10.97it/s]

8478.783203125
8284.7529296875
9520.4228515625


10it [00:00, 11.11it/s]                      


7957.33154296875


10it [00:00, 11.73it/s]                      


Train Loss: 9.1220
Val Loss: 9.4021
Epoch 39 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.13it/s]

7669.35009765625
8400.0546875
10218.482421875


 44%|████▍     | 4/9 [00:00<00:00, 10.97it/s]

7179.11669921875
9872.53125


 67%|██████▋   | 6/9 [00:00<00:00, 10.94it/s]

7994.916015625


 89%|████████▉ | 8/9 [00:00<00:00, 10.93it/s]

7933.017578125
7281.03125
18449.296875


10it [00:00, 11.13it/s]                      


5720.39892578125


10it [00:00, 11.80it/s]                      


Train Loss: 9.0718
Val Loss: 9.3907
Epoch 40 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.21it/s]

8535.07421875
9883.4384765625
7718.65771484375


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

8205.16796875
7494.6826171875


 67%|██████▋   | 6/9 [00:00<00:00, 10.94it/s]

8319.2724609375


 89%|████████▉ | 8/9 [00:00<00:00, 10.94it/s]

6717.18115234375
9185.984375
12252.365234375


10it [00:00, 11.15it/s]                      


11604.873046875


10it [00:00, 11.92it/s]                      


Train Loss: 8.9917
Val Loss: 9.3282
Epoch 41 of 50


 22%|██▏       | 2/9 [00:00<00:00, 12.35it/s]

7323.7841796875
9231.01953125
12490.72265625


 44%|████▍     | 4/9 [00:00<00:00, 11.93it/s]

9453.7822265625
9754.490234375


 67%|██████▋   | 6/9 [00:00<00:00, 11.37it/s]

7863.8486328125


 89%|████████▉ | 8/9 [00:00<00:00, 11.21it/s]

7467.87646484375
12162.630859375
8310.18359375


10it [00:00, 11.47it/s]                      


5338.7998046875


10it [00:00, 11.76it/s]                      


Train Loss: 8.9397
Val Loss: 9.2838
Epoch 42 of 50


 22%|██▏       | 2/9 [00:00<00:00, 10.88it/s]

7844.6943359375
13126.517578125


 44%|████▍     | 4/9 [00:00<00:00,  6.96it/s]

6405.984375
8507.5546875


 67%|██████▋   | 6/9 [00:00<00:00,  6.95it/s]

13194.9697265625
7738.9443359375


 89%|████████▉ | 8/9 [00:01<00:00,  8.34it/s]

7590.0029296875
9789.630859375
8022.04248046875


10it [00:01,  8.38it/s]                      


6354.236328125


10it [00:00, 11.83it/s]                      


Train Loss: 8.8575
Val Loss: 9.2371
Epoch 43 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.30it/s]

7304.2548828125
8275.05078125
8124.2607421875


 44%|████▍     | 4/9 [00:00<00:00, 11.06it/s]

10002.205078125
7247.82666015625


 67%|██████▋   | 6/9 [00:00<00:00, 11.00it/s]

10941.1142578125


 89%|████████▉ | 8/9 [00:00<00:00, 10.97it/s]

9001.927734375
7113.609375
8481.923828125


10it [00:00, 11.18it/s]                      


11270.955078125


10it [00:00, 11.73it/s]                      


Train Loss: 8.7763
Val Loss: 9.1740
Epoch 44 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.12it/s]

11576.97265625
7439.228515625
6882.28759765625


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

8236.41796875
7627.8359375


 67%|██████▋   | 6/9 [00:00<00:00, 10.90it/s]

6960.546875


 89%|████████▉ | 8/9 [00:00<00:00, 10.89it/s]

7499.0068359375
10854.1376953125
7431.8828125


10it [00:00, 11.11it/s]                      


12563.03125


10it [00:00, 11.83it/s]                      


Train Loss: 8.7071
Val Loss: 9.2588
Epoch 45 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.16it/s]

10075.568359375
6897.56884765625
7211.8359375


 44%|████▍     | 4/9 [00:00<00:00, 10.89it/s]

7236.53271484375
10202.4970703125


 67%|██████▋   | 6/9 [00:00<00:00, 10.94it/s]

9529.806640625


 89%|████████▉ | 8/9 [00:00<00:00, 10.94it/s]

8806.1611328125
7306.18798828125
7992.0322265625


10it [00:00, 11.15it/s]                      


12422.9892578125


10it [00:00, 12.15it/s]                      


Train Loss: 8.7681
Val Loss: 9.1752
Epoch 46 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.28it/s]

7614.939453125
9605.693359375
8640.806640625


 44%|████▍     | 4/9 [00:00<00:00, 11.04it/s]

9099.7177734375
7595.041015625


 67%|██████▋   | 6/9 [00:00<00:00, 10.97it/s]

8095.01025390625


 89%|████████▉ | 8/9 [00:00<00:00, 10.96it/s]

11089.720703125
8635.455078125
11396.521484375


10it [00:00, 11.19it/s]                      


5863.783203125


10it [00:00, 11.98it/s]                      


Train Loss: 8.7637
Val Loss: 9.3836
Epoch 47 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.19it/s]

8049.6884765625
14057.3388671875
7776.8076171875


 44%|████▍     | 4/9 [00:00<00:00, 10.97it/s]

7496.017578125
7748.763671875


 67%|██████▋   | 6/9 [00:00<00:00, 10.95it/s]

11391.0625


 89%|████████▉ | 8/9 [00:00<00:00, 10.97it/s]

9586.6337890625
10170.234375
6902.05615234375


10it [00:00, 11.24it/s]                      


6176.373046875


10it [00:00, 12.02it/s]                      


Train Loss: 8.9355
Val Loss: 9.1813
Epoch 48 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.16it/s]

10079.8291015625
8896.810546875
7638.4541015625


 44%|████▍     | 4/9 [00:00<00:00, 10.95it/s]

9767.625
8166.42529296875


 67%|██████▋   | 6/9 [00:00<00:00, 10.97it/s]

7281.88037109375


 89%|████████▉ | 8/9 [00:00<00:00, 11.03it/s]

8862.4296875
7027.02490234375
8700.365234375


10it [00:00, 11.18it/s]                      


10513.6123046875


10it [00:00, 11.77it/s]                      


Train Loss: 8.6934
Val Loss: 9.0166
Epoch 49 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.17it/s]

12116.44921875
7264.4599609375
7897.9609375


 44%|████▍     | 4/9 [00:00<00:00, 10.94it/s]

6831.83642578125
8816.1435546875


 67%|██████▋   | 6/9 [00:00<00:00, 10.92it/s]

8061.59716796875


 89%|████████▉ | 8/9 [00:00<00:00, 10.93it/s]

8207.68359375
8310.576171875
10623.615234375


10it [00:00, 11.14it/s]                      


7341.025390625


10it [00:00, 11.81it/s]                      


Train Loss: 8.5471
Val Loss: 8.9488
Epoch 50 of 50


 22%|██▏       | 2/9 [00:00<00:00, 11.06it/s]

10888.0810546875
8902.599609375
7678.1162109375


 44%|████▍     | 4/9 [00:00<00:00, 10.84it/s]

9123.388671875
7667.69384765625


 67%|██████▋   | 6/9 [00:00<00:00, 10.83it/s]

7572.30712890625


 89%|████████▉ | 8/9 [00:00<00:00, 10.78it/s]

8539.0078125
8291.490234375
10765.369140625


10it [00:00, 10.99it/s]                      


4798.15234375


10it [00:00, 11.76it/s]                      

Train Loss: 8.4226
Val Loss: 8.8716





In [None]:
model.to('cpu')
AE_output2,_,_ = model(data_good)
AE_output2 = AE_output2.detach().numpy()
df_ae_output2 = pd.DataFrame(AE_output2, columns = train_keys)
df_ae_output2['AE'] = 'Outputs'

df_test_tmp = df_test_good
df_test_tmp['AE'] = 'Inputs'

df_compare2 = pd.concat([df_ae_output2, df_test_tmp], ignore_index=True)
df_compare_sample2= df_compare2.sample(10_000)

In [None]:
for i in train_keys:
    plt.figure()
    x_min = df_compare_sample2[i].quantile(0.1)
    x_max = df_compare_sample2[i].quantile(0.9)
    sns.histplot(data = df_compare_sample2, x = i, hue = "AE", binrange = (x_min, x_max), bins = 100);
    plt.xlabel(i)