## a-AlphaBio homework 
### misc. futzing about.... prob messy

In [None]:
%load_ext autoreload

In [None]:
%autoreload
# import libraries
import numpy as np
import pickle as pk
import pandas as pd
import math
import os
import yaml
import sys
import pytorch_lightning as pl
# import sys
# sys.path.append("../")

import matplotlib.pyplot as plt
%matplotlib inline

import torch

In [None]:
a = 'TNYYMYWVRQAPGQGLEWMGGINPSNGGTNFNEKFKNRVTLTTDSSTTTAYMELKSLQFDDTAVYYCARRDYRFDMGFD'
print(len(a))

In [None]:
# Read the config
config_path = './config/tform_mlp_params_v3.yaml'
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

model_config = config['model_params']
train_config = config['train_params']    

print(model_config)
print(train_config)
pl.seed_everything(config['seed'])


In [None]:
import math
import random
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class scFv_Dataset_v3(Dataset):
    """
        Dataset class for scFv sequence, Kd data
        version 3: this class only uses positions 25 - 110 of each sequence.
        This corresponds to a block length = 85

        Args:
            config: dict with configuration parameters
            csv_file_path: path to the csv file
            skiprows: number of rows to skip at the beginning of the file
            inference: if True, the dataset is used for inference
            regularize: if True, the dataset is used for training and data augmentation/regularization is applied
    """
    def __init__(self, config, block_size, start_seq_idx, csv_file_path, skiprows=0, inference=False, regularize=False):  
        super().__init__()
        self.config = config
        self.block_size = block_size
        self.start_seq_idx = start_seq_idx # always start on this index (i.e 25)
        self.inference = inference
        self.regularize = regularize # sequence flipping etc...
        print('reading the data from:', csv_file_path)
        self.df = pd.read_csv(csv_file_path, skiprows=skiprows)
        
        # 20 naturally occuring amino acids in human proteins plus MASK token, 
        # 'X' is a special token for unknown amino acids, CLS token is for classification, and PAD for padding
        self.chars = ['CLS', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 
                      'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X', 'MASK', 'PAD']
        self.first_aa = self.chars[1]
        self.last_aa = self.chars[20]
        print('vocabulary:', self.chars)
        data_size, vocab_size = self.df.shape[0], len(self.chars)
        print('data has %d rows, %d vocab size (unique).' % (data_size, vocab_size))
        self.vocab_size = vocab_size

        # aa groups: group name for each position in the self.chars array above
        self.aa_groups = ['none', 'nonpolar', 'nonpolar', 'neg', 'neg', 'nonpolar', 'nonpolar', 
                          'pos', 'nonpolar', 'pos', 'nonpolar', 'nonpolar', 'neg', 
                          'nonpolar', 'neg', 'pos', 'polar', 'polar', 'nonpolar', 
                          'nonpolar', 'polar', 'none', 'none', 'none']
        
        self.groups = ['none', 'nonpolar', 'neg', 'pos', 'polar']
        print('aa groups:', self.aa_groups)

        # The relative variability frequence for each amino acid position in the scFv sequences over the entire clean_3 dataset
        # This fixed-array is 247 (246 residues + an extra 0.0000 added for holdout set)
        #
        # A better way to do this would have been to pre-computed for each dataset separately.
        raw_pos_variability = torch.tensor([ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.7778,
                                             0.9444, 0.9444, 1.0000, 1.0000, 1.0000, 0.9444, 1.0000, 1.0000, 0.8889,
                                             0.5556, 0.5000, 0.2222, 0.3333, 0.2778, 0.6111, 0.4444, 0.5556, 1.0000,
                                             0.8333, 0.8333, 0.8333, 0.9444, 0.7778, 0.8333, 0.6111, 1.0000, 0.8889,
                                             0.3333, 0.9444, 0.8889, 0.1111, 0.3889, 0.9444, 0.2778, 0.9444, 0.8333,
                                             0.5000, 1.0000, 1.0000, 1.0000, 0.9444, 0.6111, 0.6111, 0.7778, 0.2778,
                                             0.8889, 0.3889, 0.9444, 1.0000, 0.3889, 0.9444, 1.0000, 0.9444, 0.2222,
                                             0.7778, 0.5556, 0.8889, 0.2222, 0.7778, 0.6111, 0.6667, 0.8333, 0.8333,
                                             1.0000, 1.0000, 0.8889, 0.8333, 0.8333, 0.9444, 0.7222, 0.9444, 0.9444,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                             0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 
                                             0.0000, 0.0000, 0.0000, 0.0000]) 
        
        # Map the raw_pos_variability into 10 buckets based on value
        var_buckets = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1.0]
        def get_bucket(n):
            assert(n >= 0 and n <= 1), "make sure n is in the range [0,1.0]"
            for i, k in enumerate(var_buckets):
                if n <= k:
                    return i
            return 0        # assign each variability value to a bucket
        
        self.enc_pos_variability = [get_bucket(v) for v in raw_pos_variability]

        print('a sample of some raw variability values:', raw_pos_variability[60:65])
        print('and their corresponding encoded values :', self.enc_pos_variability[60:65])


        # encoding and decoding residues
        self.stoi = { ch:i for i,ch in enumerate(self.chars) }
        self.itos = { i:ch for i,ch in enumerate(self.chars) }
        # encoding decoding groups
        self.groupstoi = { group:i for i,group in enumerate(self.aa_groups) }
        self.groupitos = { i:group for i,group in enumerate(self.aa_groups) }
        self.gtoi = { group:i for i,group in enumerate(self.groups) }
        self.itog = { i:group for i,group in enumerate(self.groups) }

    def encode_aa(self, aa_name):
        return self.stoi[aa_name]
    
    def decode_aa(self, aa_idx):
        assert aa_idx < self.vocab_size, 'aa_idx is out of bounds'
        return self.itos[aa_idx]
    
    # Get the group encoding for a given amino acid
    def encode_aa_to_group(self, aa_name):
        aa_idx = self.stoi[aa_name]
        group_name = self.groupitos[aa_idx]
        group_enc = self.gtoi[group_name]
        return group_enc
    
    def decode_group(self, group_idx):
        return self.itog[group_idx]
    
    def get_vocab_size(self):
        return self.vocab_size

    def get_block_size(self):
        return self.block_size

    def __len__(self):
        return self.df.shape[0] 

    def __getitem__(self, idx):
        """ 
            Returns sequence encoding, group encoding, affinity
        """
        seq = self.df.loc[idx, 'sequence_a']

        # apologies: next couple lines are overly dataset-specific
        if self.inference == False: # training or test mode
            Kd = self.df.loc[idx, 'Kd']
            assert not math.isnan(Kd), 'Kd is nan'
            name = 'none'
        else:
            Kd = 0 # inference mode - Kd is not available
            name = self.df.loc[idx, 'description_a']

        assert Kd >= 0.0, 'affinity cannot be negative'

        # get a randomly located block_size substring from the sequence
        assert(len(seq) > self.block_size), 'sequence is shorter than block_size'
        # if len(seq) <= self.block_size:
        #     chunk = seq
        # else:
        start_idx = self.start_seq_idx  #np.random.randint(0, len(seq) - (self.block_size))
        chunk = seq[start_idx:start_idx + self.block_size]

        # encode residues, residues' groups, and position variability
        dix = torch.tensor([self.stoi[s] for s in chunk], dtype=torch.long)
        gix = torch.tensor([self.encode_aa_to_group(s) for s in chunk], dtype=torch.long)
        vix = torch.tensor(self.enc_pos_variability[:len(dix)], dtype=torch.long)

        # some sequence-level regularization & augmentation can be done here
        if self.regularize:
            # occasionally flip the aa sequences back-to-front as a regularization technique 
            dix = torch.flip(dix, [0]) if (random.random() < self.config['seq_flip_prob']) else dix
            gix = torch.flip(gix, [0]) if (random.random() < self.config['seq_flip_prob']) else gix
            vix = torch.flip(vix, [0]) if (random.random() < self.config['seq_flip_prob']) else vix

            # mask a small perentage of the amino acids with the MASK token
            # acts like a dropout
            if self.config['seq_mask_prob'] > 0.0:
                num_2_mask = max(0, int(round((dix.shape[0])*self.config['seq_mask_prob'])))
                masked_idx = torch.randperm((dix.shape[0]), dtype=torch.long)[:num_2_mask]
                dix[masked_idx] = self.stoi['MASK']
                gix[masked_idx] = self.gtoi['none']
                vix[masked_idx] = 0  # ?? not sure about this...


        # pad the end with PAD tokens if necessary
        if dix.shape[0] < self.block_size:
            dix = torch.cat((dix, torch.tensor([self.stoi['PAD']] * (self.block_size - len(dix)), dtype=torch.long)))
            gix = torch.cat((gix, torch.tensor([self.gtoi['none']] * (self.block_size - len(gix)), dtype=torch.long)))
            vix = torch.cat((vix, torch.tensor([0] * (self.block_size - len(vix)), dtype=torch.long)))

        return dix, gix, vix, torch.tensor([Kd], dtype=torch.float32), name


