In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader

from mpl_toolkits.mplot3d import Axes3D
%matplotlib notebook

In [3]:
from scripts.AutoEncoder import AutoEncoder, AutoEncoderDataset, Encoder, Decoder
from scripts.utils import train_keys, target_keys, ScaleData

In [4]:
train = "/share/rcifdata/jbarr/UKAEAGroupProject/data/QLKNN_train_data.pkl"
df_train = pd.read_pickle(train)
df_train = df_train[train_keys]
df_train, scaler = ScaleData(df_train)

df_train.describe()

Unnamed: 0,ane,ate,autor,machtor,x,zeff,gammae,q,smag,alpha,ani1,ati0,normni1,ti_te0,lognustar
count,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0,17672640.0
mean,-4.2445710000000006e-17,-2.204141e-15,-1.845331e-14,9.872596e-16,1.443307e-15,5.8251e-16,-2.163914e-14,1.1166e-15,-1.048156e-16,8.099105e-16,-8.453134e-16,9.617863e-16,1.647093e-15,1.539637e-14,1.514649e-15
std,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
min,-13.84711,-6.585977,-18.84677,-2.065468,-1.685289,-1.270876,-26.0521,-1.226404,-2.207143,-4.272697,-15.42753,-7.61473,-1.81076,-3.966175,-3.129009
25%,-0.4040953,-0.5670599,-0.2922102,-0.5312093,-0.9628142,-0.7966956,0.04273997,-0.7364278,-0.6534275,-0.4755039,-0.3499717,-0.5471105,-0.2686078,-0.1080478,-0.7260516
50%,-0.2005211,-0.261672,-0.2922102,-0.5312093,-0.1396549,-0.1542101,0.04273997,-0.2877295,-0.3921313,-0.3348355,-0.1956296,-0.2503222,-0.2116415,-0.1080478,-0.1348429
75%,0.09590093,0.2096182,-0.02717016,0.1510355,0.9351353,0.5066278,0.04273997,0.4397159,0.2817471,0.08098745,0.04703473,0.2638433,-0.09624531,-0.1080478,0.5857065
max,13.78378,10.87966,29.07871,5.549705,1.679477,15.29337,13.10198,25.23304,6.818529,18.60037,14.47752,12.82571,20.74059,20.26225,6.969653


In [5]:
full_test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/test_data.pkl"
test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/QLKNN_test_data.pkl"

df_full_test = pd.read_pickle(full_test)
target = df_full_test['Target']

df_full_test = df_full_test[train_keys]
df_full_test,_ = ScaleData(df_full_test, scaler)

df_test = pd.read_pickle(test)
df_test = df_test[train_keys]
df_test,_ = ScaleData(df_test, scaler)

n = 10_000
df_test.describe()

Unnamed: 0,ane,ate,autor,machtor,x,zeff,gammae,q,smag,alpha,ani1,ati0,normni1,ti_te0,lognustar
count,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0,2209081.0
mean,0.001294028,0.001000047,0.0009622261,-0.001467884,0.001920083,0.0007137137,0.0008494952,0.001516619,0.001549922,0.0007416861,0.0006172975,0.001285639,-7.268616e-05,0.0009756333,0.0001759007
std,1.000815,1.001248,1.009336,0.9986958,0.9999536,1.00076,0.9926095,1.002432,1.001713,1.001943,0.9989907,1.0017,0.9965351,1.002741,1.001463
min,-13.53957,-6.585977,-18.84677,-2.065468,-1.685289,-1.270876,-26.0521,-1.226404,-2.207143,-4.272697,-15.42753,-7.61473,-1.81076,-3.966175,-3.129009
25%,-0.4033442,-0.5666384,-0.2922102,-0.5312093,-0.9626737,-0.795746,0.04273997,-0.7356344,-0.6529561,-0.4753175,-0.3498381,-0.5466094,-0.2686078,-0.1080478,-0.7270171
50%,-0.1995219,-0.2610066,-0.2922102,-0.5312093,-0.1381496,-0.1535345,0.04273997,-0.2869522,-0.3904461,-0.3346827,-0.1952888,-0.2497795,-0.2116415,-0.1080478,-0.1353043
75%,0.09664207,0.2101919,-0.02768643,0.1465245,0.9354394,0.50761,0.04273997,0.4407709,0.2830114,0.08240846,0.04752986,0.2651249,-0.09624531,-0.1080478,0.5869724
max,13.72443,10.87741,29.07871,5.549705,1.679477,15.29337,13.10198,25.23304,6.818529,18.60037,14.47752,12.82307,20.74059,20.26225,6.969653


## Model 1 - AE trained on inputs that give outputs

