In [72]:
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler, random_split
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import transformers
from nlp import load_dataset
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from argparse import ArgumentParser
import re
from typing import Optional
from tqdm.auto import tqdm
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from src.tools.seq import seq_padding

In [73]:
DIR = Path('../data/AOMP_KMEANS')
TRAIN_DIR = DIR / 'train.csv'
VAL_DIR = DIR / 'val.csv'

TOKENIZER_DIR = Path('../tokenizers/bioseq')

In [74]:
df_train = pd.read_csv(DIR / 'train.csv', index_col=0).reset_index(drop=True)
pp = seq_padding(df_train.peptide[0], 15, split=False, truncate_length=15)
hla = seq_padding(df_train.HLA_sequence[0], 34, split=False, truncate_length=34)
print(pp)
print(hla)

N A Q K L L S E A L E [PAD] [PAD] [PAD] [PAD]
Y F A M Y G E K V A H T H V D T L Y V R C H Y Y T W A V L A Y T W Y


In [34]:
tokenizer = Tokenizer.from_file(str(TOKENIZER_DIR / 'tokenizer-bioseq-onehot.json'))
input = tokenizer.encode(pp, hla)
print(input.tokens)
print(input.ids)
print(input.attention_mask)

['[CLS]', 'N', 'A', 'Q', 'K', 'L', 'L', 'S', 'E', 'A', 'L', 'E', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[SEP]', 'Y', 'F', 'A', 'M', 'Y', 'G', 'E', 'K', 'V', 'A', 'H', 'T', 'H', 'V', 'D', 'T', 'L', 'Y', 'V', 'R', 'C', 'H', 'Y', 'Y', 'T', 'W', 'A', 'V', 'L', 'A', 'Y', 'T', 'W', 'Y', '[SEP]']
[1, 14, 3, 16, 11, 12, 12, 18, 6, 3, 12, 6, 0, 0, 0, 0, 2, 22, 7, 3, 13, 22, 8, 6, 11, 20, 3, 9, 19, 9, 20, 5, 19, 12, 22, 20, 17, 4, 9, 22, 22, 19, 21, 3, 20, 12, 3, 22, 19, 21, 22, 2]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [32]:
tokenizer.enable_padding(length=80)
print(pp)
print(tokenizer.encode(pp).attention_mask)
print(tokenizer.encode(pp).ids)
print(tokenizer.encode(pp).tokens)

N A Q K L L S E A L E [PAD] [PAD] [PAD] [PAD]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 14, 3, 16, 11, 12, 12, 18, 6, 3, 12, 6, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
['[CLS]', 'N', 'A', 'Q', 'K', 'L', 'L', 'S', 'E', 'A', 'L', 'E', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PA

In [18]:
batch_input = tokenizer.encode_batch([[pp, hla], [pp, hla]])
tokenizer.enable_padding(length=80)
print(batch_input[0].tokens)
print(batch_input[1].ids)
print(batch_input[1].attention_mask)

['[CLS]', 'N', 'A', 'Q', 'K', 'L', 'L', 'S', 'E', 'A', 'L', 'E', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[SEP]', 'Y', 'F', 'A', 'M', 'Y', 'G', 'E', 'K', 'V', 'A', 'H', 'T', 'H', 'V', 'D', 'T', 'L', 'Y', 'V', 'R', 'C', 'H', 'Y', 'Y', 'T', 'W', 'A', 'V', 'L', 'A', 'Y', 'T', 'W', 'Y', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
[1, 14, 3, 16, 11, 12, 12, 18, 6, 3, 12, 6, 0, 0, 0, 0, 2, 22, 7, 3, 13, 22, 8, 6, 11, 20, 3, 9, 19, 9, 20, 5, 19, 12, 22, 20, 17, 4, 9, 22, 22, 19, 21, 3, 20, 12, 3, 22, 19, 21, 22, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [88]:
### DataModule
class PHLADataset(Dataset):
    def __init__(self, phlas, labels):
        self.phlas = phlas
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'input_ids': torch.LongTensor(self.phlas[idx].ids),
            'attention_mask': torch.LongTensor(self.phlas[idx].attention_mask),
            'token_type_ids': torch.LongTensor(self.phlas[idx].type_ids),
        }

In [76]:
tokenizer = Tokenizer.from_file(str(TOKENIZER_DIR / 'tokenizer-bioseq-onehot.json'))
tokenizer.enable_padding(length=49)
tokenizer.enable_truncation(49)
print(type(tokenizer))

<class 'tokenizers.Tokenizer'>


In [77]:
phla_pairs = [
    [pep, hla]
    for pep, hla in zip(
        df_train.HLA_sequence.apply(lambda x: ' '.join(list(x))).tolist(),
        df_train.peptide.apply(lambda x: ' '.join(list(x))).tolist()
    )
]
phla_pairs

[['Y F A M Y G E K V A H T H V D T L Y V R C H Y Y T W A V L A Y T W Y',
  'N A Q K L L S E A L E'],
 ['Y Y S E Y R N I Y A Q T D E S N L Y L S Y D Y Y T W A E R A Y E W Y',
  'H P R Q S G P G A P N L'],
 ['Y T A M Y L Q N V A Q T D A N T L Y I M Y R D Y T W A V L A Y T W Y',
  'N S D R F I R I Y N K'],
 ['Y H T K Y R E I S T N T Y E S N L Y L R Y N Y Y S L A V L A Y E W Y',
  'K P K K A P K S P A K'],
 ['Y F A M Y Q E N V A Q T D V D T L Y I I Y R D Y T W A E L A Y T W Y',
  'K L E D I L D P I I K'],
 ['Y Y A M Y R N N V A Q T D V D T L Y I R Y H Y Y T W A V W A Y T W Y',
  'K V M S L S Y S L L T P L'],
 ['Y F A M Y Q E N M A H T D A N T L Y I I Y R D Y T W V A R V Y R G Y',
  'K V R S Y T K I V T N'],
 ['Y F A M Y G E K V A H T H V D T L Y V R Y H Y Y T W A V L A Y T W Y',
  'N F R N M F S L G G E'],
 ['Y Y A M Y G E N M A S T Y E N I A Y I V Y D S Y T W A V L A Y L W Y',
  'T F E R L V D V I C E'],
 ['Y F A M Y Q E N M A H T D A N T L Y I I Y R D Y T W V A R V Y R G Y',
  'R L D Q E

In [83]:
labels = df_train.label.tolist()
print(labels[:10])
print(len(labels))

[0, 1, 0, 0, 1, 0, 0, 0, 0, 1]
646945


In [89]:
phlas = tokenizer.encode_batch(phla_pairs)

In [90]:
ds = PHLADataset(phlas, labels)
print(len(ds))

646945


In [91]:
ds[2]

{'labels': tensor(0),
 'input_ids': tensor([ 1, 22, 19,  3, 13, 22, 12, 16, 14, 20,  3, 16, 19,  5,  3, 14, 19, 12,
         22, 10, 13, 22, 17,  5, 22, 19, 21,  3, 20, 12,  3, 22, 19, 21, 22,  2,
         14, 18,  5, 17,  7, 10, 17, 10, 22, 14, 11,  2,  0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0])}

In [106]:
import torch

from transformers import BertForSequenceClassification, BertConfig

print(tokenizer.get_vocab_size())

23


In [107]:


config = BertConfig(vocab_size=23,
                    num_labels=2,
                    hidden_size=512,
                    num_hidden_layers=2,
                    num_attention_heads=4,
                    intermediate_size=1024,
                    hidden_act="gelu",
                    hidden_dropout_prob=0.1,
                    attention_probs_dropout_prob=0.1,
                    max_position_embeddings=256,
                    type_vocab_size=2,
                    initializer_range=0.02,
                    layer_norm_eps=1e-9,
                    pad_token_id=0,
                    position_embedding_type="absolute",
                    use_cache=True,
                    classifier_dropout=None,
                    )

model = BertForSequenceClassification(config)

In [94]:
ds[0]

{'labels': tensor(0),
 'input_ids': tensor([ 1, 22,  7,  3, 13, 22,  8,  6, 11, 20,  3,  9, 19,  9, 20,  5, 19, 12,
         22, 20, 17,  4,  9, 22, 22, 19, 21,  3, 20, 12,  3, 22, 19, 21, 22,  2,
         14,  3, 16, 11, 12, 12, 18,  6,  3, 12,  6,  2,  0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0])}

In [97]:
x = ds[0]

In [101]:
x['labels']

tensor(0)

In [108]:

loss = model(
    input_ids=x['input_ids'],
    attention_mask=x['attention_mask'],
    token_type_ids=x['token_type_ids'],
    labels=x['labels']
).loss

ValueError: not enough values to unpack (expected 2, got 1)

In [None]:
PHLADataModule(
    tokenizer_path='../tokenizers/bioseq/tokenizer-bioseq-onehot.json'
)


class PHLADataModule(pl.LightningDataModule):

    def __init__(self,
                 train_path: str,
                 val_path: str,
                 test_path: str,
                 tokenizer_path: str,
                 max_len: int = 15,
                 batch_size: int = 1024,
                 num_workers: int = 4,
                 *args,
                 **kwargs
                 ):
        super().__init__()
        self.save_hyperparameters()

        self.tokenizer = Tokenizer.from_file(tokenizer_path)

    def prepare_data(self) -> None:
        df_train = pd.read_csv(self.hparams.train_path, index_col=0)
        df_val = pd.read_csv(self.hparams.val_path, index_col=0)
        df_test = pd.read_csv(self.hparams.test_path, index_col=0)

        self.tokenizer()

    def setup(self, stage: Optional[str] = None):
        ## Creating dataframes from the list data.  The reviews are arranged in the order that the first 12,500 belong to positive reviews and the rest 12,500 belong to negative reviews.
        df_train_reviews_clean = pd.DataFrame(self.reviews_train_clean, columns=['reviews'])
        df_train_reviews_clean['target'] = np.where(df_train_reviews_clean.index < 12500, 1, 0)

        df_test_reviews_clean = pd.DataFrame(self.reviews_test_clean, columns=['reviews'])
        df_test_reviews_clean['target'] = np.where(df_test_reviews_clean.index < 12500, 1, 0)

        # Shuffling the rows in both the train and test data.  This is very important before using the data for training.
        df_train_reviews_clean = df_train_reviews_clean.sample(frac=1).reset_index(drop=True)
        df_test_reviews_clean = df_test_reviews_clean.sample(frac=1).reset_index(drop=True)

        # breaking the train data into training and validation
        df_train, df_valid = train_test_split(df_train_reviews_clean, test_size=0.25,
                                              stratify=df_train_reviews_clean['target'])

        self.train = df_train.reset_index(drop=True)
        self.val = df_valid.reset_index(drop=True)
        self.test = df_test_reviews_clean

    def train_dataloader(self):
        return DataLoader(ImdbDataset(notes=self.train['reviews'],
                                      targets=self.train['target'],
                                      tokenizer=self.tokenizer,
                                      max_len=self.hparams.max_len
                                      ),
                          batch_size=self.hparams.batch_size,
                          num_workers=self.hparams.num_workers)

    def val_dataloader(self):
        return DataLoader(ImdbDataset(notes=self.val['reviews'],
                                      targets=self.val['target'],
                                      tokenizer=self.tokenizer,
                                      max_len=self.hparams.max_len),
                          batch_size=self.hparams.batch_size,
                          num_workers=self.hparams.num_workers)

    def test_dataloader(self):
        return DataLoader(ImdbDataset(notes=self.test['reviews'],
                                      targets=self.test['target'],
                                      tokenizer=self.tokenizer,
                                      max_len=self.hparams.max_len),
                          batch_size=self.hparams.batch_size,
                          num_workers=self.hparams.num_workers)

In [None]:
## The main Pytorch Lightning module
class ImdbModel(pl.LightningModule):

    def __init__(self,
                 learning_rate: float = 0.0001 * 8,
                 **kwargs):
        super().__init__()

        self.save_hyperparameters()

        self.num_labels = 2
        config = transformers.DistilBertConfig(dropout=0.1, attention_dropout=0.2)
        self.bert = transformers.DistilBertModel.from_pretrained('distilbert-base-uncased', config=config)

        self.pre_classifier = torch.nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.dropout = torch.nn.Dropout(self.bert.config.seq_classif_dropout)

        # relu activation function
        self.relu = torch.nn.ReLU()

    def forward(self, input_ids, attention_mask, labels):

        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)

        hidden_state = outputs[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = self.relu(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, dim)

        return logits

    def get_outputs(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
        hidden_state = outputs[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        return pooled_output

    def training_step(self, batch, batch_nb):
        # batch
        input_ids = batch['input_ids']
        label = batch['label']
        attention_mask = batch['attention_mask']
        #token_type_ids = batch['token_type_ids']
        # fwd
        y_hat = self(input_ids, attention_mask, label)

        # loss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(y_hat.view(-1, self.num_labels), label.view(-1))
        #loss = F.cross_entropy(y_hat, label)

        # logs
        tensorboard_logs = {'train_loss': loss, 'learn_rate': self.optim.param_groups[0]['lr']}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # batch
        input_ids = batch['input_ids']
        label = batch['label']
        attention_mask = batch['attention_mask']
        #token_type_ids = batch['token_type_ids']
        # fwd
        y_hat = self(input_ids, attention_mask, label)

        # loss
        #loss = F.cross_entropy(y_hat, label)
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(y_hat.view(-1, self.num_labels), label.view(-1))

        # acc
        a, y_hat = torch.max(y_hat, dim=1)
        val_acc = accuracy_score(y_hat.cpu(), label.cpu())
        val_acc = torch.tensor(val_acc)

        # logs
        tensorboard_logs = {'val_loss': loss, 'val_acc': val_acc}
        # can't log in validation step lossess, accuracy.  It wouldn't log it at every validation step
        return {'val_loss': loss, 'val_acc': val_acc, 'progress_bar': tensorboard_logs}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        # logs
        tensorboard_logs = {'val_loss': avg_loss, 'val_acc': avg_val_acc}
        return {'val_loss': avg_loss, 'progress_bar': tensorboard_logs, 'log': tensorboard_logs}

    def on_batch_end(self):
        #for group in self.optim.param_groups:
        #    print('learning rate', group['lr'])
        # This is needed to use the One Cycle learning rate that needs the learning rate to change after every batch
        # Without this, the learning rate will only change after every epoch
        if self.sched is not None:
            self.sched.step()

    def on_epoch_end(self):
        if self.sched is not None:
            self.sched.step()

    def test_step(self, batch, batch_nb):
        input_ids = batch['input_ids']
        label = batch['label']
        attention_mask = batch['attention_mask']
        #token_type_ids = batch['token_type_ids']
        y_hat = self(input_ids, attention_mask, label)

        # loss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(y_hat.view(-1, self.num_labels), label.view(-1))

        a, y_hat = torch.max(y_hat, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), label.cpu())

        return {'test_loss': loss, 'test_acc': torch.tensor(test_acc)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()

        tensorboard_logs = {'avg_test_loss': avg_loss, 'avg_test_acc': avg_test_acc}
        return {'avg_test_acc': avg_test_acc, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.hparams.learning_rate,
                                     eps=1e-08)
        #scheduler = StepLR(optimizer, step_size=1, gamma=0.2)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-5, total_steps=2000)
        #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-7, max_lr=1e-4, cycle_momentum=False,step_size_up=300)

        #scheduler = ReduceLROnPlateau(optimizer, patience=0, factor=0.2)
        self.sched = scheduler
        self.optim = optimizer
        return [optimizer], [scheduler]