In [None]:

dataset = scFv_Dataset_v3(train_config, model_config['block_size'], model_config['start_seq_idx'], train_config['train_data_path'], 
                          regularize=train_config['sequence_regularize'])

x, x2, x3, kd, name = dataset.__getitem__(0)
print('x shape:', x.shape, ', x2 shape:', x2.shape, ', x3 shape:', x3.shape, ', kd:', kd, ', name:', name)
print('x.dtype:', x.dtype)

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=config['batch_size'])

In [None]:
it = iter(train_loader)
x, kd = next(it)
print('x.dtype:', x.dtype, ', x.shape:', x.shape)

In [None]:
import math
import random
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
# from datasets.scFv_dataset import scFv_Dataset

#--------------------------------------------------------
# Simple wrapper Dataset to turn output from the scFv dataset
# into a B&W image for use in a CNN model
#--------------------------------------------------------
class CNN_Dataset_BGR(Dataset):
    """
    Emits 2D B&W images and binding energies
    """
    def __init__(self, config, csv_file_path, transform=None, skiprows=0, inference=False):  
        super().__init__()
        self.scFv_dataset = scFv_Dataset(config, csv_file_path, skiprows, inference)
        self.config = config
        self.img_shape = config['image_shape']
        self.transform = transform
         
        chars = self.scFv_dataset.chars
        groups= ['none', 'nonpolar', 'nonpolar', 'neg', 'neg', 'nonpolar', 'nonpolar', 'pos', 'nonpolar', 'pos', 'nonpolar', 'nonpolar', 'neg', 
                'nonpolar', 'neg', 'pos', 'polar', 'polar', 'nonpolar', 'nonpolar', 'polar', 'none', 'none', 'none']
        
        # for VIT, since the residue encodings are spread over 8-bits, assign encodings to groups that spread across the 8-bits
        group_encodings = { 'none'    : int('11001100', base=2), 
                            'polar'   : int('00110011', base=2),
                            'nonpolar': int('01100110', base=2), 
                            'pos'     : int('01010101', base=2),
                            'neg'     : int('10101010', base=2)} 
        
        print('group_encodings:', group_encodings)

        # map encoded sequence to groups
        self.i_to_grp = {self.scFv_dataset.stoi[ch]:group_encodings[i] for ch,i in zip(chars, groups)} 

        # The relative mutation frequence for each amino acid position in the scFv sequences over the entire clean_3 dataset
        # This fixed-array is 241 elements long. (I clipped off the last 5 residues from the 246 residue sequences for the VIT model)
        self.rel_mutation_freq = torch.tensor([ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.7778,
                                                0.9444, 0.9444, 1.0000, 1.0000, 1.0000, 0.9444, 1.0000, 1.0000, 0.8889,
                                                0.5556, 0.5000, 0.2222, 0.3333, 0.2778, 0.6111, 0.4444, 0.5556, 1.0000,
                                                0.8333, 0.8333, 0.8333, 0.9444, 0.7778, 0.8333, 0.6111, 1.0000, 0.8889,
                                                0.3333, 0.9444, 0.8889, 0.1111, 0.3889, 0.9444, 0.2778, 0.9444, 0.8333,
                                                0.5000, 1.0000, 1.0000, 1.0000, 0.9444, 0.6111, 0.6111, 0.7778, 0.2778,
                                                0.8889, 0.3889, 0.9444, 1.0000, 0.3889, 0.9444, 1.0000, 0.9444, 0.2222,
                                                0.7778, 0.5556, 0.8889, 0.2222, 0.7778, 0.6111, 0.6667, 0.8333, 0.8333,
                                                1.0000, 1.0000, 0.8889, 0.8333, 0.8333, 0.9444, 0.7222, 0.9444, 0.9444,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                                0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) 
        
        self.mutation_freq_encoded = self.rel_mutation_freq * 255
        self.mutation_freq_encoded = torch.ceil(self.mutation_freq_encoded).to(torch.long)
        print('min mutation freq:', torch.min(self.mutation_freq_encoded), 'max mutation freq:', torch.max(self.mutation_freq_encoded))
        print('some mutation_freq_encoded values:', self.mutation_freq_encoded[60:70])
        
    def get_vocab_size(self):
        return self.scFv_dataset.vocab_size

    def get_block_size(self):
        return self.config['block_size']

    def __len__(self):
        return self.scFv_dataset.__len__()

    def _bin(self, x):
        return format(x, '08b')

    def _encode_channel(self, x, shape):
        d = ''.join([self._bin(val) for val in x])
        d = [int(x) for x in d] # turn d into a list of integers, one for each bit
        t = torch.tensor(d[:(shape[0]*shape[1])], dtype=torch.float32) # this is for shape matrix
        t = t.reshape(shape)
        return t

    """ Returns image, Kd pairs used for CNN training """
    def __getitem__(self, idx):

        dix, kd = self.scFv_dataset.__getitem__(idx)

        # The residue encoding channel
        ch_1 = self._encode_channel(dix, self.img_shape)
        print('ch_1:', ch_1.shape, torch.min(ch_1), torch.max(ch_1))

        # The residue group encoding channel
        dix_grp = torch.tensor([self.i_to_grp[i] for i in dix.numpy().tolist()], dtype=torch.long)
        ch_2 = self._encode_channel(dix_grp, self.img_shape)
        print('ch_2:', ch_2.shape, torch.min(ch_2), torch.max(ch_2))   

        # The mutation frequency channel; anything not an amino acid gets a zero.
        ch3_in = torch.zeros_like(dix)
        # First aa is always position 1 (0 is a CLS token)
        ch3_in[1:len(self.mutation_freq_encoded)+1] = self.mutation_freq_encoded
        ch_3 = self._encode_channel(ch3_in, self.img_shape)
        print('ch_3:', ch_3.shape, torch.min(ch_3), torch.max(ch_3))   

        # stack the 3 channels into a bgr image
        bgr_img = torch.stack((ch_1, ch_2, ch_3), dim=0) * 255

        if self.transform:
            bgr_img = self.transform(bgr_img)
            
        # Normalize image [-1, 1]
        bgr_img = (bgr_img - 127.5)/127.5


        return bgr_img, kd, ch_1, ch_2, ch_3 

In [None]:
from torch.utils.data import DataLoader
# from datasets.scFv_dataset import scFv_Dataset as dataset
from torchvision.transforms.v2 import Resize, Compose, ToDtype, RandomHorizontalFlip, RandomVerticalFlip 

# train_transforms = Compose([ToDtype(torch.float32, scale=False),
#                             RandomHorizontalFlip(p=0.25),
#                             RandomVerticalFlip(p=0.25)])

train_data_path = config['train_data_path']  
train_dataset = CNN_Dataset_BGR(config, train_data_path) #, train_transforms)
print(train_dataset.__len__())
config['vocab_size'] = train_dataset.get_vocab_size()
print('config[vocab_size]:', config['vocab_size'], ', config[block_size]:', config['block_size'])

train_loader = DataLoader(train_dataset, shuffle=True, pin_memory=True, batch_size=config['batch_size'], num_workers=config['num_workers'])


In [None]:
img, kd, ch_1, ch_2, ch_3 = train_dataset.__getitem__(105)
print(img.dtype)
# change the order of the channels to be (H, W) instead of (C, H, W)
rgb_img = img.permute(1, 2, 0)

print('img shape:', img.shape, ', kd:', kd)
# plt.imshow(rgb_img) #, cmap='gray')

plt.imshow(ch_3, cmap='gray')

print(torch.min(rgb_img), ', ', torch.max(rgb_img), ', ', torch.mean(rgb_img), ', ', torch.std(rgb_img))


----
### Examine other data that may be added as input channels to Vision Transformer.


In [None]:
# Classify the amino acids into their usual groups
# polar, nonpolar, positively charged, negatively charged, or none (i.e. CLS, SEP, PAD)
#
# 20 naturally occuring amino acids in human proteins plus MASK token, 
# 'X' is a special token for unknown amino acids, and CLS token is for classification, and PAD for padding
chars = ['CLS', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X', 'MASK', 'PAD']
groups= ['none', 'nonpolar', 'nonpolar', 'neg', 'neg', 'nonpolar', 'nonpolar', 'pos', 'nonpolar', 'pos', 'nonpolar', 'nonpolar', 'neg', 
         'nonpolar', 'neg', 'pos', 'polar', 'polar', 'nonpolar', 'nonpolar', 'polar', 'none', 'none', 'none']
print('\nvocabulary:', chars)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

# for VIT, since the residue encodings are spread over 8-bits, assign encodings to groups that spread across the 8-bits
group_encodings = { 'none'    : int('00101000', base=2), 
                    'polar'   : int('00110011', base=2),
                    'nonpolar': int('11001100', base=2), 
                    'pos'     : int('01010101', base=2),
                    'neg'     : int('10101010', base=2)} 

print('group_encodings:', group_encodings)

# maps amino acid to group                     
s_to_grp = {ch:group_encodings[i] for ch,i in zip(chars, groups)} 
# maps encoded residue to group
i_to_grp = {stoi[ch]:group_encodings[i] for ch,i in zip(chars, groups)} 

print('\ns_to_grp:', s_to_grp)
print('\ni_to_grp:', i_to_grp)




In [None]:
import torch
def _bin(x):
    return format(x, '08b')

def _encode_channel(x, shape=(48,48)):
    d = ''.join([_bin(x[i]) for i in x.numpy()])
    # turn d into a list of integers, one for each bit
    d = [int(x) for x in d]    
    t = torch.tensor(d[:(shape[0]*shape[1])], dtype=torch.float32) # this is for 46,46 matrix
    t = t.reshape(shape)
    # t = t.unsqueeze(0) # add channel dimension
    return t


In [None]:
# Make a heat-map for the variability of each position in the sequence?
# That somehow changes with each sequence?
#