In [1]:
import torch
from torch import nn, optim
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_style("whitegrid"); sns.set_palette("tab10")
import pandas as pd
import numpy as np
from models.vae import VAE
from trainer import Trainer
from data.data import Data
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [2]:
data = pd.read_csv("data/playground/train.csv")
labels = pd.read_csv("data/playground/train_labels.csv")
data = pd.merge(data, labels,on="sequence", how="left")

In [3]:
d = []
for seq, grp in data.groupby("sequence"):
    state = grp.state.iloc[0]
    sers = grp.loc[:, [x for x in grp.columns if x.startswith("sensor")]]
    idx = np.tril_indices(sers.shape[1])
    corr = sers.corr().values
    corr[idx] = np.nan
    corr = pd.DataFrame(corr).stack().reset_index(level=0, drop=True)
    diffs = sers.diff()
    means = sers.mean()
    means.index = [x+"_mean" for x in means.index]
    stds = diffs.std()
    stds.index = [x+"_std" for x in stds.index]
    sers = sers.stack().reset_index(level=0, drop=True)
    feats = pd.concat([sers, corr, means, stds, pd.Series({"label" : state})])
    d.append(feats.values.reshape(-1,1))

In [4]:
df = pd.DataFrame([x.squeeze() for x in d]).dropna()
X = df.iloc[:, :-1]
y = df.iloc[:,-1]

In [5]:
train_X, test_X, train_y, test_y = train_test_split(X,y, test_size=0.3)
scaler = StandardScaler()
train_X = scaler.fit_transform(train_X)
test_X = scaler.transform(test_X)

In [6]:
train = Data(train_X, train_y.values)
test = Data(test_X, test_y.values)

In [7]:
vae = VAE(884, 32)
t = Trainer(vae, train, test)

In [8]:
losses = t.fit(30)

Epoch: 0, Train: 189586.41, Test: 187419.9
Epoch: 1, Train: 187191.51, Test: 186873.21
Epoch: 2, Train: nan, Test: nan
Epoch: 3, Train: nan, Test: nan


KeyboardInterrupt: 

In [None]:
losses.plot()

In [None]:
x,y = train[2]
z = vae.encode(x)
x_hat = vae.decode(z)

In [None]:
plt.plot(x.detach().numpy())

In [None]:
plt.plot(x_hat.detach().numpy())

In [None]:
X = []
Y = []
for i, (x,y) in enumerate(t.train_loader):
    X.append(x.flatten(1))
    Y.append(y)
    if i > 10:
        break
X = torch.cat(X)#.detach().numpy()
Y = torch.cat(Y)#.detach().numpy()

In [None]:
Z = vae.encode(X).detach().numpy()
y = Y.detach().numpy()

In [None]:
sns.scatterplot(x=Z[:,0], y = Z[:,1], hue=y)