### The scFv antibody sequence/binding affinity data from a-alpha-Bio's Nature Communications paper

#### Two datasets:
* antibody_dataset_1 : 1109000 MITLL_AAlphaBio_Ab_Binding_dataset.csv rows
* antibody_dataset_2 : 1903928 MITLL_AAlphaBio_Ab_Binding_dataset2.csv rows



In [None]:
%load_ext autoreload

In [None]:
%autoreload
# import libraries
import numpy as np
import pickle as pk
import pandas as pd
import math
# pd.options.mode.copy_on_write = True # to avoid SettingWithCopyWarning
import os
import yaml
import sys
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from pytorch_lightning.core import LightningModule
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


%matplotlib inline

In [None]:
#
# Read the config
#
config_path = './config/pretrain_config.yaml'  
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

config = config['model_params']
print(config)

----------------
### OAS data

In [None]:
df = pd.read_csv('/home/mark/dev/myBERT/data/oas/human_light_sars_covid/1279049_1_Light_Bulk.csv', skiprows=1)
print(df.shape)


# Note, v_, d_, j_, c_ are the columns for the VDJC gene sequences : variable, diverse, joining, constant
# Antibodies obtain their diversity through 2 processes. The first is called V(D)J (variable, diverse, and joining regions) 
# recombination. During cell maturation, the B cell splices out the DNA of all but one of the genes from each region 
# and combine the three remaining genes to form one VDJ segment.

In [None]:
print(df.columns)

In [None]:
print(df['sequence'][0])
print(df['locus'][0])
print(df['v_frameshift'][0])
print(df['sequence_alignment'][0])
print(df['sequence_alignment_aa'][0])
print(df['fwr1_start'][0])
print(df['fwr1_end'][0])
print(df['fwr1'][0], ', len(fwr1) = ', len(df['fwr1'][0]))
print(df['Redundancy'][0])

In [None]:
# Just the amino acid sequences
cols = df.columns.to_list()
# print(cols)
aa_cols = [c for c in cols if 'aa' in c]
for a in aa_cols:
    print(a)

In [None]:
for col in aa_cols:
    s = df[col][0]
    if isinstance(s, str):
        print(col, ', ', len(s), ', ', s[:10])
    else:
        print(col, ', ', s)

s1 = df['sequence_alignment_aa'][0]
s2 = df['germline_alignment_aa'][0]
diff = [i for i in range(len(s1)) if s1[i] != s2[i]]
print('diff:', diff)
print(s1[diff[0]-3:diff[0]+3], ', ', s2[diff[0]-3:diff[0]+3])

s1 = df['sequence_alignment_aa'][0]
s2 = df['fwr1'][0]
print((s2 in s1))

-------------------
### Plan
* focus on just sequence_alignment_aa column (not germline)
    * eliminate duplicates
* ignore the separate v, d, and j (they are already in the light chain)
* ignore fwr1, fwr2, fwr3, fwr4 regions (on the heavy chain?)
* ignore cdr sequences (they are already contained in the longer light-chain sequences)

### Result of data extraction
* total rows 20306305,  num unique, len(seqs): 18061315
* range in length from min_len: 43 , max_len: 132


In [None]:
# Read all .csv files in a directory
# extract just the sequence_alignment_aa column data
# put into a set (to remove duplicates)
# There are 175 csv files in the directory

import glob
files = glob.glob('/home/mark/dev/myBERT/data/oas/human_light_sars_covid/*.csv')
print(len(files), 'csv files to process')
seqs = set()
total_rows = 0
for i, file in enumerate(files):
    df = pd.read_csv(file, skiprows=1)
    total_rows += df.shape[0]
    print(i, ':', os.path.basename(file), ', num rows:' , df.shape[0])
    seqs.update(df['sequence_alignment_aa'].to_list())
    print('\ttotal rows so far:', total_rows, ', num unique, len(seqs):', len(seqs))



max_len = -99
min_len = 100000
print(len(seqs), 'unique sequences')
for i, s in enumerate(seqs):
    if i < 10: 
        print(len(s), ', ', s)

    if len(s) < min_len:
        min_len = len(s)
    elif len(s) > max_len:
        max_len = len(s)

print('min_len:', min_len, ', max_len:', max_len)

# pk.dump(seqs, open('/home/mark/dev/myBERT/data/oas/human_light_sars_covid/unique_seqs.pk', 'wb'))

In [None]:
# Plot a histogram of the string lengths in the seqs set
lengths = [len(s) for s in seqs]


plt.hist(lengths, bins=50)
plt.show()


In [None]:
seqs_trimmed = seqs.copy()

# remove all sequencs less than 90 in length
for s in seqs:
    if len(s) < 90:
        seqs_trimmed.remove(s)


print(len(seqs_trimmed), 'unique sequences')

