In [1]:
import os
import sys
import csv
import pickle
import random
import numpy as np
from time import time
from tqdm import tqdm

from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler

from tokenizers import ByteLevelBPETokenizer

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


seed_everything(1234)

# avoids parallelism errors when both tokenizers and torch dataloaders use multiprocessing 
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [3]:
# def read_data(path):
#     with open(path, 'r') as csvfile:
#         train_data = list(csv.reader(csvfile))[1:] # skip col name
#         sents, lbls = [], []
#         for i in range(0, len(train_data), 16):
#             s, l = zip(*train_data[i:i+16])
#             sents.append(s)
#             lbls.append(l)
#     return sents, lbls


def read_data(path):
    with open(path, "r") as csvfile:
        train_data = list(csv.reader(csvfile))[1:]  # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls


# number of trainable parameters in model
def get_total_model_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl

    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]

    def __len__(self):
        return len(self.lbl)


class BucketSampler(Sampler):
    def __init__(self, seqs, batch_size):

        # pair each sequence with their *tokenized* length
        indices = [(idx, len(tokenizer.encode(s).ids)) for idx, s in enumerate(seqs)]
        random.shuffle(indices)

        idx_pools = []
        # generate pseudo-random batches of (arbitrary) size batch_size * 100
        # each batch of size batch_size * 100 is sorted in itself by seq length
        for i in range(0, len(indices), batch_size * 100):
            idx_pools.extend(
                sorted(indices[i : i + batch_size * 100], key=lambda x: x[1])
            )

        # filter only indices
        self.idx_pools = [x[0] for x in idx_pools]

    def __iter__(self):
        return iter(self.idx_pools)

    def __len__(self):
        return len(self.idx_pools)


class CleavageBatch:
    def __init__(self, batch: List[Tuple[str, str]]):
        ordered_batch = list(zip(*batch))
        self.seq = torch.tensor(
            [s.ids for s in tokenizer.encode_batch(ordered_batch[0])], dtype=torch.int64
        )
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.float)
        self.lengths = torch.tensor([self.seq.shape[1]] * self.seq.shape[0], dtype=torch.int64)

    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self


def collate_wrapper(batch):
    return CleavageBatch(batch)

In [5]:
class QuadBiLSTM(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_dim,
        rnn_size1,
        rnn_size2,
        rnn_size3,
        rnn_size4,
        hidden_size,
        dropout,
    ):
        super().__init__()

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0
        )

        self.dropout = nn.Dropout(dropout)

        self.lstm1 = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=rnn_size1,
            bidirectional=True,
            batch_first=True,
        )

        self.lstm2 = nn.LSTM(
            input_size=2 * rnn_size1,
            hidden_size=rnn_size2,
            bidirectional=True,
            batch_first=True,
        )

        self.lstm3 = nn.LSTM(
            input_size=2 * rnn_size2,
            hidden_size=rnn_size3,
            bidirectional=True,
            batch_first=True,
        )

        self.lstm4 = nn.LSTM(
            input_size=2 * rnn_size3,
            hidden_size=rnn_size4,
            bidirectional=True,
            batch_first=True,
        )

        self.fc1 = nn.Linear(rnn_size4 * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, seq, lengths):
        # input shape: (batch_size, seq_len=10)
        embedded = self.dropout(self.embedding(seq))

        packed_embeddings = pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )

        # input shape: (batch_size, seq_len, embedding_dim)
        out, _ = self.lstm1(packed_embeddings)

        # input shape: (batch_size, seq_len, 2*rnn_size1)
        out, _ = self.lstm2(out)

        # input shape: (batch_size, seq_len, 2*rnn_size2)
        out, _ = self.lstm3(out)

        # input shape: (batch_size, seq_len, 2*rnn_size3)
        out, _ = self.lstm4(out)

        unpacked_output, _ = pad_packed_sequence(out, batch_first=True, padding_value=0)

        # input shape: (batch_size, seq_len, 2*hidden_size)
        pooled = torch.mean(unpacked_output, dim=1)

        # input shape; (batch_size, 2*hidden_size)
        out = self.dropout(F.relu(self.fc1(pooled)))

        # input shape: (batch_size, hidden_size)
        # output shape: (batch_size)
        return self.fc2(out).squeeze()

In [6]:
def process(model, loader, criterion, optim=None):
    epoch_loss, num_correct, total = 0, 0, 0

    for batch in tqdm(
        loader,
        desc="Train: " if optim is not None else "Eval: ",
        file=sys.stdout,
        unit="batches",
    ):
        seq, lbl, lengths = batch.seq, batch.lbl, batch.lengths
        seq, lbl = seq.to(device), lbl.to(device)

        scores = model(seq, lengths)
        loss = criterion(scores, lbl)

        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()

        epoch_loss += loss.item()
        num_correct += ((scores > 0) == lbl).sum()
        total += len(seq)
    return epoch_loss / total, num_correct / total

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vocab_file = "../../preprocessing/bbpe_params/vocab_10k/n_term/vocab.json"
merge_file = "../../preprocessing/bbpe_params/vocab_10k/n_term/merges.txt"

