# **Install Packages**

In [None]:
!pip install pytorch-lightning torchvision matplotlib tqdm

# **Import Packages**

In [None]:
import argparse
import datetime
import os
import random
import time
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.profiler import SimpleProfiler
from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn as nn
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
plt.style.use('ggplot')
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, f1_score
from tqdm import tqdm
tqdm.pandas()

In [None]:
from tqdm import tqdm
tqdm.pandas()

# **Define Model Architectures**

In [None]:
class Linear(pl.LightningModule):
    # https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    def __init__(self, **kwargs):
        super().__init__()
        self.lr = kwargs.get('lr')

        self.l1 = nn.Linear(2000, 128)
        self.l2 = nn.Linear(128, 1)
        self.loss = nn.BCELoss()
        self.save_hyperparameters()

    def forward(self, x):
        # defines model(...) function
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.sigmoid(self.l2(x))
        return x


    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(torch.squeeze(y_hat), torch.squeeze(y))
        self.log('Training loss', loss.item())
        return loss

    def on_validation_start(self):
        self.losses = []
        self.accuracies = []

    def validation_step(self, batch, batch_idx):
        x, y = batch
        probs = self(x)
        loss = self.loss(torch.squeeze(probs), torch.squeeze(y))

        acc = self.accuracy(probs, y)
        self.accuracies.extend(acc.cpu().numpy().tolist())
        self.losses.append(loss.item())
        return loss

    def on_validation_epoch_end(self):
        overall_acc = np.mean(self.accuracies)
        overall_loss = np.mean(self.losses)
        self.log('Validation loss', overall_loss)
        self.log('Validation Accuracy', overall_acc)

    def on_test_start(self):
        self.accuracies = []

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        self.accuracies.extend(acc.cpu().numpy().tolist())
        return acc

    def on_test_epoch_end(self):
        overall_acc = np.mean(self.accuracies)
        self.log("Test Accuracy", overall_acc)

    def accuracy(self, logits, y):
        acc = torch.eq(torch.argmax(logits, -1), y).to(torch.float32)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [None]:
class Conv(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.lr = kwargs.get('lr')

        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )

        self.out = nn.Linear(32 * 7 * 7, 10)

        self.save_hyperparameters()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output


    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('Training loss', loss.item())
        return loss
        # TODO: If you have time also log the train accuracy

    def on_validation_start(self):
        self.losses = []
        self.accuracies = []

    def validation_step(self, batch, batch_idx):
        x, y = batch
        probs = self(x)
        loss = F.cross_entropy(probs, y)

        acc = self.accuracy(probs, y)
        self.accuracies.extend(acc.cpu().numpy().tolist())
        self.losses.append(loss.item())
        return loss

    def on_validation_epoch_end(self):
        overall_acc = np.mean(self.accuracies)
        overall_loss = np.mean(self.losses)
        self.log('Validation loss', overall_loss)
        self.log('Validation Accuracy', overall_acc)

    def on_test_start(self):
        self.accuracies = []

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        self.accuracies.extend(acc.cpu().numpy().tolist())
        return acc

    def on_test_epoch_end(self):
        overall_acc = np.mean(self.accuracies)
        self.log("Test Accuracy", overall_acc)

    def accuracy(self, logits, y):
        acc = torch.eq(torch.argmax(logits, -1), y).to(torch.float32)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# **Define DataModule**

In [None]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, annotations, vectors, transform=None, target_transform=None):
        self.labels = annotations
        self.vectors = vectors
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        vector = self.vectors[idx]
        label = self.labels[idx]
        if self.transform:
            vector = self.transform(vector)
        if self.target_transform:
            label = self.target_transform(label)
        return torch.tensor(vector.toarray()).float(), torch.tensor(label).float()


class QuoraDataModule(pl.LightningDataModule):

    def __init__(self, **kwargs):
        # https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
        super().__init__()

        self.data_dir = kwargs.get('data_dir')
        self.batch_size = kwargs.get('batch_size')
        self.num_workers = kwargs.get('num_workers', 0)
        self.val_ratio = kwargs.get('val_ratio')

        error_msg = "[!] valid_size should be in the range [0, 1]."
        assert ((self.val_ratio >= 0) and (self.val_ratio <= 1)), error_msg

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders

        # Data: data transformation strategy
        # no need to use these in our case (just for splitting train & test datasets)
        
        train = pd.read_csv("train.csv")
        test = pd.read_csv("test.csv")

        # can implement BERT, but it will add another layer to the project which will increase complexity
        text_vectorizer = TfidfVectorizer(max_features=2000) # max_features used to reduce dimensions so that we no longer get an error for sparse matrix
        train_vector = text_vectorizer.fit_transform(train["question_text"])
        test_vector = text_vectorizer.transform(test["question_text"])

        X_train,X_val,y_train,y_val = train_test_split(
            train_vector,
            train["target"],
            test_size=0.2,
            stratify=train["target"],
            random_state=42)
        
        self.dataset_train = CustomDataset(y_train.values, X_train)
        self.dataset_val = CustomDataset(y_val.values, X_val)

    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=self.batch_size,  num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=self.batch_size,  num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)


