In [None]:
import pandas as pd
import numpy as np
import os

import pickle
import time

from tqdm.notebook import tqdm
import multiprocessing as mp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch import optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

seed = 66

In [None]:
class EarlyStopping:
    def __init__(self, patience=30):
        self.patience = patience
        self.counter = 0
        self.current_lead = None
        self.stop = False

    def __call__(self, val_loss, model):
        if self.current_lead is None:
            self.current_lead = val_loss
            self.update_lead(val_loss, model)
        elif val_loss > self.current_lead:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop = True
        else:
            self.current_lead = val_loss
            self.update_lead(val_loss, model)
            self.counter = 0

    def update_lead(self, val_loss, model):
        torch.save(model.state_dict(), 'temp4earlystop.pt')

In [None]:
class DAN_Dataset(Dataset):
    def __init__(self, X):
        self.X = X
        
    def __len__(self):
        return(len(self.X))
    
    def __getitem__(self, idx):
        x = self.X[idx][:-1]
        y = self.X[idx][-1]
        return x, y

In [None]:
## Sampling for imbalance class set
validation_split = 0.85

neg_count = int(X[:,-1].sum())
pos_count = int(len(X) - neg_count)

split4pos = int(pos_count*validation_split)
split4neg = int(neg_count*validation_split)

indices4pos = list(range(pos_count))
indices4neg = list(range(neg_count))

np.random.seed(seed)
np.random.shuffle(indices4pos)
np.random.shuffle(indices4neg)

trainIdx4pos, validIdx4pos = indices4pos[:split4pos], indices4pos[split4pos:]
trainIdx4neg, validIdx4neg = indices4neg[:split4neg], indices4neg[split4neg:]

data_pos = X[X[:,-1] == 0]
data_neg = X[X[:,-1] == 1]

data_train = torch.cat([data_pos[trainIdx4pos], data_neg[trainIdx4neg]])
data_valid = torch.cat([data_pos[validIdx4pos], data_neg[validIdx4neg]])


TRAIN_BATCH_SIZE = 512
VALID_BATCH_SIZE = 9999

train_label_weights = [1/len(trainIdx4pos), 1/len(trainIdx4neg)]
valid_label_weights = [1/len(validIdx4pos), 1/len(validIdx4neg)]

train_weights = [train_label_weights[int(x[-1])] for x in data_train]
valid_weights = [valid_label_weights[int(x[-1])] for x in data_valid]

train_sampler = WeightedRandomSampler(train_weights, 240000)
valid_sampler = WeightedRandomSampler(valid_weights, 7000)

dataset_train = DAN_Dataset(data_train)
dataset_valid = DAN_Dataset(data_valid)

train_loader = DataLoader(dataset_train, 
                          batch_size=TRAIN_BATCH_SIZE, 
                          sampler=train_sampler,
                          num_workers=4)

valid_loader = DataLoader(dataset_valid, 
                          batch_size=VALID_BATCH_SIZE,
                          sampler=valid_sampler,                          
                          num_workers=4)

In [None]:
class DAN(nn.Module):
    def __init__(self, input_dim):
        super(DAN, self).__init__()
        embedding_dim = 30
        linear_dim1 = 100
        linear_dim2 = 25
        linear_dim3 = 5        
        num_class = 1
        
        self.embedding = nn.EmbeddingBag(input_dim, embedding_dim, mode='mean')        
        self.dropout1 = nn.Dropout(0.2)
        self.bn1 = nn.BatchNorm1d(embedding_dim)
        self.lin1 = nn.Linear(embedding_dim, linear_dim1)

        self.dropout2 = nn.Dropout(0.2)
        self.bn2 = nn.BatchNorm1d(linear_dim1)
        self.lin2 = nn.Linear(linear_dim1, linear_dim2)
        self.lin2_1 = nn.Linear(linear_dim2, linear_dim2)
        self.lin2_2 = nn.Linear(linear_dim2, linear_dim2)
        self.lin2_3 = nn.Linear(linear_dim2, linear_dim2)
        
        self.lin_1_1 = nn.Linear(linear_dim2, linear_dim3)        
        self.dropout_1 = nn.Dropout(0.2)
        self.bn_1 = nn.BatchNorm1d(linear_dim3)
        self.lin_1 = nn.Linear(linear_dim3, num_class)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.bn1(x)
        
        x = F.leaky_relu(self.lin1(x))
        x = self.bn2(x)
        x = self.dropout2(x)
        
        x = F.leaky_relu(self.lin2(x))    
        x = F.leaky_relu(self.lin2_1(x))
        x = F.leaky_relu(self.lin2_2(x))
        x = F.leaky_relu(self.lin2_3(x))        
        
        x = F.leaky_relu(self.lin_1_1(x))                
        x = self.bn_1(x)
        x = self.dropout_1(x)
        
        x = self.lin_1(x)
        x = torch.sigmoid(x)
        return x

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0.01)

In [None]:
LEARNING_RATE = 0.01
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
loss_func = nn.BCELoss()

In [None]:
input_dim = 81 # length of 'vocabulary'
model = DAN(input_dim=input_dim)
model.apply(init_weights)
model = model.to(device)
model.train()

patience = 30
early_stopping = EarlyStopping(patience=patience)

train_loss_plot = []
valid_loss_plot = []
NUM_EPOCH = 150
for epoch in range(NUM_EPOCH):

    model.train()
    train_losses = 0
    valid_losses = 0
    for tx, label in train_loader:
        tx = tx.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        
        label_hat = model(tx.long())
        loss = loss_func(label_hat.squeeze(), label.float())
        loss.backward()
        optimizer.step()
        
        train_losses += loss.item()
        
    with torch.no_grad():
        model.eval()
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model(x.long())
            val_loss = loss_func(y_hat.squeeze(), y.float())
            valid_losses += val_loss.item()
            
    train_loss_plot.append(train_losses/len(train_loader))
    valid_loss_plot.append(valid_losses/len(valid_loader))
    
    print('Epoch:', epoch, 
          'Train Loss: {:.4f}'.format(train_losses/len(train_loader)),
          'Valid Loss: {:.4f}'.format(valid_losses/len(valid_loader)),
          'Valid ROC: {:.4f}'.format(roc_auc_score(y.cpu(), y_hat.squeeze().detach().cpu().numpy())))
    
    early_stopping(valid_losses/len(valid_loader), model)
    if early_stopping.early_stop:
        print('topping at', epoch)
        break

# torch.save(model.state_dict(), 'DAN')