In [None]:
max_len = -99
min_len = 100000
print(len(seqs_trimmed), 'unique sequences')
for i, s in enumerate(seqs_trimmed):
    if len(s) < min_len:
        min_len = len(s)
    elif len(s) > max_len:
        max_len = len(s)

print('min_len:', min_len, ', max_len:', max_len)

In [None]:
# Plot a histogram of the string lengths in the seqs set
lengths = [len(s) for s in seqs_trimmed]

plt.hist(lengths, bins=50)
plt.show()


In [None]:
# Make train/test split
data = list(seqs_trimmed)
np.random.shuffle(data)
split = int(0.8 * len(data))
train_data = data[:split]
test_data = data[split:]

# pk.dump(train_data, open('/home/mark/dev/myBERT/data/oas/human_light_sars_covid/train_data.pk', 'wb'))
# pk.dump(test_data, open('/home/mark/dev/myBERT/data/oas/human_light_sars_covid/test_data.pk', 'wb'))

In [None]:
print(len(train_data), ', ', len(test_data))

max_len = -99
min_len = 100000
for i, s in enumerate(train_data):
    if len(s) < min_len:
        min_len = len(s)
    elif len(s) > max_len:
        max_len = len(s)

print('min_len:', min_len, ', max_len:', max_len)

In [None]:
for seq in test_data[0:3]:
    print(seq, ', len(seq):', len(seq))
    print('\tlen(seq) - self.config[block_size]:', len(seq) - config['block_size'])
    start_idx = np.random.randint(0, len(seq) - config['block_size'])
    print('\tstart_idx:', start_idx)
    chunk = seq[start_idx:start_idx + config['block_size']]
    print('\tlen(chunk):', len(chunk), ', chunk:', chunk)


In [None]:

import torch
from torch.utils.data import Dataset
import pickle as pk
import numpy as np

#--------------------------------------------------------
# Dataset for OAS data
#--------------------------------------------------------
class OASSequenceDataset(Dataset):
    """
    Emits sequences of aa's from the OAS data
    """
    def __init__(self, config, pk_file_path):
        super().__init__()
        self.config = config
        print('reading the data from:', pk_file_path)
        pk_data = pk.load(open(pk_file_path, 'rb'))
        self.data = list(pk_data)
    
        # 20 naturally occuring amino acids in human proteins plus MASK token
        # 'X' is a special token for unknown amino acids, and CLS token is for classification
        self.chars = ['CLS', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X', '[MASK]']
        print('vocabulary:', self.chars)

        data_size, vocab_size = len(self.data), len(self.chars)
        print('data has %d rows, %d vocab size (unique).' % (data_size, vocab_size))

        self.stoi = { ch:i for i,ch in enumerate(self.chars) }
        self.itos = { i:ch for i,ch in enumerate(self.chars) }
        self.vocab_size = vocab_size

    def get_vocab_size(self):
        return self.vocab_size

    def get_block_size(self):
        return self.config['block_size']

    def __len__(self):
        return len(self.data)

    """ Returns data, mask pairs used for Masked Language Model training """
    def __getitem__(self, idx):
        seq = self.data[idx]

        # get a randomly located block_size-1 substring from the sequence
        # '-1' so we can prepend the CLS token to the start of the encoded string
        if len(seq) == self.config['block_size']-1:
            chunk = seq
        else:
            start_idx = np.random.randint(0, len(seq) - (self.config['block_size'] - 1))
            chunk = seq[start_idx:start_idx + self.config['block_size']-1]

        # encode every character to an integer
        dix = torch.tensor([self.stoi[s] for s in chunk], dtype=torch.long)

        # prepend the CLS token to the sequence
        dix = torch.cat((torch.tensor([self.stoi['CLS']], dtype=torch.long), dix))

        # get number of tokens to mask
        n_pred = max(1, int(round(self.config['block_size']*self.config['mask_prob'])))

        # indices of the tokens that will be masked (a random selection of n_pred of the tokens)
        masked_idx = torch.randperm(self.config['block_size']-1, dtype=torch.long, )[:n_pred]
        masked_idx += 1  # so we never mask the CLS token

        mask = torch.zeros_like(dix)

        # copy the actual tokens to the mask
        mask[masked_idx] = dix[masked_idx]
        
        # ... and overwrite then with MASK token in the data
        dix[masked_idx] = self.stoi["[MASK]"]

        return dix, mask 


In [None]:
test_data_path = '/home/mark/dev/myBERT/data/oas/human_light_sars_covid/test_data.pk'
test_dataset = OASSequenceDataset(config, test_data_path)
print(test_dataset.__len__())
test_loader = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=config['batch_size'], num_workers=5)


In [None]:
dix, mask = test_dataset.__getitem__(0)
print()
print(dix)
print()
print(mask)

In [None]:
iter_ = iter(test_loader)
dix, mask = next(iter_)
print(dix.shape, mask.shape)
print(dix[0])
print()
print(mask[0])

In [None]:
torch.tensor(['CLS'])

In [None]:
print(list(["CLS"]))

