In [1]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.distributions as dist
import torch
import copy
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback

from torch.utils.data import Dataset, DataLoader
from scripts.utils import ScaleData, train_keys
from scripts.AutoEncoder import Encoder, Decoder, AutoEncoderDataset

import itertools
import seaborn as sns
%matplotlib notebook

In [2]:
device = torch.device("cpu")

In [3]:
def ELBO(encoder, decoder, X):
        # calculate the ELBO loss
        q_z_given_x = encoder.forward(X)

        q_samples = q_z_given_x.rsample()

        ones = torch.ones(2)
        zeros = torch.zeros(2)
        
#         if torch.cuda.is_available():
#             ones.cuda()
#             zeros.cuda()
#             q_samples.cuda()
        
        latent_prior = dist.Normal(zeros, ones)
             
        log_p_z = latent_prior.log_prob(q_samples).sum(-1)

        log_q_z_given_x = q_z_given_x.log_prob(q_samples).sum(-1)

        log_p_x_given_z = decoder.forward(q_samples).log_prob(X).sum(dim=1)
        
        ELBO = log_p_x_given_z + log_p_z - log_q_z_given_x

        return ELBO

In [4]:
encoder = Encoder(2, 15, VAE = True)
decoder = Decoder(2, 15, VAE = True) 

In [5]:
train_data_path = "/share/rcifdata/jbarr/UKAEAGroupProject/data/train_data_clipped.pkl"
train_data = AutoEncoderDataset(train_data_path, columns = train_keys, train = True)
train_data.data = train_data.data.sample(100_000)
train_data.scale()

train_loader = DataLoader(train_data, shuffle = True, batch_size = 2048)

valid_data_path = "/share/rcifdata/jbarr/UKAEAGroupProject/data/valid_data_clipped.pkl"
valid_data = AutoEncoderDataset(valid_data_path, columns = train_keys, train = True)
valid_data.data = valid_data.data.sample(100_000)
valid_data.scale()

valid_loader = DataLoader(valid_data, shuffle = True, batch_size = 2048)

In [6]:
# if torch.cuda.is_available():
#     encoder = encoder.cuda()
#     decoder = decoder.cuda() 

opt_vae = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()))
N_epochs = 50 # Note that you may want to run more than 10 epochs!
for epoch in range(N_epochs):
    train_loss = 0.0
    for X in train_loader:
#         if torch.cuda.is_available():
#             X = X.cuda()

        opt_vae.zero_grad()
        loss = -ELBO(encoder, decoder, X).mean()
        loss.backward()
        opt_vae.step()
        train_loss += loss.item() * X.shape[0] / len(train_data)
    print("Epoch %d, train loss = %0.4f" % (epoch, train_loss));

Epoch 0, train loss = 21.6942
Epoch 1, train loss = 21.4322
Epoch 2, train loss = 21.3592
Epoch 3, train loss = 21.3411
Epoch 4, train loss = 21.2885
Epoch 5, train loss = 21.2405
Epoch 6, train loss = 21.1282
Epoch 7, train loss = 20.6620
Epoch 8, train loss = 18.5174
Epoch 9, train loss = 15.5022
Epoch 10, train loss = 12.4470
Epoch 11, train loss = 8.5672
Epoch 12, train loss = 6.8192
Epoch 13, train loss = 6.4929
Epoch 14, train loss = 5.4888
Epoch 15, train loss = 7.1421
Epoch 16, train loss = 5.6924
Epoch 17, train loss = 5.0145
Epoch 18, train loss = 4.3050
Epoch 19, train loss = 6.1908
Epoch 20, train loss = 4.7091
Epoch 21, train loss = 4.0397
Epoch 22, train loss = 4.1267
Epoch 23, train loss = 4.1862
Epoch 24, train loss = 3.3924
Epoch 25, train loss = 6.8176
Epoch 26, train loss = 4.7054
Epoch 27, train loss = 4.0667
Epoch 28, train loss = 3.6480
Epoch 29, train loss = 3.2045
Epoch 30, train loss = 2.8345
Epoch 31, train loss = 2.5108
Epoch 32, train loss = 3.9631
Epoch 33,

