# Final Assignment - Implementing VAT
Implementing stage 1 and evaluating as requested in stage 4.

> `By: Yuval Rehsef`

> `ID: 314805045`

## Import Libraries
**`note: set the path to the datasets folder here`**

In [9]:
import torch
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.metrics import roc_auc_score, precision_recall_curve, roc_curve, precision_score, auc, confusion_matrix, make_scorer, accuracy_score
from scipy.stats import uniform, norm
from skorch import NeuralNetClassifier
import skorch
from skorch.utils import to_tensor
from skorch.callbacks import Checkpoint
import pandas as pd
from fastprogress import progress_bar, master_bar
from pathlib import Path
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from datetime import datetime

data_path = Path('datasets')

## Load Datasets
`Loading the datasets and preprocessing with sklearns StandardScaler.`

In [10]:
dfs = {}
for f in progress_bar(list(data_path.glob('*.csv'))):
    df = pd.read_csv(f)
    X = df[df.columns[:-1]]
    y = df[df.columns[-1]]
    if y.dtype != np.int64:
        y = y.replace({'N':0, 'P':1}) # only one datasets label is not an integer, deal with it
    y = np.eye(y.max() + 1)[y.values]
    dfs[f.stem] = (X,y)

## VAT Implementation
`Implementing the VAT loss and additional necessary functions.`

