In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader

from mpl_toolkits.mplot3d import Axes3D
%matplotlib notebook

In [2]:
from scripts.AutoEncoder import AutoEncoder, AutoEncoderDataset
from scripts.utils import train_keys, target_keys

In [3]:
full_test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/test_data.pkl"
test = "/share/rcifdata/jbarr/UKAEAGroupProject/data/QLKNN_test_data.pkl"

df_full_test = pd.read_pickle(full_test)
target = df_full_test['Target']
df_full_test = df_full_test[train_keys]

df_test = pd.read_pickle(test)
df_test = df_test[train_keys]

n = 50_000

## Model 1 - AE trained on inputs that give outputs

In [4]:
path = "/share/rcifdata/jbarr/UKAEAGroupProject/logs/AutoEncoder/Run-1/Run-1/experiment_name=0-epoch=24-val_loss=0.30.ckpt"

model = AutoEncoder.load_from_checkpoint(path, n_input = 15, batch_size = 2048, epochs = 100, learning_rate = 0.002)
encoder = model.encoder

### Evaluate on inputs that give outputs

In [5]:
data_test = torch.from_numpy(df_test.values).float()
outputs_test = encoder.forward(data_test).detach().numpy()

In [6]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_test[:n,0], outputs_test[:n,1],outputs_test[:n,2])
fig.show()

<IPython.core.display.Javascript object>

In [7]:
plt.figure()

sc = plt.scatter(outputs_test[:n,0], outputs_test[:n,1], c = outputs_test[:n,2])
plt.colorbar(sc)
plt.show()

<IPython.core.display.Javascript object>

### Plot  input and output distributions

In [8]:
AE_output = model.forward(data_test).detach().numpy()
df_ae_output = pd.DataFrame(AE_output, columns = train_keys)
df_ae_output['AE'] = 'Outputs'

df_test_tmp = df_test
df_test_tmp['AE'] = 'Inputs'

In [22]:
df_compare = pd.concat([df_ae_output, df_test_tmp], ignore_index=True)
df_compare_sample = df_compare.sample(n)

In [23]:
for i in train_keys:
    plt.figure()
    sns.histplot(data = df_compare_sample, x = i, hue = "AE");
    plt.xlabel(i)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Evaluate on all inputs

In [29]:
data_test_full = torch.from_numpy(df_full_test.values).float()
outputs_test_full = encoder.forward(data_test_full).detach().numpy()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1],outputs_test_full[:n,2])
fig.show()

In [None]:
plt.figure()
sc = plt.scatter(outputs_test_full[:n,0], outputs_test_full[:n,1], c = outputs_test_full[:n,2])
plt.colorbar(sc)
plt.show()

## Model 2 - AE trained on all inputs

In [None]:
path = "/share/rcifdata/jbarr/UKAEAGroupProject/logs/AutoEncoder/Run-1/Run-1/experiment_name=0-epoch=24-val_loss=0.30.ckpt"

model_full = AutoEncoder.load_from_checkpoint(path, n_input = 15, batch_size = 2048, epochs = 100, learning_rate = 0.002)
encoder_full = model_full.encoder

### Evaluate on inputs that give outputs

In [None]:
outputs_encoder_full_test = encoder_full.forward(data_test).detach().numpy()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_encoder_full_test[:n,0], outputs_encoder_full_test[:n,1],
           outputs_encoder_full_test[:n,2])
fig.show()

### Evaluate on all inputs

In [None]:
outputs_encoder_full_test_full = encoder_full.forward(data_test_full).detach().numpy()

In [None]:
plt.figure()
plt.hist(outputs_encoder_full_test_full[:n,0], bins = 100, range = (0,75));
plt.figure()
plt.hist(outputs_encoder_full_test_full[:n,1], bins = 100, range = (-150,10));
plt.figure()
plt.hist(outputs_encoder_full_test_full[:n,2], bins = 100, range = (0,5));

In [None]:
fig = plt.figure()
# n = 100_000
ax = fig.add_subplot(111, projection='3d')
ax.scatter(outputs_encoder_full_test_full[:n,0], outputs_encoder_full_test_full[:n,1],
           outputs_encoder_full_test_full[:n,2], c = target[:n])

fig.show()

## Model with 2D Latent Space

In [15]:
path = "/share/rcifdata/jbarr/UKAEAGroupProject/logs/AutoEncoder/Run-3/Run-3/experiment_name=0-epoch=27-val_loss=0.35.ckpt"
model_2d = AutoEncoder.load_from_checkpoint(path, n_input = 15, latent_dims = 2, batch_size = 2048, epochs = 100, learning_rate = 0.002)
encoder_2d = model_2d.encoder

In [24]:
outputs_2d = encoder_2d.forward(data_test).detach().numpy()

In [26]:
plt.figure()
plt.scatter(outputs_2d[:n,0], outputs_2d[:n,1])

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7fa067187d60>

In [30]:
outputs_2d_full = encoder_2d.forward(data_test_full).detach().numpy()

In [32]:
plt.figure()
plt.scatter(outputs_2d_full[:n,0], outputs_2d_full[:n,1])

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7fa067dc29d0>

In [35]:
df  = pd.read_pickle("/share/rcifdata/jbarr/UKAEAGroupProject/data/test_data.pkl")

In [37]:
df.describe()

Unnamed: 0,Ane,Ate,Autor,Machtor,x,Zeff,gammaE,q,smag,alpha,Ani1,Ati0,normni1,Ti_Te0,logNustar,Target
count,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0,3339495.0
mean,3.346844,8.01166,0.5459562,0.1053272,0.5346982,1.844889,5077.907,2.212637,1.029978,0.3244506,3.600052,7.453991,0.01903933,1.019185,-0.4784231,0.6616653
std,75.44098,23.99218,2.469088,0.185476,0.2665847,0.6504638,22298060.0,1.22805,1.771504,0.8370681,19.75473,47.16084,0.0105326,0.2216106,0.4944099,0.4731431
min,-21045.07,-8557.695,-184.9448,-0.2921602,0.02246824,1.001057,-4299727000.0,0.6446359,-146.5966,-22.25995,-1074.658,-9524.554,4.731857e-06,0.2451745,-1.990947,0.0
25%,0.2867085,2.42435,0.0,0.0,0.3246776,1.334894,0.0,1.33031,0.1227123,0.05949192,0.6175906,2.670995,0.01609759,0.9930528,-0.8371349,0.0
50%,1.731745,5.433598,0.0,0.0,0.5461664,1.746351,0.0,1.88979,0.5492964,0.1423735,1.725996,5.550784,0.01664689,1.0,-0.557739,1.0
75%,3.844875,9.594195,0.8007807,0.1648038,0.7566727,2.17544,0.0,2.741815,1.419462,0.3785337,3.356488,9.207165,0.01833734,1.0,-0.1859302,1.0
max,21055.37,8059.684,130.7951,1.157954,0.9541317,11.6423,4299727000.0,31.89302,497.8162,121.7655,1200.819,9524.554,0.2345388,4.985357,3.529977,1.0