In [7]:
X_random_batch = next(iter(valid_loader))
X_random_batch.shape

torch.Size([2048, 15])

In [8]:
with torch.no_grad():    
    out = encoder(X_random_batch).sample()

In [9]:
plt.figure()
plt.scatter(out[:,0], out[:,1])

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f43834632e0>

In [10]:
test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/test_data_clipped.pkl"

df_test = pd.read_pickle(test)
df_test = df_test.sample(10_000)
target = df_test['target']
df_test_good = df_test[df_test.target == 1]
df_test_good = df_test_good[train_keys]

df_test_good,_ = ScaleData(df_test_good)

df_test_bad = df_test[df_test.target == 0]
df_test_bad = df_test_bad[train_keys]
df_test_bad,_ = ScaleData(df_test_bad)

df_test_good.describe()

Unnamed: 0,ane,ate,autor,machtor,x,zeff,gammae,q,smag,alpha,ani1,ati0,normni1,ti_te0,lognustar
count,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0,6640.0
mean,-6.978448e-17,-8.420724000000001e-17,-5.288942e-16,-2.410053e-16,7.967188e-17,2.205398e-17,-4.332545e-16,-1.855736e-17,-7.795805999999999e-19,-7.069312e-17,-2.2551410000000003e-17,1.9056880000000003e-17,1.174429e-16,8.67228e-16,3.2336920000000005e-17
std,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075,1.000075
min,-6.845754,-6.371173,-17.49049,-1.787121,-1.565196,-1.273361,-20.24214,-1.209308,-2.019406,-2.720614,-14.23615,-6.770256,-1.741873,-3.070346,-2.617869
25%,-0.4186228,-0.5633395,-0.2816705,-0.5171418,-0.9698262,-0.8005172,0.05106996,-0.7424203,-0.6533745,-0.4716698,-0.3341631,-0.533503,-0.2736305,-0.09932006,-0.7266158
50%,-0.2113955,-0.2491672,-0.2816705,-0.5171418,-0.1363426,-0.1562986,0.05106996,-0.28975,-0.3836439,-0.3332791,-0.1906041,-0.2451188,-0.2133404,-0.09932006,-0.133651
75%,0.08544694,0.1962401,-0.0324403,0.1024664,0.9326222,0.5068481,0.05106996,0.4446861,0.2867453,0.08562877,0.03469841,0.2664604,-0.1017711,-0.09932006,0.5908324
max,11.15447,8.584925,17.64793,5.511609,1.666036,12.7125,13.67118,7.202677,6.737304,18.13024,13.33185,12.26486,17.33606,20.50938,4.655364


In [11]:
data_good = torch.from_numpy(df_test_good.values).float()
data_bad = torch.from_numpy(df_test_bad.values).float()
#data_good_batch = next(iter(data_good_loader))
with torch.no_grad():
    outputs_good = encoder.forward(data_good).sample()
    outputs_bad = encoder.forward(data_bad).sample()
#    outputs_good = encoder.forward(data_good).sample().detach().numpy()
    
plt.figure()
plt.scatter(outputs_good[:,0], outputs_good[:,1])
plt.scatter(outputs_bad[:,0], outputs_bad[:,1])

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f438360ebb0>

In [12]:
AE_output = decoder.forward(encoder.forward(data_good).sample()).sample().detach().numpy()
df_ae_output = pd.DataFrame(AE_output, columns = train_keys)
df_ae_output['AE'] = 'Outputs'

df_test_tmp = df_test_good
df_test_tmp['AE'] = 'Inputs'

In [13]:
df_compare = pd.concat([df_ae_output, df_test_tmp], ignore_index=True)
df_compare_sample = df_compare.sample(10_000)

In [14]:
for i in train_keys:
    plt.figure()
    x_min = df_compare_sample[i].quantile(0.1)
    x_max = df_compare_sample[i].quantile(0.9)
    sns.histplot(data = df_compare_sample, x = i, hue = "AE", binrange = (x_min, x_max), bins = 100);
    plt.xlabel(i)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>