In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader

from mpl_toolkits.mplot3d import Axes3D
%matplotlib notebook

In [2]:
from scripts.AutoEncoder import AutoEncoder, AutoEncoderDataset, Encoder, Decoder
from scripts.utils import train_keys, target_keys, ScaleData

In [3]:
train = "/share/rcifdata/jbarr/UKAEAGroupProject/data/train_data_clipped.pkl"
df_train = pd.read_pickle(train)
df_train = df_train[train_keys]
df_train, scaler = ScaleData(df_train)

df_train.describe()

Unnamed: 0,ane,ate,autor,machtor,x,zeff,gammae,q,smag,alpha,ani1,ati0,normni1,ti_te0,lognustar
count,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0,26715960.0
mean,-7.554751000000001e-17,2.325436e-16,-1.969976e-14,-2.478637e-14,-1.175991e-15,-7.503728e-15,4.034983e-14,-3.315205e-15,-1.910605e-15,4.421118e-16,-5.968948e-16,2.790048e-16,2.490396e-15,1.438155e-14,1.09086e-16
std,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
min,-224.9306,-381.4446,-74.574,-2.143341,-1.921246,-1.297033,-171.0608,-1.27779,-84.74415,-26.7842,-55.09811,-210.9354,-1.799047,-3.488317,-3.060594
25%,-0.03258431,-0.2497779,-0.2199693,-0.5673091,-0.7880002,-0.7826355,0.0001095864,-0.7187233,-0.5207354,-0.314272,-0.1510736,-0.1056034,-0.2791258,-0.1170805,-0.7257788
50%,-0.01714946,-0.1156905,-0.2199693,-0.5673091,0.04251098,-0.1509808,0.0001095864,-0.2630268,-0.276017,-0.2158643,-0.09455876,-0.04192553,-0.2264989,-0.08693051,-0.1606022
75%,0.005464286,0.06980049,0.101768,0.3190794,0.8328257,0.5075416,0.0001095864,0.4315107,0.224457,0.06491163,-0.01128275,0.03915007,-0.06719322,-0.08693051,0.5922162
max,224.9695,358.5519,52.36406,5.67917,1.572247,15.05446,171.0611,24.17821,285.1771,144.0235,61.17973,210.6058,20.36316,17.87184,8.10916


In [16]:
test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/test_data_clipped.pkl"


df_test = pd.read_pickle(test)
target = df_test['target']
df_test_good = df_test[df_test.target == 1]
df_test_good = df_test[train_keys]

df_test_good,_ = ScaleData(df_test_good, scaler)

df_test_all = df_test[train_keys]
df_test_all,_ = ScaleData(df_test_all, scaler)

n = 10_000
df_test_good.describe()

Unnamed: 0,ane,ate,autor,machtor,x,zeff,gammae,q,smag,alpha,ani1,ati0,normni1,ti_te0,lognustar
count,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0
mean,0.0001480916,-0.0009981783,-0.0004765438,0.0008691511,-0.0005182396,-0.0003884115,0.0003116066,-0.0004423686,-6.358724e-06,-0.0001101691,0.001331482,0.0001694135,-0.0003797497,-0.000478921,-0.0005054504
std,0.8061888,1.068407,0.9926565,1.000534,0.9996224,0.9995123,0.887109,1.000412,1.016921,0.9927253,1.009475,1.043631,0.9952746,0.998619,1.000274
min,-224.9306,-381.4446,-74.574,-2.143341,-1.921246,-1.297033,-171.0608,-1.27779,-84.74415,-26.7842,-55.09811,-210.9354,-1.799047,-3.488317,-3.060594
25%,-0.03255359,-0.2498094,-0.2199693,-0.5673091,-0.7880403,-0.7840538,0.0001095864,-0.7192161,-0.5208166,-0.3143393,-0.1510736,-0.1056744,-0.2783585,-0.1182357,-0.7262394
50%,-0.01711142,-0.1158031,-0.2199693,-0.5673091,0.04248436,-0.1518029,0.0001095864,-0.2634447,-0.2759386,-0.2160454,-0.09443357,-0.04194699,-0.2264526,-0.08693051,-0.1609747
75%,0.005470228,0.06947443,0.1019715,0.3217102,0.8318275,0.5075416,0.0001095864,0.4306439,0.2235747,0.06402999,-0.01111474,0.03896573,-0.06671414,-0.08693051,0.5912568
max,224.9695,358.5519,52.36406,5.67917,1.572247,15.05446,171.0611,24.17821,285.1771,144.0235,61.17973,210.6058,20.36316,17.87184,8.10916


## Model 1 - AE trained on inputs that give outputs

In [6]:
path = glob.glob("/share/rcifdata/jbarr/UKAEAGroupProject/logs/AutoEncoder/Run-13/*")[0]

