In [None]:
import os
import csv
import pickle
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchtext.vocab import build_vocab_from_iterator

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune import ExperimentAnalysis

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [None]:
def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)
    
def collate_batch(batch):
    ordered_batch = list(zip(*batch))
    seq = torch.tensor([encode_text(seq) for seq in ordered_batch[0]], dtype=torch.int64)
    lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
    return seq, lbl

def regularized_acc(train_acc, dev_acc, threshold=0.005):
    """
    Returns development accuracy if overfitting is below threshold, otherwise 0.
    """
    return dev_acc if (train_acc - dev_acc) < threshold else 0

In [None]:
class QuadBiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, rnn_size1, rnn_size2, rnn_size3, rnn_size4, hidden_size, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )
        
        self.dropout=nn.Dropout(dropout)
        
        self.lstm1 = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=rnn_size1,
            bidirectional=True,
            batch_first=True
        )
        
        self.lstm2 = nn.LSTM(
            input_size=2*rnn_size1,
            hidden_size=rnn_size2,
            bidirectional=True,
            batch_first=True
        )
        
        self.lstm3 = nn.LSTM(
            input_size=2*rnn_size2,
            hidden_size=rnn_size3,
            bidirectional=True,
            batch_first=True
        )
        
        self.lstm4 = nn.LSTM(
            input_size=2*rnn_size3,
            hidden_size=rnn_size4,
            bidirectional=True,
            batch_first=True
        )
        
        self.fc1 = nn.Linear(rnn_size4 * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
    def forward(self, seq):
        # input shape: (batch_size, seq_len=10)
        embedded = self.dropout(self.embedding(seq))
        
        # input shape: (batch_size, seq_len, embedding_dim)
        out, _ = self.lstm1(embedded)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        out, _ = self.lstm2(out)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        out, _ = self.lstm3(out)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        out, _ = self.lstm4(out)
        
        # input shape: (batch_size, seq_len, 2*hidden_size)
        pooled = torch.mean(out, dim=1)
        
        # input shape; (batch_size, 2*hidden_size)
        out = self.dropout(F.relu(self.fc1(pooled)))
        
        # input shape: (batch_size, hidden_size)
        out = self.fc2(out).squeeze()
        return out # shape: (batch_size)

In [None]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0
    
    # for seq, lbl in loader:
    for seq, lbl in loader:
        seq, lbl = seq.to(device), lbl.to(device)
        
        scores = model(seq)
        loss = criterion(scores, lbl)
        
        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum()
        total += len(seq)
    return epoch_loss / total, num_correct / total

In [None]:
# load train and dev data
train_seqs, train_lbl = read_data('../../data/n_train.csv')
dev_seqs, dev_lbl = read_data('../../data/n_val.csv')

# create vocab from train seqs
vocab = build_vocab_from_iterator(train_seqs, specials=['<UNK>'])
vocab.set_default_index(vocab['<UNK>'])
encode_text = lambda x: vocab(list(x))

In [None]:
NUM_EPOCHS = 5
VOCAB_SIZE = len(vocab)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train(config, checkpoint_dir=None):
    
    # create train and dev loader
    train_data = CleavageDataset(train_seqs, train_lbl)
    train_loader = DataLoader(train_data, batch_size = 512, shuffle=True, collate_fn=collate_batch, num_workers=8)

    dev_data = CleavageDataset(dev_seqs, dev_lbl)
    dev_loader = DataLoader(dev_data, batch_size = 512, shuffle=True, collate_fn=collate_batch, num_workers=8)
    
    
    model = QuadBiLSTM(
        vocab_size=VOCAB_SIZE,
        embedding_dim=config['embedding_dim'],
        rnn_size1=config['rnn_size1'],
        rnn_size2=config['rnn_size2'],
        rnn_size3=config['rnn_size3'],
        rnn_size4=config['rnn_size4'],
        hidden_size=config['hidden_size'],
        dropout=config['dropout']
    ).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.BCEWithLogitsLoss()
    
    # normal train loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss, train_acc = process(model, train_loader, criterion, optimizer)
        
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = process(model, dev_loader, criterion)
        
        with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
            checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
            torch.save(model.state_dict(), checkpoint_path)
            
        # detach acc metrics that will be reported back to main process
        train_acc = train_acc.detach().cpu().numpy()
        val_acc = val_acc.detach().cpu().numpy()
        reg_acc = regularized_acc(train_acc, val_acc)
        tune.report(train_acc=train_acc, dev_acc=val_acc, reg_acc=reg_acc, train_loss=train_loss, dev_loss=val_loss)

In [None]:
class TuneReporter(CLIReporter):
    def __init__(self):
        super().__init__()
        self.num_terminated = 0

    def should_report(self, trials, done=False):
        """Reports only on trial termination events."""
        old_num_terminated = self.num_terminated
        self.num_terminated = len([t for t in trials if t.status == "TERMINATED"])
        return self.num_terminated > old_num_terminated
    
    def report(self, trials, done, *sys_info):
        print(self._progress_str(trials, done, *sys_info))
    
reporter = TuneReporter()
reporter.add_metric_column(metric='train_acc')
reporter.add_metric_column(metric='dev_acc')
reporter.add_metric_column(metric='reg_acc')
reporter.add_metric_column(metric='train_loss')
reporter.add_metric_column(metric='dev_loss')

In [None]:
search_space = {
    'embedding_dim': tune.choice([128]),
    'rnn_size1': tune.choice([128]),
    'rnn_size2': tune.choice([512]),
    'rnn_size3': tune.choice([256]),
    'rnn_size4': tune.choice([128]),
    'hidden_size': tune.choice([128]),
    'dropout': tune.choice([0.5]),
    'lr': tune.qloguniform(1e-4, 1e-1, 5e-5),
}

In [None]:
path = '../../params/n_term/quadBiLSTM/'
experiment = 'search'
num_samples = 1000

analysis = tune.run(
    train,
    name=experiment,
    config=search_space,
    sync_config=tune.SyncConfig(syncer=None),
    num_samples=num_samples,
    scheduler=ASHAScheduler(metric='reg_acc', mode='max'),
    progress_reporter=reporter,
    local_dir=path,
    keep_checkpoints_num=None, # keeps all checkpoints
    checkpoint_score_attr='reg_acc',
    resources_per_trial={'cpu': 16, 'gpu': 1},
)

In [None]:
ana = ExperimentAnalysis(path + experiment)

In [None]:
df = ana.dataframe()
cols_needed = [col for col in df.columns if col.startswith('config/')]

In [None]:
df[['train_acc', 'dev_acc', 'reg_acc', 'training_iteration', *cols_needed]].sort_values(by='reg_acc', ascending=False)

In [None]:
### when the acc suddenly drops to zero, that's when the model started to overfit stronger than the threshold (default=0.005) 
plt.figure(figsize=(16, 9))
ax = None
for val in ana.trial_dataframes.values():
    ax = val.reg_acc.plot(ax=ax, legend=False)
plt.xlabel('Epochs')
plt.ylabel('Regularized Accuracy')
plt.suptitle('Overview of Hyperparameter Search: QuadBiLSTM (n_term)', fontsize=15, ha='center')
plt.title("In cases of sudden accuracy drops to zero, the model started to overfit stronger than the threshold (default=0.005)", fontsize=12, ha='center')
plt.tight_layout()
plt.show()