In [None]:
import numpy as np
import pandas as pd
import time
import pickle
import os
import matplotlib.pyplot as plt
from impala.dbapi import connect

from sklearn.cluster import KMeans
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 55

In [None]:
TRAIN_BATCH_SIZE = 2048
VALID_BATCH_SIZE = 9999
validation_split = 0.8

split = int(df.shape[0]*validation_split)
split_idx = list(range(df.shape[0]))

np.random.seed(seed)
np.random.shuffle(split_idx)

train_idx, valid_idx = split_idx[:split], split_idx[split:]
print('train vs val:', len(train_idx), len(valid_idx))

In [None]:
class TX_Dataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return(len(self.df))
    
    def __getitem__(self, idx):
        x = self.df[idx]
        return x

In [None]:
dataset_train = TX_Dataset(df[train_idx])
dataset_valid = TX_Dataset(df[valid_idx])

train_loader = DataLoader(dataset_train, 
                          batch_size=TRAIN_BATCH_SIZE, 
                          num_workers=1,
                          pin_memory = True)

valid_loader = DataLoader(dataset_valid, 
                          batch_size=VALID_BATCH_SIZE,
                          num_workers=1, 
                          pin_memory = True)

In [None]:
class VAETX(nn.Module):
    def __init__(self):
        super(VAETX, self).__init__()
        self.fc1 = nn.Linear(464, 100)
        self.fc21 = nn.Linear(100, 10)
        self.fc22 = nn.Linear(100, 10)
        self.fc3 = nn.Linear(10, 100)
        self.fc4 = nn.Linear(100, 464)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 464))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1+log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
vae_extractor = VAETX()
vae_extractor = vae_extractor.to(device)

LEARNING_RATE = 0.01
optimizer = optim.Adam(vae_extractor.parameters(), lr = LEARNING_RATE)

beg = time.time()
NUM_EPOCH = 200
train_loss_plot = []
valid_loss_plot = []

for epoch in range(NUM_EPOCH):
    vae_extractor.train()
    train_losses = 0
    valid_losses = 0

    for x_tr in train_loader:
        x_tr = x_tr.to(device)
        optimizer.zero_grad()
        
        x_hat, mu, logvar = vae_extractor(x_tr.float())
        loss = loss_function(x_hat, x_tr, mu, logvar)

        loss.backward()
        train_losses += loss.item()
        
#         torch.nn.utils.clip_grad_norm_(vae_extractor.parameters(), 5)
        optimizer.step()

    with torch.no_grad():
        vae_extractor.eval()
        for x_val in valid_loader:
            x_val = x_val.to(device)

            x_hat_val, mu_val, logvar_val = vae_extractor(x_val.float())
            val_loss = loss_function(x_hat_val, x_val, mu_val, logvar_val)
            valid_losses += val_loss.item()
            
    train_loss_plot.append(train_losses/len(train_loader.dataset))
    valid_loss_plot.append(valid_losses/len(valid_loader.dataset))
    
    print('Epoch:', epoch, 
          'Train Loss: {:.4f}'.format(train_losses/len(train_loader.dataset)),
          'Valid Loss: {:.4f}'.format(valid_losses/len(valid_loader.dataset)))