# Factorizable Library generation

This notebook serves as a demo for training models for use in generating Stochastically Annealed Product Space (SAPS) libraries as presented by _Dai & Saksena et. al. 2022_. Here, an example of generating a library to optimize ranibizumab affinity selection data is presented. 

## Brief background on Google Colab

Google Colab is a free cloud-based service that combines the interactive execution environment of Jupyter Notebooks with cloud compute. A good introduction to Google Colab can be found [here](https://colab.research.google.com/notebooks/welcome.ipynb).

Google Colab notebooks can be used with GPU acceleration. If you haven't already, make a copy of the notebook to your Drive by going to File->Save A Copy In Drive so that you can edit the notebook. To enable GPU acceleration, go to Runtime->Change runtime type and select GPU under Hardware accelerator. In the top right corner, there should also be a green tick indicating you are connected to a hosted runtime. If not, click to reconnect. More information is available [here](https://colab.research.google.com/notebooks/gpu.ipynb)

Then, if correctly set up, `nvidia-smi` should return information about the available GPU. 

In [4]:
!nvidia-smi

Thu Dec  1 07:06:11 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
!git clone https://github.com/sajivsaksena/BME350.git
%cd 

Cloning into 'FactorizableLibrary'...
remote: Enumerating objects: 412, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (58/58), done.[K
remote: Total 412 (delta 25), reused 87 (delta 25), pack-reused 324[K
Receiving objects: 100% (412/412), 292.15 MiB | 38.12 MiB/s, done.
Resolving deltas: 100% (27/27), done.
Checking out files: 100% (317/317), done.
/content/FactorizableLibrary/library_generation


In [6]:
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats

import random
import math
import pickle
from collections import OrderedDict
import time
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data
from torch.nn.parameter import Parameter
from torch.nn import init

from tqdm import tqdm
import pprint

In [7]:
torch.cuda.set_device(0)

## Model definition for library scoring
- See TrainModel.ipynb for additional details

In [8]:
class ResBlock(nn.Module):
    def __init__(self, channels, dilation):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.conv1_bn = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.conv2_bn = nn.BatchNorm1d(channels)
        
    def forward(self, x):
        x1 = F.relu( self.conv1_bn( self.conv1(x) ) )
        return self.conv2_bn( self.conv2(x1) ) + x

class EmbedderNet(nn.Module):
    def __init__(self, length, channels, outchannels):
        super(EmbedderNet, self).__init__()
        self.length = length
        
        self.pool = nn.MaxPool1d(2, 2, ceil_mode = True)
        self.conv = nn.Conv1d(40, channels, 1, 1, 0)
        self.conv_bn = nn.BatchNorm1d(channels)
        
        self.block1 = ResBlock(channels, 1)
        self.block2 = ResBlock(channels, 1)
        self.block3 = ResBlock(channels, 1)
        self.block4 = ResBlock(channels, 1)
        self.block5 = ResBlock(channels, 1)
        
        self.embed_1 = nn.Linear( int( channels*math.ceil(math.ceil(self.length/2)/2) ) , 128)
        self.embed_2 = nn.Linear(128, outchannels)

        # deprecated positional encoding
        position = ( (torch.tensor([i for i in range(self.length)]).view(-1,1)
                      - torch.tensor([i for i in range(self.length)]).view(1,-1)) )
        stack = []
        for i in range(-self.length+1, self.length):
            stack.append((position == i).float())
        positioncode = torch.stack(stack).view(1,(2*self.length-1),self.length,self.length)
        
        self.register_buffer("position", positioncode)
        
    def forward(self, x):
        batchsize = x.shape[0]
        x = F.relu( self.conv_bn( self.conv(x) ) )
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool( self.block4(x) )
        x = self.pool( self.block5(x) ).view(batchsize,-1)
        return self.embed_2( F.relu( self.embed_1(x) ) )
    
class Predictor_Joint(nn.Module):
    def __init__(self):
        super(Predictor_Joint, self).__init__()
        self.embedSeed = EmbedderNet(10, 64, 16)
       
        
    def forward(self, x):
        seed = x
  
        return torch.sum(self.embedSeed(seed), dim = 1)

In [9]:
class NetEnsemble(nn.Module):
    def __init__(self, lst):
        super(NetEnsemble, self).__init__()
        self.nets = nn.ModuleList(lst)
        
    def forward(self, x):
        return torch.stack([net(x) for net in self.nets]).permute((1,0,2))

class Combine(nn.Module):
    def __init__(self):
        super(Combine, self).__init__()
        
    def forward(self, x1, x2):
        return torch.einsum("nij,nij->n",x1,x2)
        
    def getPrefix(self, x1, x2):
        return torch.einsum("nij,ij->n",x1,x2)
    
    def getSuffix(self, x1, x2):
        return torch.einsum("ij,nij->n",x1,x2)

In [10]:
# Pass in a list of models to get an ensemble
def loadModels(fns):
    seedMaps = []

    for fn in fns:
        net = Predictor_Joint()
        net.load_state_dict(torch.load( fn, map_location="cuda:0"))
        seedMaps.append(net.embedSeed)
    seedEnsemble = NetEnsemble(seedMaps)
    combine = Combine()
    seedEnsemble.eval().cuda()
    combine.eval()
    return seedEnsemble, combine

def calibrate(model):
    with open("../model_training/encoding.pkl", 'rb') as fin:
        encoding = pickle.load(fin)
    seed = []
    with open("calibration.txt", 'rt') as fin:
        for line in fin:
            line = line.rstrip('\n')
            line = [encoding[c] for c in line]
            seed.append(np.array(line[:]).T)
         
    
    calibrationset = torch.utils.data.TensorDataset(
        torch.tensor(np.stack(seed), dtype = torch.float32))
    loader = torch.utils.data.DataLoader(calibrationset, batch_size=10000, shuffle=False, num_workers=2, drop_last=False)
    
    outputs = []
    seedModel, suffixModel, combine = model
    with torch.no_grad():
        for seed in tqdm(loader, position = 0, leave = True):
            outputs.append( combine( seedModel(seed.cuda())))
    return np.std(np.concatenate(outputs), ddof = 1)

def getTranslate():
    with open("../model_training/encoding.pkl", 'rb') as fin:
        onehot = pickle.load(fin)
    dic = {}
    for k in onehot:
        dic[k] = torch.tensor(onehot[k]).float().cuda()
    return dic


## Objects for library generation

In [11]:
class runningSum():
    def __init__(self, embedShape):
        self.x = torch.zeros(*embedShape).double()
        self._c = torch.zeros(*embedShape).double()
        
    def add(self, y):
        self.x = self.x + y

class Translator():
    def __init__(self, translate, model):
        self.SymbolToCode = translate
        self.NumberToLetter = [aa for aa in self.SymbolToCode]
        self.NumberToCode = []
        self.SymbolToNumber = {}
        self.CodeToEmbedding = model
        self.pad = -1
        for i, aa in enumerate(self.NumberToLetter):
            self.SymbolToNumber[aa] = i
            self.NumberToCode.append( self.SymbolToCode[aa] )
            if aa == 'J':
                self.pad = i
            
    def toString(self, seq):
        return ''.join([self.NumberToLetter[i] for i in seq])

class Sequence():
    def __init__(self, sequence, translator):
        self.translator = translator
        self.sequence = [self.translator.SymbolToNumber[c] for c in sequence]
        self.code = torch.stack( [self.translator.NumberToCode[c] for c in self.sequence] )
        self.tracker = None
        self.runningSum = None
        self.embedding = None
            
    def register(self, tracker):
        self.tracker = tracker
        for i,j in enumerate(self.sequence):
            self.tracker.counts[i][j] += 1
        self.tracker.seqs.add( self.translator.toString(self.sequence) )
        
    def register2(self, runningSum):
        self.runningSum = runningSum
        self.runningSum.add(self.embedding)
        
    def update(self, position, aa_number, new_embedding):
        self.tracker.seqs.remove( self.translator.toString(self.sequence) )
        self.tracker.counts[position][ self.sequence[position] ] -= 1
        self.sequence[position] = aa_number
        self.tracker.counts[position][ aa_number ] += 1
        self.tracker.seqs.add( self.translator.toString(self.sequence) )
        
        self.code[position] = self.translator.NumberToCode[aa_number]
        with torch.no_grad():
            self.runningSum.add(-self.embedding.double())
            self.embedding = new_embedding
            self.runningSum.add(self.embedding.double())
            
    def isFree(self, position, proposals):
        current = self.sequence[position]
        freeMove = []
        for i in proposals:
            self.sequence[position] = i
            freeMove.append(self.translator.toString(self.sequence) not in self.tracker.seqs)
        self.sequence[position] = current
        return freeMove
    
    # Generate a list of embeddings corresponding to samples that can then be used to calculate deltas upstream
    def mockUpdates(self, position):
        samples = []
        current = self.sequence[position]
        for i in range(len(self.translator.NumberToCode)):
            if i != current and i != self.translator.pad:
                self.sequence[position] = i
                samples.append(i)
        self.sequence[position] = current
        if len(samples) == 0:
            return [], []
        
        seqs = torch.stack( [self.code for i in range(len(samples))] )
        for i,proposal in enumerate(samples):
            seqs[i, position] = self.translator.NumberToCode[proposal]
            
        return seqs, samples

class FrequencyTracker():
    def __init__(self, sequences, length, sigma):
        self.counts = np.zeros( (length, sigma) )
        self.seqs = set()
        for seq in sequences:
            seq.register(self)
        
class SubLibrary():
    def __init__(self, sequences, translate, model, updateParam, length):
        sigma = len(translate)
        
        self.translator = Translator(translate, model)
        self.seqs = [Sequence(seq, self.translator) for seq in sequences]
        self.size = len(self.seqs)
        self.updateParam = updateParam
        
        with torch.no_grad():
            embeds = self.translator.CodeToEmbedding( torch.stack([seq.code for seq in self.seqs]).permute(0,2,1) ).cpu()
            for i, seq in enumerate(self.seqs):
                seq.embedding = embeds[i].detach()
        
        partition = {}
        for seq, update in zip(self.seqs, self.updateParam):
            if update not in partition:
                partition[update] = []
            partition[update].append(seq)
        
        self.frequencyTrackers = {}
        for k in partition:
            self.frequencyTrackers[k] = FrequencyTracker(partition[k], length, sigma)
        
        self.embeddingSum = runningSum(self.seqs[0].embedding.shape)
        for seq in self.seqs:
            seq.register2(self.embeddingSum)
            
    def _entropy(self, arr, z):
        return -np.log(np.maximum(1,arr)/z) * (arr/z)
    
    def dedupe(self, seqindex, position, ls):
        ls2 = []
        proposals = ls[0]
        isFree = self.seqs[seqindex].isFree(position, proposals)
        for arr in ls:
            ls2.append([z for z,x in zip(arr,isFree) if x])
        return ls2
    
    def getEntropyChanges(self, counts, current, proposals):
        x = counts[current]
        z = np.sum(counts)
        h_now = self._entropy(counts, z)
        dh = self._entropy(counts[current]-1, z) - h_now[current]
        deltah = (self._entropy(counts+1, z) - h_now) + dh
        return np.take(deltah, proposals) * z
    
    def proposeEveryChange(self, seqindex, position):
        codes, proposals = self.seqs[seqindex].mockUpdates(position)
        if len(proposals) == 0: return [], None, []
        return codes, self.seqs[seqindex].embedding, proposals
    
    def getDeltaEntropy(self, seqindex, position, proposals):
        deltaEntropy = self.getEntropyChanges(self.seqs[seqindex].tracker.counts[position],
                                              self.seqs[seqindex].sequence[position],
                                              proposals)
        return deltaEntropy
    
    def update(self, seqindex, position, aa_number, new_embedding):
        self.seqs[seqindex].update(position, aa_number, new_embedding)
        
    def getEntropyScores(self):
        report = []
        for k in self.frequencyTrackers:
            counts = self.frequencyTrackers[k].counts
            z = np.sum(counts[0])
            h = np.array( [np.sum( self._entropy(arr, z) ) for arr in counts] )
            report.append( (z,h) )
        return report
    
    def getLib(self):
        return [self.translator.toString(z.sequence) for z in self.seqs]

class Optimizer():
    def __init__(self, s_seed, suffix_seed, updateSeed, seedLength, translate, model, std):
        seedModel, suffixModel, combine = model
        self.seedLib = SubLibrary(s_seed, translate, seedModel, updateSeed, seedLength)
        self.combine = combine
        self.entropyWeight = 0.01
        self.T = 50
        self.std = std
        self.updateSeed = updateSeed
        self.seedLength = seedLength
        
    def sweep(self):
        sweep = [i for i in range(self.seedLength)]
        for s in zip(sweep):
            scol = self.sweepColumn(p, self.updateSeed, self.seedLib)
            for i in range(max(len(pcol), len(scol))):
                if i < len(pcol):
                    self.executeUpdate(True, pcol[i], self.seedLib, s)

            torch.cuda.empty_cache()
                    
    def executeUpdate(self, isSeed, command, lib, otherlib, position):
        seqindex, dEmbed, cEmbed, proposals = command
        with torch.no_grad():
            dScore = (self.combine.getSeed(dEmbed-cEmbed, self.suffixLib.embeddingSum.x.float()) if isSeed
                      else self.combine.getSuffix(self.seedLib.embeddingSum.x.float(), dEmbed-cEmbed)).numpy()
            
        proposals, dEmbed, dScore = lib.dedupe( seqindex, position, (proposals, dEmbed, dScore) )
        dScore = np.array(dScore)
        if len(proposals) == 0: return
        
        dEntropy = lib.getDeltaEntropy(seqindex, position, proposals)
        dScore = (dScore/ (otherlib.size*self.std)) + (dEntropy * self.entropyWeight)
        
        dScore = np.concatenate( ((dScore * self.T), [0]) )
        proposals.append(-1)
        newIndex = random.choices(range(len(proposals)), np.exp(dScore - np.max(dScore)))[0]
        if proposals[newIndex] != -1:
            lib.update(seqindex, position, proposals[newIndex], dEmbed[newIndex].detach())
          
    def sweepColumn(self, position, updates, lib):
        index = 0
        commands = []
        bigTensor = []
        for seqindex, (mn, mx) in enumerate( updates ):
            if mn <= position < mx:
                codeProposals, cEmbed, proposals = lib.proposeEveryChange(seqindex, position)
                if len(proposals) == 0: continue
                
                rng = (index, index + len(proposals))
                commands.append( (seqindex, cEmbed, proposals, rng) )
                bigTensor.append(codeProposals)
                index += len(proposals)
        if len(bigTensor) == 0: return []
        with torch.no_grad():
            outTensor = lib.translator.CodeToEmbedding( torch.cat(bigTensor, dim = 0).permute((0,2,1)).cuda() )
        commands2 = []
        for i, (seqindex, cEmbed, proposals, rng) in enumerate(commands):
            dEmbed = outTensor[rng[0]:rng[1]].detach().cpu()
            commands2.append((seqindex, dEmbed, cEmbed, proposals))
        return commands2
        
    def getScore(self):
        with torch.no_grad():
            score = self.combine(self.seedLib.embeddingSum.x.float().unsqueeze(0))
        h = self.seedLib.getEntropyScores()
        ht = np.sum([z*h for z,h in h2]) * self.seedLib.size
        total = -score + (ht*self.entropyWeight)
        return total
    
    def getLib(self):
        return self.seedLib.getLib()

In [12]:
def getSeed(n):
    aas = list("ACDEFGHIKLMNPQRSTVWYJ")
    sample = np.random.randint(0,20, (n*2, 10) )
    sample = set(sample)
    if len(sample) < n:
        return getSeed(n)
    return list(sample)[:n]


## Generate a factorizable library

In [17]:
# Initialize the ensemble with a list of models
# Here, you can change the path to a folder containing model weights of your choice
# In this example, we optimize ranibizumab affinity
model_dir = "../model_training/trained_weights/reverse_kernel/baculovirus/"
ensembles = loadModels([os.path.join(model_dir, file) for file in os.listdir(model_dir)])

# Determine the standard deviation of the ensemble
std = calibrate(ensembles)

RuntimeError: ignored

In [None]:
# Initialize seed sequences
p,pu = getSeed(10)


# Initialize the optimizer
opt = Optimizer(s_seed=p, 
                updateSeed=pu,  
                seedLength=10,  
                translate=getTranslate(), 
                model=ensembles, 
                std=std)

# Set entropy parameter lambda
opt.entropyWeight = 0.1

In [None]:
# Initialize temperature
opt.T = 1

bestLib = None
bestScore = float("-inf")

# Optimize for 10 iterations
# change the value in the range() function to increase iterations
pbar = tqdm(range(10), position=0, leave=True)
for i in pbar:
    opt.sweep()
    torch.cuda.empty_cache()
    score = opt.getScore()
    if score > bestScore:
        bestLib = opt.getLib()
    pbar.set_description("Score = {:.2f}, Temperature = {:.5f}".format(score, 1/opt.T))
    
    # Lower the temperature every 5 iterations
    if i%5==4:
        opt.T *= 1.1
        
# Optimized segment libraries
seedLibrary = bestLib

Score = 15830.26, Temperature = 0.90909: 100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


In [None]:
pp = pprint.PrettyPrinter(width=150, compact=True)
print("Baculovirus library: \n")
pp.pprint(seedLibrary)


Prefix library: 

['CRQRNCGDPG', 'WIPRIRLWWG', 'SPRSKCRCKA', 'GDQWPWKMRR', 'YFSIWTECDI', 'CAHCRCRHLW', 'GRGLDFQEQA', 'WDYSWGVNPA', 'KYFKCMQKIA', 'LMCQNGCLMH',
 'JFYRLWCVSP', 'JTSQWICGKN', 'JFCHNTYMCH', 'JWQMLQQNSS', 'JLVRNWKCGC', 'JCKQKNHHRE', 'JCQQEYCEFV', 'JKNRNQQPHW', 'JGFWEAYRSW', 'JLIRLQRIKL',
 'JJSSSTWMNW', 'JJMKKCNIRP', 'JJSGHRPQLD', 'JJLIISTMNS', 'JJFCFRKAFM', 'JJCRRQLPMY', 'JJERMIIIKV', 'JJKLKFVKMV', 'JJHGETHFPL', 'JJWHLSRCYK',
 'JJJIKSRKTN', 'JJJVFYYMWC', 'JJJACKNKLP', 'JJJWKNIMTK', 'JJJTCVVIDA', 'JJJLKNKVGN', 'JJJQRRQVNK', 'JJJYHGLCNK', 'JJJRSQKPMT', 'JJJARSRVSF',
 'JJJJMFNAWW', 'JJJJWCQQSF', 'JJJJCCCQCQ', 'JJJJTTFVLI', 'JJJJSKKHNP', 'JJJJEFELSC', 'JJJJPWWADF', 'JJJJINITFR', 'JJJJWLCTKH', 'JJJJWCPSAR',
 'JJJJJWNLWW', 'JJJJJQQYKD', 'JJJJJYKCFL', 'JJJJJDNVKY', 'JJJJJFKNIC', 'JJJJJVPDTK', 'JJJJJLEERC', 'JJJJJFRARV', 'JJJJJILTCF', 'JJJJJYCRHE',
 'JJJJJJVWDS', 'JJJJJJRKNI', 'JJJJJJQNGN', 'JJJJJJRCKV', 'JJJJJJWWCA', 'JJJJJJNYIG', 'JJJJJJCVYA', 'JJJJJJLCKN', 'JJJJJJMCTA', 'JJJJJJWL