In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/vaccine/
import sys
sys.path.append('/content/drive/MyDrive/vaccine/')

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import os
import socket
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from datetime import datetime
from scipy.cluster.hierarchy import dendrogram, linkage

In [None]:
from data import Tokenizer, Tokenizer2, BaselineDataset, EpitopeDataset, AntigenDataset, NpyDataset, EpitopeRawDataset
from model import RNN
from utils import predict, plot_roc_curve, plot_representations

In [None]:
RANDOM_SEED = 42
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
def train_eval(model, dataloader, criterion, optimizer=None, scheduler=None, is_train=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_grad_enabled(is_train)
    if is_train:
        model.train()
    else:
        model.eval()
    total_loss = 0
    total_correct = 0
    progress_bar = tqdm(dataloader, ascii=True)

    for batch_idx, batch in enumerate(progress_bar):
        
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        if is_train:
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            if scheduler:
                scheduler.step()

        total_loss += loss.item() * len(labels)
        progress_bar.set_description_str('Batch: {:d}, Loss: {:.4f}'.format((batch_idx+1), loss.item()))

        predictions = torch.argmax(outputs, dim=1)
        total_correct += torch.sum(predictions.eq(labels))

    return total_loss / len(dataloader.dataset), total_correct / len(dataloader.dataset)

In [None]:
hparams = {
        'lr': 0.005,
        'l2': 0.0001,
        'batch_size': 1024,
        'emb_size': 32,
        'kernel_size': 5,
        'hidden_size': 128,
        'dropout': 0.2,
        'pooling': True
    }

In [None]:
set_seed(RANDOM_SEED)
tokenizer = Tokenizer(max_len=40)
train_dataset_ = EpitopeDataset('Positive_train.txt', 'Negative_train.txt', tokenizer=tokenizer, data_dir='./data')
test_dataset = EpitopeDataset('Positive_test.txt', 'Negative_test.txt', tokenizer=tokenizer, data_dir='./data')
valid_size = 1000
train_dataset, valid_dataset = random_split(train_dataset_, [len(train_dataset_) - valid_size, valid_size])
train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, collate_fn=train_dataset.dataset.collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=hparams['batch_size'], shuffle=False, collate_fn=valid_dataset.dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, collate_fn=test_dataset.collate_fn)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RNN(tokenizer, emb_size=hparams['emb_size'], kernel_size=hparams['kernel_size'], hidden_size=hparams['hidden_size'], dropout=hparams['dropout'], pooling=hparams['pooling']).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=hparams['l2'], amsgrad=True)
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 30], gamma=0.5)
scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
scheduler4 = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0001, hparams['lr'], step_size_up=100, cycle_momentum=False)
criterion = nn.CrossEntropyLoss()

In [None]:
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join('runs', current_time + '_' + socket.gethostname())
writer = SummaryWriter(log_dir=log_dir)
model_path = os.path.join(log_dir, 'lstm.pt')
best_valid_loss = 0
best_valid_acc = 0

for epoch_idx in range(50):
    train_loss, train_acc = train_eval(model, train_loader, criterion, optimizer)
    valid_loss, valid_acc = train_eval(model, valid_loader, criterion, is_train=False)
    scheduler2.step()
    
    print("Epoch {}".format(epoch_idx))
    print("Training Loss: {:.4f}. Validation Loss: {:.4f}. ".format(train_loss, valid_loss))
    print("Training Accuracy: {:.4f}. Validation Accuracy: {:.4f}. ".format(train_acc, valid_acc))
    writer.add_scalars('loss', {'train': train_loss, 'valid': valid_loss}, epoch_idx)
    writer.add_scalars('accuracy', {'train': train_acc, 'valid': valid_acc}, epoch_idx)

    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        best_valid_loss = valid_loss
        torch.save(model, model_path)

model = torch.load(model_path)
_, test_acc = train_eval(model, test_loader, criterion, is_train=False)
print("Test Accuracy: {:.4f}. ".format(test_acc))
writer.add_hparams(hparams, {'hparam/accuracy': test_acc})
writer.close()

In [None]:
model = torch.load(model_path, map_location=device)
train_labels, train_probs = predict(model, train_loader)
test_labels, test_probs = predict(model, test_loader)
figure = plot_roc_curve(train_labels, train_probs, test_labels, test_probs)
figure.savefig(os.path.join('tex', 'figs', 'roc.pdf'), bbox_inches='tight')

In [None]:
representations = []
def hook(module, input):
    representations.append(input[0].detach())
model = torch.load(model_path, map_location=device)
model.fc1.register_forward_pre_hook(hook)
test_labels, _ = predict(model, test_loader)
representations = torch.cat(representations).cpu()

In [None]:
indices = list(range(1000)) + list(range(len(test_labels) // 2, len(test_labels) // 2 + 1000))
figure = plot_representations(representations[indices], test_labels[indices])
figure.savefig(os.path.join('tex', 'figs', 'tsne1.pdf'), bbox_inches='tight')

In [None]:
test_representations = torch.load('./data/test_esm.pkl')
figure = plot_representations(test_representations[indices], np.array(test_dataset.labels)[indices])
figure.savefig(os.path.join('tex', 'figs', 'tsne2.pdf'), bbox_inches='tight')

In [None]:
acc_test_dataset = NpyDataset('Positive_test.npy', 'Negative_test.npy')
figure = plot_representations(acc_test_dataset.data[indices], np.array(test_dataset.labels)[indices])
figure.savefig(os.path.join('tex', 'figs', 'tsne3.pdf'), bbox_inches='tight')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load(model_path, map_location=device)
mat = model.embedding.weight[3:, :].detach().cpu().numpy()
Z = linkage(mat, method='complete')
fig, ax = plt.subplots()
dn = dendrogram(Z, labels='ACDEFGHIKLMNPQRSTVWY', ax=ax)
fig.savefig(os.path.join('tex', 'figs', 'dg.pdf'), bbox_inches='tight')