In [30]:
from tdc.single_pred import Epitope
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, precision_recall_curve
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score
from copy import deepcopy
torch.manual_seed(1)
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt  
from sklearn.metrics import roc_curve, auc
from scipy import interp
from sklearn.metrics import roc_auc_score 

In [2]:
data =  Epitope(name = 'IEDB_Jespersen')
split = data.get_split()

Downloading...
100%|██████████| 2.18M/2.18M [00:00<00:00, 7.34MiB/s]
Loading...
Done!


In [3]:
train_data = split['train']
valid_data = split['valid']
test_data = split['test']

In [13]:
X = 'Antigen'

In [23]:
def data2vocab(data):
	length = len(data)
	vocab_set = set()
	total_length, positive_num = 0, 0
	for i in range(length):
		antigen = data[X][i]
		vocab_set = vocab_set.union(set(antigen))
		Y = data['Y'][i]
		assert len(antigen) > max(Y)
		total_length += len(antigen)
		positive_num += len(Y)
	return vocab_set, positive_num / total_length

In [24]:
train_vocab, train_positive_ratio = data2vocab(train_data)
valid_vocab, valid_positive_ratio = data2vocab(valid_data)
test_vocab, test_positive_ratio = data2vocab(test_data)

In [27]:
vocab_set = train_vocab.union(valid_vocab)
vocab_set = vocab_set.union(test_vocab)
vocab_lst = list(vocab_set)
# logger

In [32]:
def onehot(idx, length):
	lst = [0 for i in range(length)]
	lst[idx] = 1
	return lst 

def zerohot(length):
	return [0 for i in range(length)]

# what is the maxlength here
def standardize_data(data, vocab_lst, maxlength = 300):
	length = len(data)
	standard_data = []
	for i in range(length):
		antigen = data[X][i]
		Y = data['Y'][i] 
		sequence = [onehot(vocab_lst.index(s), len(vocab_lst)) for s in antigen] 
		labels = [0 for i in range(len(antigen))]
		mask = [True for i in range(len(labels))] # labels and mask have the same length
		sequence += (maxlength-len(sequence)) * [zerohot(len(vocab_lst))] #pad to consistent length
		labels += (maxlength-len(labels)) * [0] 
		mask += (maxlength-len(mask)) * [False] # pad to maxlength
		for y in Y:
			labels[y] = 1 		
		sequence, labels, mask = sequence[:maxlength], labels[:maxlength], mask[:maxlength]
		sequence, labels, mask = torch.FloatTensor(sequence), torch.FloatTensor(labels), torch.BoolTensor(mask) 
		# print(sequence.shape, labels.shape, mask.shape)
        # sequence is 2D, labels and mask are 1D
		standard_data.append((sequence, labels, mask))
	return standard_data 

In [33]:
train_data_stand = standardize_data(train_data, vocab_lst)
valid_data_stand = standardize_data(valid_data, vocab_lst)
test_data_stand = standardize_data(test_data, vocab_lst)

In [57]:
class dataset(Dataset):
	def __init__(self, data):
		self.sequences = [i[0] for i in data]
		self.labels = [i[1] for i in data]
		self.mask = [i[2] for i in data] 

	def __getitem__(self, index):
		return self.sequences[index], self.labels[index], self.mask[index]

	def __len__(self):
		return len(self.labels)

In [58]:
train_set = dataset(train_data_stand)
valid_set = dataset(valid_data_stand)
test_set = dataset(test_data_stand)
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

In [None]:
class RNN(nn.Module):
    def __init__(self, name, hidden_size, input_size, num_layers = 2):
        super(RNN, self).__init__()
        self.name = name 
        self.hidden_size = hidden_size
        self.input_size = input_size 
        self.rnn = nn.LSTM(         # if use nn.RNN(), it hardly learns
            input_size=input_size,
            hidden_size=hidden_size,         # rnn hidden unit
            num_layers=num_layers,           # number of rnn layer
            batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )

        self.out = nn.Linear(hidden_size, 1)
        criterion = torch.nn.BCEWithLogitsLoss()  
        self.opt = torch.optim.Adam(self.parameters(), lr=1e-3)
    
    
    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None represents zero initial hidden state

        # choose r_out at the last time step
        out = self.out(r_out)
        out = out.squeeze(-1)
        return out
    