In [None]:
# IDEA: use ProtTrans pipeline as feature extraction (instead of embeddings) and build model on top of it
# check whether input format makes sense
# https://github.com/agemagician/ProtTrans/blob/master/Embedding/PyTorch/Basic/ProtBert.ipynb

In [None]:
import os
import csv
import random
import time
import numpy as np
from collections import Counter
from itertools import chain

from typing import List, Tuple

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

from torchtext.vocab import build_vocab_from_iterator

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# def read_data(path):
#     with open(path, 'r') as csvfile:
#         train_data = list(csv.reader(csvfile))[1:] # skip col name
#         sents, lbls = [], []
#         for i in range(0, len(train_data), 16):
#             s, l = zip(*train_data[i:i+16])
#             sents.append(s)
#             lbls.append(l)
#     return sents, lbls

def read_data(path):
    with open(path, 'r') as csvfile:
        train_data = list(csv.reader(csvfile))[1:] # skip col name
        sents, lbls = [], []
        for s, l in train_data:
            sents.append(s)
            lbls.append(l)
    return sents, lbls

In [None]:
sents, lbl = read_data('../data/n_train.csv')

In [None]:
vocab = build_vocab_from_iterator(sents, specials=['<UNK>'])
vocab.set_default_index(vocab['<UNK>'])

encode_text = lambda x: vocab(list(x))

In [None]:
class CleavageDataset(Dataset):
    def __init__(self, seq, lbl):
        self.seq = seq
        self.lbl = lbl
    
    def __getitem__(self, idx):
        return self.seq[idx], self.lbl[idx]
    
    def __len__(self):
        return len(self.lbl)

In [None]:
class CleavageBatch:
    def __init__(self, batch: List[Tuple[str, str]]):
        ordered_batch = list(zip(*batch))
        self.seq = torch.tensor([encode_text(seq) for seq in ordered_batch[0]], dtype=torch.int64)
        self.lbl = torch.tensor([int(l) for l in ordered_batch[1]], dtype=torch.long)
        
    def pin_memory(self):
        self.seq = self.seq.pin_memory()
        self.lbl = self.lbl.pin_memory()
        return self
    
def collate_wrapper(batch):
    return CleavageBatch(batch)

In [None]:
data = CleavageDataset(sents, lbl)
loader = DataLoader(data, batch_size = 8, shuffle=True, collate_fn=collate_wrapper, pin_memory=True, num_workers=2)

In [None]:
for sample in loader:
    print(sample.seq)
    print(sample.lbl)
    break

In [None]:
class QuadBiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, rnn_size1, rnn_size2, rnn_size3, rnn_size4, dropout, hidden_size):
        super().__init__()
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )
        
        self.lstm1 = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=rnn_size1,
            bidirectional=True,
            dropout=dropout,
            batch_first=True
        )
        
        self.lstm2 = nn.LSTM(
            input_size=2*rnn_size1,
            hidden_size=rnn_size2,
            bidirectional=True,
            dropout=dropout,
            batch_first=True
        )
        
        self.lstm3 = nn.LSTM(
            input_size=2*rnn_size2,
            hidden_size=rnn_size3,
            bidirectional=True,
            dropout=dropout,
            batch_first=True
        )
        
        self.lstm4 = nn.LSTM(
            input_size=2*rnn_size3,
            hidden_size=rnn_size4,
            bidirectional=True,
            dropout=dropout,
            batch_first=True
        )
        
        self.fc1 = nn.Linear(rnn_size4 * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
    def forward(self, seq):
        pass