In [None]:
chars = ['CLS', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X', '[MASK]']
print('vocabulary:', chars)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

chunk = ['A', 'C', 'D', 'E', 'F', 'R']
chunk.insert(0, 'CLS')

dix = torch.tensor([stoi[s] for s in chunk], dtype=torch.long)
print(dix)


In [None]:
import torch
from torcheval.metrics.text import Perplexity

metric=Perplexity()
input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]],[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]])
target = torch.tensor([[2],  [1], [2],  [1]])
print('input shape:', input.shape)
print('target shape:', target.shape)

metric.update(input, target)
metric.compute()

-----------------------------
### The scFv antibody sequence/binding affinity data

In [None]:
# data_path = './data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_1/MITLL_AAlphaBio_Ab_Binding_dataset.csv'
# df_1 = pd.read_csv(data_path)
# print(df_1.shape)
# print(df_1.columns)

In [None]:
data_path = './data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/MITLL_AAlphaBio_Ab_Binding_dataset2.csv'
df_2 = pd.read_csv(data_path, skiprows=6)
print(df_2.shape)
print(df_2.columns)

In [None]:
test_set = df_2.sample(frac = 0.10)
train_set = df_2.drop(test_set.index)

print(test_set.shape)
print(train_set.shape)


In [None]:
test_set.to_csv('./data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/test_set.csv')
train_set.to_csv('./data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/train_set.csv')

In [None]:

# pk.dump(train_set, open('./data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/train_set.pkl', 'wb'))
# pk.dump(test_set, open('./data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/test_set.pkl', 'wb'))

In [None]:
# print(df_2.columns.to_list())

# s1 = df_2['Sequence'][0]
# s2 = df_2['Sequence'][4]
# diff = [i for i in range(len(s1)) if s1[i] != s2[i]]
# print(diff)

# s1 = df_2['Sequence'][0]
# s2 = df_2['HC'][4]
# print('s1 length:', len(s1), 's2 length:', len(s2))



In [None]:
# row = df_2.iloc[0]
# print(row)

----------------------------------
#### Crafting a dataset for the sequence data

In [None]:
class FABSequenceDataset(Dataset):
    """
    Emits batches of characters
    """
    def __init__(self, config, csv_file_path):
        self.config = config
        self.df = pd.read_csv(data_path, skiprows=6)
        
        # my_set = set()   
        # def make_set(x):
        #     for c in x:
        #         my_set.add(c)

        # self.df['Sequence'].apply(make_set)
        # self.chars = sorted(list(my_set)) + ["[MASK]"]
        # print('len of chars:', len(self.chars))
        # print('chars:', self.chars)
    
        # 20 naturally occuring amino acids in human proteins plus MASK token
        self.chars = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '[MASK]']
        print('vocabulary:', self.chars)

        data_size, vocab_size = self.df.shape[0], len(self.chars)
        print('data has %d rows, %d vocab size (unique).' % (data_size, vocab_size))

        self.stoi = { ch:i for i,ch in enumerate(self.chars) }
        self.itos = { i:ch for i,ch in enumerate(self.chars) }
        self.vocab_size = vocab_size

    def get_vocab_size(self):
        return self.vocab_size

    def get_block_size(self):
        return self.config['block_size']

    def __len__(self):
        return self.df.shape[0] #len(self.data) - self.config['block_size']

    """ Returns data, mask pairs used for Masked Language Model training """
    def __getitem__(self, idx):
        # grab a chunk of (block_size) characters from the data
        # chunk = self.data[idx:idx + self.config['block_size']]
        chunk = self.df.loc[idx, 'Sequence']
        
        # encode every character to an integer
        dix = torch.tensor([self.stoi[s] for s in chunk], dtype=torch.long)

        # get number of tokens to mask
        n_pred = max(1, int(round(self.config['block_size']*self.config['mask_prob'])))

        # indices of the tokens that will be masked (a random selection of n_pred of the tokens)
        masked_idx = torch.randperm(self.config['block_size'], dtype=torch.long, )[:n_pred]

        mask = torch.zeros_like(dix)

        # copy the actual tokens to the mask
        mask[masked_idx] = dix[masked_idx]
        
        # ... and overwrite then with MASK token in the data
        dix[masked_idx] = self.stoi["[MASK]"]

        return dix, mask 


In [None]:
data_path = './data/mit-ll/mit-ll-AlphaSeq_Antibody_Dataset-a8f64a9/antibody_dataset_2/MITLL_AAlphaBio_Ab_Binding_dataset2.csv'
dataset = FABSequenceDataset(config, data_path)

In [None]:
print(dataset.__len__())

dix, mask = dataset.__getitem__(0)
print(len(dix))
print()
print(len(mask))

In [None]:
train_loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=5, 
                          pin_memory=True)

In [None]:
data_iter = iter(train_loader)


In [None]:
batch = next(data_iter)
dix, mask = batch
print(dix.shape)
print(mask.shape)