In [6]:
path = "/share/rcifdata/jbarr/UKAEAGroupProject/logs/AutoEncoder/Run-8/experiment_name=0-epoch=146-val_loss=0.25.ckpt"

model = AutoEncoder.load_from_checkpoint(path, n_input = 15, batch_size = 2048, epochs = 100, learning_rate = 0.002)
encoder = model.encoder

### Evaluate on inputs that give outputs

In [7]:
data_test = torch.from_numpy(df_test.values).float()
outputs_test = encoder.forward(data_test).detach().numpy()

In [8]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_test[:n,0], outputs_test[:n,1],outputs_test[:n,2])
fig.show()

<IPython.core.display.Javascript object>

In [9]:
plt.figure()

sc = plt.scatter(outputs_test[:n,0], outputs_test[:n,1], c = outputs_test[:n,2])
plt.colorbar(sc)
plt.show()

<IPython.core.display.Javascript object>

### Plot  input and output distributions

In [10]:
AE_output = model.forward(data_test).detach().numpy()
df_ae_output = pd.DataFrame(AE_output, columns = train_keys)
df_ae_output['AE'] = 'Outputs'

df_test_tmp = df_test
df_test_tmp['AE'] = 'Inputs'

In [11]:
df_compare = pd.concat([df_ae_output, df_test_tmp], ignore_index=True)
df_compare_sample = df_compare.sample(n)

In [12]:
for i in train_keys:
    plt.figure()
    x_min = df_compare_sample[i].quantile(0.1)
    x_max = df_compare_sample[i].quantile(0.9)
    sns.histplot(data = df_compare_sample, x = i, hue = "AE", binrange = (x_min, x_max), bins = 100);
    plt.xlabel(i)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Evaluate on all inputs

In [13]:
data_test_full = torch.from_numpy(df_full_test.values).float()
outputs_test_full = encoder.forward(data_test_full).detach().numpy()

In [14]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1],outputs_test_full[:n,2])
fig.show()

<IPython.core.display.Javascript object>

In [25]:
plt.figure()
#sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
#plt.colorbar(sc)
plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c =target[:n])

plt.show()

<IPython.core.display.Javascript object>

In [18]:
plt.figure()
#sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
#plt.colorbar(sc)
plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c =target[:n])
plt.xlim(-25,1)
plt.ylim(-25,10)
plt.show()

<IPython.core.display.Javascript object>

In [23]:
outputs_test_full[:n,1].min()

-24566556.0

In [27]:
plt.figure()
#sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
#plt.colorbar(sc)
plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c =target[:n])
plt.show()

<IPython.core.display.Javascript object>

In [42]:
target.shape

(3339495,)

In [43]:
outputs_test_full.shape

(3339495, 3)

In [40]:
idx_left = np.where(outputs_test_full[:,0] < -1.5e9)
idx_right = np.where(outputs_test_full[:,0] > -1.5e9)

data_left = outputs_test_full[idx_left[0]]
targets_left = target[idx_left[0]]
data_right = outputs_test_full[idx_right[0]]
targets_right = target[idx_right[0]]

KeyError: '[4869, 22044, 44585, 226239, 345802, 373775, 405272, 428239, 429597, 434337, 523114, 544315, 596171, 600930, 713854, 782902, 949919, 1056473, 1251270, 1300350, 1388498, 1437497, 1673729, 1730665, 1758429, 1804495, 1825842, 1913540, 2039310, 2150569, 2151426, 2426257, 2478516, 2555026, 2637977, 2654335, 2720329, 2801054, 2803980, 2885797, 2895495, 2930903, 2981620, 3012138, 3015915, 3018684, 3065757, 3113319, 3215403, 3241123] not in index'

In [39]:
idx_left[0]

array([   4869,   22044,   44585,  129819,  135325,  226239,  345802,
        373775,  405272,  428239,  429597,  434337,  523114,  544315,
        596171,  600930,  713854,  782902,  867447,  949919, 1056473,
       1251270, 1300350, 1388498, 1437497, 1673729, 1730665, 1758429,
       1804495, 1825842, 1913540, 1923978, 2039310, 2150569, 2151426,
       2249049, 2426257, 2478516, 2555026, 2637977, 2650468, 2654335,
       2720329, 2801054, 2803980, 2820356, 2885797, 2895495, 2930903,
       2981620, 3012138, 3015915, 3018684, 3065757, 3113319, 3214974,
       3215403, 3241123, 3268321])

In [36]:
outputs_test_full[[0,1,2]]

array([[-1.3052657 ,  0.02925883,  0.96055824],
       [-2.1639023 ,  0.60354245,  1.4047456 ],
       [-1.3148432 ,  0.6713982 ,  1.0025367 ]], dtype=float32)