model = AutoEncoder.load_from_checkpoint(path, n_input = 15, batch_size = 2048, epochs = 150, learning_rate = 0.0025)
encoder = model.encoder

### Evaluate on inputs that give outputs

In [7]:
data_good = torch.from_numpy(df_test_good.values).float()
outputs_good = encoder.forward(data_good).detach().numpy()

In [8]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_good[:n,0], outputs_good[:n,1],outputs_good[:n,2])
fig.show()

<IPython.core.display.Javascript object>

In [12]:
plt.figure()

sc = plt.scatter(outputs_good[:n,0], outputs_good[:n,1], c = outputs_good[:n,2])
plt.xlim(-20,20)
plt.ylim(-50,0)
plt.colorbar(sc)
plt.show()

<IPython.core.display.Javascript object>

### Plot  input and output distributions

In [13]:
AE_output = model.forward(data_good).detach().numpy()
df_ae_output = pd.DataFrame(AE_output, columns = train_keys)
df_ae_output['AE'] = 'Outputs'

df_test_tmp = df_test_good
df_test_tmp['AE'] = 'Inputs'

In [14]:
df_compare = pd.concat([df_ae_output, df_test_tmp], ignore_index=True)
df_compare_sample = df_compare.sample(n)

In [15]:
for i in train_keys:
    plt.figure()
    x_min = df_compare_sample[i].quantile(0.1)
    x_max = df_compare_sample[i].quantile(0.9)
    sns.histplot(data = df_compare_sample, x = i, hue = "AE", binrange = (x_min, x_max), bins = 100);
    plt.xlabel(i)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Evaluate on all inputs

In [17]:
data_test_full = torch.from_numpy(df_test_all.values).float()
outputs_test_full = encoder.forward(data_test_full).detach().numpy()

In [18]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1],outputs_test_full[:n,2])
fig.show()

<IPython.core.display.Javascript object>

In [21]:
plt.figure()
#sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
#plt.colorbar(sc)
plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = target[:n])
plt.colorbar()

plt.show()

<IPython.core.display.Javascript object>

In [18]:
plt.figure()
#sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
#plt.colorbar(sc)
plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c =target[:n])
plt.xlim(-25,1)
plt.ylim(-25,10)
plt.show()

<IPython.core.display.Javascript object>

In [23]:
outputs_test_full[:n,1].min()

-24566556.0

In [27]:
plt.figure()
#sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
#plt.colorbar(sc)
plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c =target[:n])
plt.show()

<IPython.core.display.Javascript object>

In [42]:
target.shape

(3339495,)

In [43]:
outputs_test_full.shape

(3339495, 3)

In [40]:
idx_left = np.where(outputs_test_full[:,0] < -1.5e9)
idx_right = np.where(outputs_test_full[:,0] > -1.5e9)

data_left = outputs_test_full[idx_left[0]]
targets_left = target[idx_left[0]]
data_right = outputs_test_full[idx_right[0]]
targets_right = target[idx_right[0]]

KeyError: '[4869, 22044, 44585, 226239, 345802, 373775, 405272, 428239, 429597, 434337, 523114, 544315, 596171, 600930, 713854, 782902, 949919, 1056473, 1251270, 1300350, 1388498, 1437497, 1673729, 1730665, 1758429, 1804495, 1825842, 1913540, 2039310, 2150569, 2151426, 2426257, 2478516, 2555026, 2637977, 2654335, 2720329, 2801054, 2803980, 2885797, 2895495, 2930903, 2981620, 3012138, 3015915, 3018684, 3065757, 3113319, 3215403, 3241123] not in index'

In [39]:
idx_left[0]

array([   4869,   22044,   44585,  129819,  135325,  226239,  345802,
        373775,  405272,  428239,  429597,  434337,  523114,  544315,
        596171,  600930,  713854,  782902,  867447,  949919, 1056473,
       1251270, 1300350, 1388498, 1437497, 1673729, 1730665, 1758429,
       1804495, 1825842, 1913540, 1923978, 2039310, 2150569, 2151426,
       2249049, 2426257, 2478516, 2555026, 2637977, 2650468, 2654335,
       2720329, 2801054, 2803980, 2820356, 2885797, 2895495, 2930903,
       2981620, 3012138, 3015915, 3018684, 3065757, 3113319, 3214974,
       3215403, 3241123, 3268321])

In [36]:
outputs_test_full[[0,1,2]]

array([[-1.3052657 ,  0.02925883,  0.96055824],
       [-2.1639023 ,  0.60354245,  1.4047456 ],
       [-1.3148432 ,  0.6713982 ,  1.0025367 ]], dtype=float32)