# tokenizer serves as vocab at the same time
tokenizer = ByteLevelBPETokenizer.from_file(vocab_file, merge_file)
tokenizer.enable_padding(pad_token="<PAD>")

# load train and dev data
train_seqs, train_lbl = read_data("../../data/n_train.csv")
dev_seqs, dev_lbl = read_data("../../data/n_val.csv")

In [8]:
NUM_EPOCHS = 10
BATCH_SIZE = 512
VOCAB_SIZE = tokenizer.get_vocab_size()
EMBEDDING_DIM = 300
RNN_SIZE1 = 128
RNN_SIZE2 = 512
RNN_SIZE3 = 256
RNN_SIZE4 = 128
HIDDEN_SIZE = 128
DROPOUT = 0.5
LEARNING_RATE = 1e-4

model = QuadBiLSTM(
    vocab_size=VOCAB_SIZE,
    embedding_dim=EMBEDDING_DIM,
    rnn_size1=RNN_SIZE1,
    rnn_size2=RNN_SIZE2,
    rnn_size3=RNN_SIZE3,
    rnn_size4=RNN_SIZE4,
    hidden_size=HIDDEN_SIZE,
    dropout=DROPOUT,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

# create train and dev loader
train_data = CleavageDataset(train_seqs, train_lbl)
train_bucket_sampler = BucketSampler(train_seqs, BATCH_SIZE)
train_sampler = BatchSampler(train_bucket_sampler, BATCH_SIZE, drop_last=True)
train_loader = DataLoader(
    train_data,
    batch_sampler=train_sampler,
    collate_fn=collate_wrapper,
    pin_memory=True,
    num_workers=10,
)

dev_data = CleavageDataset(dev_seqs, dev_lbl)
dev_bucket_sampler = BucketSampler(dev_seqs, BATCH_SIZE)
dev_sampler = BatchSampler(dev_bucket_sampler, BATCH_SIZE, drop_last=True)
dev_loader = DataLoader(
    dev_data,
    batch_sampler=dev_sampler,
    collate_fn=collate_wrapper,
    pin_memory=True,
    num_workers=10,
)


print(f"Total trainable model parameters: {get_total_model_params(model):,}")

Total trainable model parameters: 9,910,209


In [9]:
start = time()
print("Starting Training.")
highest_val_acc = 0
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss, train_acc = process(model, train_loader, criterion, optimizer)

    model.eval()
    with torch.no_grad():
        val_loss, val_acc = process(model, dev_loader, criterion)

    # save current acc, loss
    train_losses.append((epoch, train_loss))
    train_accuracies.append((epoch, train_acc))
    val_losses.append((epoch, val_loss))
    val_accuracies.append((epoch, val_acc))

    if val_acc > highest_val_acc:
        highest_val_acc = val_acc
        path = f"../../params/n_term/quadBiLSTM/acc{val_acc:.4f}_epoch{epoch}.pt"
        torch.save(model.state_dict(), path)

    print(
        f"Training:   [Epoch {epoch:2d}, Loss: {train_loss:8.4f}, Acc: {train_acc:.4f}]"
    )
    print(f"Evaluation: [Epoch {epoch:2d}, Loss: {val_loss:8.4f}, Acc: {val_acc:.4f}]")

print("Finished Training.")
train_time = (time() - start) / 60
print(f"Training took {train_time} minutes.")

Starting Training.
Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2235/2235 [00:15<00:00, 143.95batches/s]
Eval: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 279/279 [00:00<00:00, 347.12batches/s]
Training:   [Epoch  1, Loss:   0.0013, Acc: 0.5812]
Evaluation: [Epoch  1, Loss:   0.0013, Acc: 0.5920]
Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2235/2235 [00:15<00:00, 144.97batches/s]
Eval: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 279/279 [00:00<00:00, 356.56batches/s]
Training:   [Epoch  2, Lo

In [None]:
# save training stats
lsts = [train_losses, train_accuracies, val_losses, val_accuracies, train_time]
names = [
    "train_losses",
    "train_accuracies",
    "val_losses",
    "val_accuracies",
    "train_time",
]
to_save = dict()
for name, lst in zip(names, lsts):
    to_save[name] = lst

with open(f"../params/n_term/quadBiLSTM/metrics.pkl", "wb") as f:
    pickle.dump(to_save, f, pickle.HIGHEST_PROTOCOL)

print("Finished Saving Details.")