### The scFv antibody sequence/binding affinity data from a-alpha-Bio's Nature Communications paper

#### Two datasets:
* antibody_dataset_1 : 1109000 MITLL_AAlphaBio_Ab_Binding_dataset.csv rows
* antibody_dataset_2 : 1903928 MITLL_AAlphaBio_Ab_Binding_dataset2.csv rows
#### Plan
* pre-train model in MLM mode on dataset_2
* fine-tune on dataset_1

In [None]:
%load_ext autoreload

In [None]:
%autoreload
# import libraries
import numpy as np
import pickle as pk
import pandas as pd
import math
# pd.options.mode.copy_on_write = True # to avoid SettingWithCopyWarning
import os
import yaml
import sys
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from pytorch_lightning.core import LightningModule
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


%matplotlib inline

In [None]:
#
# Read the config
#
config_path = './config/fab_sequence_data.yaml'  
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

config = config['model_params']
print(config)

-----------------------------
### The scFv antibody sequence/binding affinity data

In [None]:
# data_path = './data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_1/MITLL_AAlphaBio_Ab_Binding_dataset.csv'
# df_1 = pd.read_csv(data_path)
# print(df_1.shape)
# print(df_1.columns)

In [None]:
# data_path = './data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/MITLL_AAlphaBio_Ab_Binding_dataset2.csv'
# df_2 = pd.read_csv(data_path, skiprows=6)
# print(df_2.shape)
# print(df_2.columns)

In [None]:
# print(df_2.columns.to_list())

# s1 = df_2['Sequence'][0]
# s2 = df_2['Sequence'][4]
# diff = [i for i in range(len(s1)) if s1[i] != s2[i]]
# print(diff)

# s1 = df_2['Sequence'][0]
# s2 = df_2['HC'][4]
# print('s1 length:', len(s1), 's2 length:', len(s2))



In [None]:
# row = df_2.iloc[0]
# print(row)

----------------------------------
#### Crafting a dataset for the sequence data

In [None]:
class FABSequenceDataset(Dataset):
    """
    Emits batches of characters
    """
    def __init__(self, config, csv_file_path):
        self.config = config
        self.df = pd.read_csv(data_path, skiprows=6)
        
        # my_set = set()   
        # def make_set(x):
        #     for c in x:
        #         my_set.add(c)

        # self.df['Sequence'].apply(make_set)
        # self.chars = sorted(list(my_set)) + ["[MASK]"]
        # print('len of chars:', len(self.chars))
        # print('chars:', self.chars)
    
        # 20 naturally occuring amino acids in human proteins plus MASK token
        self.chars = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '[MASK]']
        print('vocabulary:', self.chars)

        data_size, vocab_size = self.df.shape[0], len(self.chars)
        print('data has %d rows, %d vocab size (unique).' % (data_size, vocab_size))

        self.stoi = { ch:i for i,ch in enumerate(self.chars) }
        self.itos = { i:ch for i,ch in enumerate(self.chars) }
        self.vocab_size = vocab_size

    def get_vocab_size(self):
        return self.vocab_size

    def get_block_size(self):
        return self.config['block_size']

    def __len__(self):
        return self.df.shape[0] #len(self.data) - self.config['block_size']

    """ Returns data, mask pairs used for Masked Language Model training """
    def __getitem__(self, idx):
        # grab a chunk of (block_size) characters from the data
        # chunk = self.data[idx:idx + self.config['block_size']]
        chunk = self.df.loc[idx, 'Sequence']
        
        # encode every character to an integer
        dix = torch.tensor([self.stoi[s] for s in chunk], dtype=torch.long)

        # get number of tokens to mask
        n_pred = max(1, int(round(self.config['block_size']*self.config['mask_prob'])))

        # indices of the tokens that will be masked (a random selection of n_pred of the tokens)
        masked_idx = torch.randperm(self.config['block_size'], dtype=torch.long, )[:n_pred]

        mask = torch.zeros_like(dix)

        # copy the actual tokens to the mask
        mask[masked_idx] = dix[masked_idx]
        
        # ... and overwrite then with MASK token in the data
        dix[masked_idx] = self.stoi["[MASK]"]

        return dix, mask 


In [None]:
data_path = './data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/MITLL_AAlphaBio_Ab_Binding_dataset2.csv'
dataset = FABSequenceDataset(config, data_path)

In [None]:
print(dataset.__len__())

dix, mask = dataset.__getitem__(0)
print(len(dix))
print()
print(len(mask))

In [None]:
train_loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=5, 
                          pin_memory=True)

In [None]:
data_iter = iter(train_loader)


In [None]:
batch = next(data_iter)
dix, mask = batch
print(dix.shape)
print(mask.shape)