# **Define Training Configuration**

In [None]:
dict_args = {
    'model': Linear,
    'dataloader': QuoraDataModule,
    'load': None,
    'resume_from_checkpoint': None,
    'data_dir': './',
    'batch_size': 32,
    'epoch': 3,
    'num_workers': 0, 
    'val_freq': 0.5, 
    'logdir': './logs',
    'lr': 0.001, 
    'display_freq': 64,
    'seed': 42, 
    'clip_grad_norm': 0, 
    'val_ratio': 0.2
}

In [None]:
# https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html
# Define seed for reproducibility everywhere!
pl.seed_everything(dict_args['seed'])
# Initialize model to train

if dict_args['load'] is not None:
    model = dict_args['model'].load_from_checkpoint(dict_args['load'], **dict_args)
else:
    model = dict_args['model'](**dict_args)

# Initialize logging paths
now = datetime.datetime.now().strftime('%m%d%H%M%S')
weight_save_dir = os.path.join(dict_args["logdir"], os.path.join('models', 'state_dict', now))

os.makedirs(weight_save_dir, exist_ok=True)

# Callback: model checkpoint strategy
checkpoint_callback = ModelCheckpoint(
    dirpath=weight_save_dir, save_top_k=5, verbose=True, monitor="Validation loss", mode="min"
)

#  TODO: Implement early stopping based on validation loss or accuracy --> optional

# Data: load data module
data_module = dict_args['dataloader'](**dict_args)

# Trainer: initialize training behaviour

logger = TensorBoardLogger(save_dir=dict_args['logdir'], version=now, name='lightning_logs', log_graph=True)
trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    val_check_interval=dict_args['val_freq'],
    deterministic=True,
    logger=logger,
    max_epochs=dict_args["epoch"],
    log_every_n_steps=dict_args["display_freq"],
    gradient_clip_val=dict_args['clip_grad_norm'],
    # resume_from_checkpoint=dict_args['resume_from_checkpoint']
)

# **Train the model**

In [None]:
trainer.fit(model, data_module)

# **Test the model**

In [None]:
trainer.test(model, ckpt_path='best', datamodule=data_module)

In [None]:
def threshold_search(y_true, y_proba):
    best_threshold = 0
    best_score = 0
    for threshold in tqdm([i * 0.01 for i in range(100)]):
        score = f1_score(y_true=y_true, y_pred=y_proba > threshold)
        if score > best_score:
            best_threshold = threshold
            best_score = score
    search_result = {'threshold': best_threshold, 'f1': best_score}
    return search_result

# **Visualize Results**

In [None]:
test_dataloader = data_module.test_dataloader()
wrong_preds_x = torch.empty(32, 1, 2000)
wrong_preds_y = torch.empty(0)
wrong_preds_label = torch.empty(0)
all_y_s = []
all_y_preds = []


for batch in test_dataloader:
    x, y = batch
    pred = trainer.model(x)
    threshold_search(y, pred)
    all_y_s.extend(y.detach().cpu().numpy().tolist())
    all_y_preds.extend(pred.detach().cpu().numpy().tolist())
    pred = torch.argmax(pred, dim=1)
    correct_mask = pred == y
    if torch.sum(correct_mask == False) >= 1:
        wrong_preds_x = torch.cat((wrong_preds_x, x[correct_mask == False]), dim=0)
        wrong_preds_y = torch.cat((wrong_preds_y, pred[correct_mask == False]), dim=0)
        wrong_preds_label = torch.cat((wrong_preds_label, y[correct_mask == False]), dim=0)
    if wrong_preds_y.size(0) >= 36:
        break

print('\n')
print(len(all_y_s))

In [None]:
f1 = threshold_search(all_y_s, np.array(all_y_preds))
auc = roc_auc_score(all_y_s, np.array(all_y_preds))

print('\n')
print("[Text Classification using Neural Network]")
print("   1) F1 score: ", f1)
print("   2) ROC AUC score: %.2f"%(auc))

In [None]:
fig, ax = plt.subplots(6, 6, figsize=(15, 15))
for row in range(6):
    for col in range(6):
        ax[row, col].imshow(np.squeeze(wrong_preds_x[row*6+col].detach().cpu().numpy()), cmap='gray')
        ax[row, col].set_title(f'Label: {wrong_preds_label[row*6+col]}, Pred: {wrong_preds_y[row*6+col]}')
        ax[row, col].axis('off')

plt.show()