<h3>Imports and utilities</h3>

In [124]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [142]:
class resBlock(nn.Module):
    def __init__(self):
        super(resBlock, self).__init__()
        self.conv1 = nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(256)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + x
        out = self.relu2(out)
        return out

class PUFFIN(nn.Module):
    def __init__(self):
        super(PUFFIN, self).__init__()
        self.conv = nn.Conv1d(40, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm1d(256)
        self.relu = nn.ReLU(inplace=True)
        self.resBlocks = nn.Sequential(*[resBlock() for i in range(5)])
        self.avgpool = nn.AvgPool1d(9, stride = 1, padding = 0)

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        out = self.resBlocks(out)
        out = self.avgpool(out).view((-1, 256))
        return out
    
class GaussianPredictor(nn.Module):
    def __init__(self):
        super(GaussianPredictor, self).__init__()
        self.embed = PUFFIN()
        self.dropout = nn.Dropout(0.2)
        
        self.extractMean = nn.Linear(256, 1)
        self.extractStd = nn.Linear(256, 1)
        
    def forward(self, x):
        embedding = self.embed(x)
        noisy_embedding = self.dropout(embedding)
        
        mean = self.extractMean(noisy_embedding)
        std = F.softplus(self.extractStd(noisy_embedding))
        return mean, std
    
class CategoricalPredictor(nn.Module):
    def __init__(self):
        super(CategoricalPredictor, self).__init__()
        self.embed = PUFFIN()
        self.dropout = nn.Dropout(0.2)
        
        self.extractCategory = nn.Linear(256, 8)
        
    def forward(self, x):
        embedding = self.embed(x)
        noisy_embedding = self.dropout(embedding)
        
        dist = F.softmax(self.extractCategory(noisy_embedding), dim = 1)
        return dist

In [143]:
def load_categorical_model(allele, split, initialization):
    fn = "./models/Categorical_DR40{}_split{}_init{}.pt".format(allele, split, initialization)
    model = CategoricalPredictor()
    model.load_state_dict(torch.load(fn))
    model.eval()
    model.dropout.train()
    return model

def load_gaussian_model(allele, split, initialization):
    fn = "./models/Gaussian_DR40{}_split{}_init{}.pt".format(allele, split, initialization)
    model = GaussianPredictor()
    model.load_state_dict(torch.load(fn))
    model.eval()
    model.dropout.train()
    return model

def runCategoricalEnsemble(ens, x):
    with torch.no_grad():
        exm1 = torch.tensor( [float(i) for i in range(8)] ).cuda()
        exm2 = torch.tensor( [float(i**2) for i in range(8)] ).cuda()
        mns = []
        vrs = torch.zeros((x.size(0),1)).cuda()
        for model in ens:
            for i in range(50):
                out = model(x)
                mom1 = torch.sum(out*exm1, dim = 1).view((-1,1))
                mom2 = torch.sum(out*exm2, dim = 1).view((-1,1))
                mns.append( mom1 )
                vrs = vrs + mom2 - (mom1**2)
        mns = torch.cat(mns, dim = 1)
        vrs = ( (vrs/1000).view(-1) + torch.var(mns, dim = 1, unbiased=False) ).cpu().numpy()
        mns = torch.mean(mns, dim = 1).cpu().numpy()
    return mns, vrs

def runGaussianEnsemble(ens, x):
    with torch.no_grad():
        mns = []
        vrs = torch.zeros((x.size(0),1)).cuda()
        for model in ens:
            for i in range(50):
                mean,std = model(x)
                mns.append( mean )
                vrs = vrs + (std ** 2)
        mns = torch.cat(mns, dim = 1)
        vrs = ( (vrs/1000).view(-1) + torch.var(mns, dim = 1, unbiased=False) ).cpu().numpy()
        mns = torch.mean(mns, dim = 1).cpu().numpy()
    return mns, vrs

<h3>Running PUFFIN and generating anchor substitutions</h3>

In [149]:
# These functions load the ensembles and the functions for running them

def getEnsemble_Categorical(allele):
    ens = []
    for split in range(1,11):
        for i in (1,2):
            model = load_categorical_model(allele, split, i).cuda()
            ens.append(model)
    return ens, runCategoricalEnsemble

def getEnsemble_Gaussian(allele):
    ens = []
    for split in range(1,11):
        for i in (1,2):
            model = load_gaussian_model(allele, split, i).cuda()
            ens.append(model)
    return ens, runGaussianEnsemble

# Get Categorical predictor for DR401
#getEnsemble_Categorical(1)
# Get Categorical predictor for DR402
#getEnsemble_Categorical(2)
# Get Gaussian predictor for DR401
#getEnsemble_Gaussian(1)
# Get Gaussian predictor for DR402
#getEnsemble_Gaussian(2)

In [114]:
# embed will embed the required function

aaEmbedding = {}
with open("./aa_embedding.txt", 'rt') as fin:
    for line in fin:
        line = line.rstrip('\n').split(',')
        aaEmbedding[line[0]] = tuple( [float(z) for z in line[1:]] )
        
def embed(seqs):
    embedded = []
    for seq in seqs:
        embedded.append(torch.tensor([aaEmbedding[c] for c in seq]).permute(1,0))
    return embedded

In [156]:
# Example of how to run PUFFIN
def example_use():
    # Embed sequences as into a batch of tensors
    t = torch.stack(embed(["QMCPGDGRP", "FRVSSTLQA"])).cuda()
    
    # Load ensemble and function for running them
    ens, run_ens = getEnsemble_Categorical(1)
    
    # Run ensemble on sequences
    mean, var = run_ens(ens, t)
    print ("Mean: {}\nVar: {}".format(mean, var))
    
# Run this to test if we can load the ensemble and run predictions
#example_use()

In [154]:
def generateLandscape(s):
    aa = "ACDEFGHIKLMNPQRSTVWY"
    substitutions = []
    for aa1 in aa:
        for aa2 in aa:
            for aa3 in aa:
                for aa4 in aa:
                    substitutions.append( aa1 + s[1:3] + aa2 + s[4:5] + aa3 + s[6:8] + aa4 )
    return substitutions

# Generate proposals (PE)
def propose_subst_PE(seq, ensemble, batch_size = 100):
    ens, run_f = ensemble
    subst = generateLandscape(seq)
    i = 0
    means = []
    while i < len(subst):
        emb = torch.stack(embed( subst[i: i+batch_size] )).cuda()
        mns, _ = run_f(ens, emb)
        means.append(mns)
        i += batch_size
        print ("Progress: {}/160000".format(i))
    scores = np.concatenate(means)
    best = np.argsort(-scores)
    proposals = [subst[j] for j in best[:11]]
    if seq in proposals:
        proposals.remove(seq)
    return proposals[:10]
    
# Generate proposals (UCB)
def propose_subst_UCB(seq, ensemble, batch_size = 100):
    ens, run_f = ensemble
    subst = generateLandscape(seq)
    i = 0
    means = []
    variances = []
    while i < len(subst):
        emb = torch.stack(embed( subst[i: i+batch_size] )).cuda()
        mns, vrs = run_f(ens, emb)
        means.append(mns)
        variances.append(vrs)
        i += batch_size
        print ("Progress: {}/160000".format(i))
    scores = np.concatenate(means) + np.sqrt(np.concatenate(variances))
    best = np.argsort(-scores)
    proposals = [subst[j] for j in best[:11]]
    if seq in proposals:
        proposals.remove(seq)
    return proposals[:10]    

# Example: generate anchor substitutions for QMCPGDGRP that optimize for DR401 binding using the Gaussian model
# Adjust batch_size to address memory limitations/runtime
#propose_subst_UCB("QMCPGDGRP", getEnsemble_Gaussian(1))