`Link to the official implementation:`
[GitHub link](https://github.com/lyakaap/VAT-pytorch)

In [11]:
@contextlib.contextmanager
def _disable_tracking_bn_stats(model):

    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True
            
    model.apply(switch_attr)
    yield
    model.apply(switch_attr)


def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
    return d


class VATLoss(nn.Module):

    def __init__(self, alpha=1.0, xi=1e-6, eps=1.0, ip=1):
        """VAT loss
        :param xi: hyperparameter of VAT (default: 10.0)
        :param eps: hyperparameter of VAT (default: 1.0)
        :param ip: iteration times of computing adv noise (default: 1)
        """
        super(VATLoss, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip
        self.alpha = alpha

    def forward(self, model, x):
        with torch.no_grad():
            pred = F.softmax(model(x), dim=1)

        # prepare random unit tensor
        d = torch.rand(x.shape).sub(0.5).to(x.device)
        d = _l2_normalize(d)

        with _disable_tracking_bn_stats(model):
            # calc adversarial direction
            for _ in range(self.ip):
                d.requires_grad_()
                pred_hat = model(x + self.xi * d)
                logp_hat = F.log_softmax(pred_hat, dim=1)
                adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
                adv_distance.backward()
                d = _l2_normalize(d.grad)
                model.zero_grad()
    
            # calc LDS
            r_adv = d * self.eps
            pred_hat = model(x + r_adv)
            logp_hat = F.log_softmax(pred_hat, dim=1)
            lds = F.kl_div(logp_hat, pred, reduction='batchmean')

        return self.alpha*lds

## Model Class

`Setting up the pytorch model and the skorch object. Skorch is a wrapper for pytorch for using pytorch model as sklearn model.`

In [12]:
class VATNet(nn.Module):
    def __init__(self, input_size, classes): 
        super().__init__()
        self.model = nn.Sequential(
                                    nn.Linear(input_size, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, 16),
                                    nn.ReLU(),
                                    nn.Linear(16, classes)
                                    )
    def forward(self, x):
        return self.model(x)

class MySkorch(skorch.NeuralNet):
    def train_step_single(self, x, y, **fit_params):
        self.module_.train()
        x, y = to_tensor((x.float(), y.long()), device=self.device)
        y_pred = self.module_(x)
        ce_loss = F.cross_entropy(y_pred, y)
        vat_loss = self.criterion_(self.module_, x)
        loss = ce_loss + vat_loss
        loss.backward()
        return {'loss': loss, 'y_pred': y_pred}

    def validation_step(self, x, y, **fit_params):
        self.module_.eval()
        x, y = to_tensor((x.float(), y.long()), device=self.device)
        y_pred = self.module_(x)
        ce_loss = F.cross_entropy(y_pred, y)
        vat_loss = self.criterion_(self.module_, x)
        loss = ce_loss + vat_loss
        return {'loss': loss, 'y_pred': y_pred}
    
    def evaluation_step(self, x, training=False):
        self.check_is_fitted()
        x = to_tensor(x.float(), device=self.device)
        with torch.set_grad_enabled(training):
            self.module_.train(training)
            return F.softmax(self.module_(x), dim=1)
    
    def predict(self, X):
        return self.predict_proba(X).argmax(1)

## Create Pipeline
`Define the skorch model & functions to be used by a sklearn pipeline.`

In [13]:
def make_VAT_loss(alpha=1.0, VAT_eps=2.0, VAT_xi=10.0, VAT_ip=1):
    VAT_loss = VATLoss(alpha=alpha, xi=VAT_xi, eps=VAT_eps, ip=VAT_ip)
    return VAT_loss

def my_auc(y_true, y_pred):  
    return roc_auc_score(y_true, y_pred, average='macro', multi_class='ovo')

def get_scores(model, x, y):
    preds = model.predict(x)
    proba = model.predict_proba(x)
    cm = confusion_matrix(y.argmax(1), preds)

    if len(np.unique(y.argmax(1))) == 2:
        fp = cm[0,1]
        fn = cm[1,0]
        tp = cm[1,1]
        tn = cm[0,0]
        auc_res = roc_auc_score(y.argmax(1), proba[:,1], average='macro', multi_class='ovo')
    else:
        fp = (cm.sum(axis=0) - np.diag(cm)).astype(float)
        fn = (cm.sum(axis=1) - np.diag(cm)).astype(float)
        tp = np.diag(cm).astype(float)
        tn = (cm.sum() - (fp + fn + tp)).astype(float)
        auc_res = roc_auc_score(y.argmax(1), proba, average='macro', multi_class='ovo')

    tpr = tp/(tp+fn)
    fpr = fp/(fp+tn)

    precision, recall, _ = precision_recall_curve(y.ravel(), proba.ravel())

    return {'Accuracy': accuracy_score(y.argmax(1), preds),
            'TPR': tpr.mean(),
            'FPR': fpr.mean(),
            'Precision': precision_score(y.argmax(1), preds, average='macro'),
            'AUC': auc_res,
            'PR Curve': auc(recall, precision)
    }

inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)
outer_cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

hypers = {
    'net__lr':uniform(loc=1e-4, scale=(1e-2 - 1e-4)),
    'net__criterion__VAT_eps':uniform(loc=1, scale=2),
    'net__criterion__alpha':uniform(loc=0.5, scale=1)
}

## Evaluation Over the Nested Cross-Validation
`Evaluating the model performance over the 20 datasets, saving and printing the results.`

In [None]:
scores_dfs = pd.DataFrame(columns=['Dataset Name', 'Algorithm Name', 'Cross Validation', 'Hyperparameters Values', 
                                   'Accuracy', 'TPR', 'FPR', 'Precision', 'AUC', 'PR Curve', 'Training Time', 'Inference Time'])

displayer = display(scores_dfs, display_id=True, clear=True)
for data_name, (X,y) in progress_bar(dfs.items()):
    print(data_name)
    best_score = 0
    X = X.to_numpy()
    cv_counter=0
    for train_index, test_index in progress_bar(outer_cv.split(X,y.argmax(1)), total=10):
        cv_counter += 1
        x_train, x_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        if y_train.argmax(1).min() == 1:
            y_train = y_train.argmax(1) - 1
            y_test = y_test[:, 1:]
        else:
            y_train = y_train.argmax(1)

        net = MySkorch(VATNet, max_epochs=10, criterion=make_VAT_loss, module__input_size=x_train.shape[-1], module__classes=y_train.max() + 1, 
                       optimizer=optim.Adam, batch_size=32, train_split=None, device='cpu', verbose=0)
        pipe = Pipeline([
            ('scale', StandardScaler()),
            ('net', net)
        ])
        
        rsg = RandomizedSearchCV(pipe, hypers, n_iter=50, n_jobs=-1, scoring=make_scorer(accuracy_score), cv=inner_cv, random_state=0)
        
        train_start = datetime.now()
        rsg.fit(X=x_train, y=y_train)
        train_time = (datetime.now() - train_start)
        
        if train_time.seconds == 0:
            train_time = f'{train_time.microseconds}ms'
        else:
            train_time = f'{train_time.seconds}s'

        inference_indexes = np.random.choice(range(len(x_train)), size=(1000,))
        inference_start = datetime.now()
        _ = rsg.predict(x_train[inference_indexes])
        inference_time = (datetime.now() - inference_start)

        if inference_time.seconds == 0:
            inference_time = f'{inference_time.microseconds}ms'
        else:
            inference_time = f'{inference_time.seconds}s'

        results = get_scores(rsg, x_test, y_test)
        results['Dataset Name'] = data_name
        results['Algorithm Name'] = 'VAT'
        results['Cross Validation'] = cv_counter
        results['Hyperparameters Values'] = rsg.best_params_
        results['Training Time'] = train_time
        results['Inference Time'] = inference_time

        scores_dfs = scores_dfs.append(results, ignore_index=True)
        displayer.update(scores_dfs)

`Saving the results DataFrame.`

In [18]:
scores_dfs.to_csv('results/VAT_results.csv', index=False)