In [None]:
# Run all the cells up to the Examples Heading to load the algorithm and sample data

In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import gzip

from copy import deepcopy
from matplotlib import pyplot as plt
import Levenshtein as lev

import torch
from torch.nn import parallel

<h3>Import Data</h3>

In [2]:
# Calibrated data
with open("./calibrated_input.pkl", 'rb') as fin:
    f, inputs1_credences, inputs2_credences = pickle.load(fin)
    exec(f)
    inputs1_credences = Compressed(*inputs1_credences).makeFrame()
    inputs2_credences = Compressed(*inputs2_credences).makeFrame()

100%|██████████| 779/779 [00:01<00:00, 746.84it/s] 
100%|██████████| 1200/1200 [00:02<00:00, 444.29it/s]
100%|██████████| 440/440 [00:00<00:00, 2871.92it/s]
100%|██████████| 1018459/1018459 [00:02<00:00, 341249.11it/s]
100%|██████████| 537/537 [00:00<00:00, 1246.77it/s]
100%|██████████| 920/920 [00:01<00:00, 574.82it/s] 
100%|██████████| 502/502 [00:00<00:00, 969.52it/s]
100%|██████████| 634346/634346 [00:01<00:00, 357391.74it/s]


In [3]:
inputs1_credences

(tensor([[0.6949, 0.5643, 0.7047,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.9102, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 0.8891, 0.4362, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 0.8572, 0.3063, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 0.6972, 0.4505, 1.0000]]),
 ['AAFATAQEAY',
  'AALCTFLLNK',
  'AALQIPFAM',
  'AANTVIWDYK',
  'AARYMRSLK',
  'AAVDALCEK',
  'AECTIFKDA',
  'AEVQIDRLI',
  'AEVQIDRLIT',
  'AEWFLAYIL',
  'AFGGCVFSY',
  'AGFSLWVYK',
  'AHFPREGVF',
  'AIASEFSSL',
  'AIDAYPLTK',
  'AKYTQLCQY',
  'ALAPNMMVT',
  'ALCTFLLNK',
  'ALLEDEFTPF',
  'ALNNIINNA',
  'ALNTLVKQL',
  'ALQDAYYRA',
  'ALRANSAVK',
  'AMQTMLFTM',
  'AMYTPHTVL',
  'APAHISTI',
  'AQFAPSASA',
  'AQFAPSASAF',
  'AQKFNGLTVL',
  'AQLPAPRTL',
  'AQPCSDKAY',
  'ARLYYDSMSY',
  'ARYMRSLKV',
  'ASAFFGMSR',
  'ASCDAIMTR',
  'ASFDNFKFV',
  'ASFRLFARTR',
  'ASFSTFKCY',
  'ASHMYCSFY',
 

In [12]:
with open("./ensemble_MIRA_input.pkl", 'rb') as fin:
    f, inputs1_MIRA_01, inputs2_MIRA_01 = pickle.load(fin)
    exec(f)
    inputs1_MIRA_01 = Compressed(*inputs1_MIRA_01).makeFrame()
    inputs2_MIRA_01 = Compressed(*inputs2_MIRA_01).makeFrame()

100%|██████████| 779/779 [00:00<00:00, 1030.68it/s]
100%|██████████| 1200/1200 [00:02<00:00, 404.54it/s]
100%|██████████| 440/440 [00:00<00:00, 569.18it/s] 
100%|██████████| 1018459/1018459 [00:03<00:00, 315924.92it/s]
100%|██████████| 537/537 [00:00<00:00, 2717.29it/s]
100%|██████████| 920/920 [00:01<00:00, 694.65it/s] 
100%|██████████| 502/502 [00:00<00:00, 929.96it/s]
100%|██████████| 634346/634346 [00:01<00:00, 389266.15it/s]


In [12]:
# Filter for peptides that are also present in omicron

# Remove non-omicron peptides
def removeNonOmicron(inputs):
    with open("./omicron_peptides.pkl", 'rb') as fin:
        omicron_peptides = pickle.load(fin)
    
    index = [(i,seq) for i,seq in enumerate(inputs[1]) if seq in omicron_peptides]
    newseqs = [z[1] for z in index]
    index = [z[0] for z in index]
    newmatr = inputs[0][index, :]
    
    print ("After filtering: {}, Before filtering: {}".format(newmatr.shape[0], inputs[0].shape[0]))
    return newmatr, newseqs, inputs[2], inputs[3]

inputs1_credences_omicron = removeNonOmicron(inputs1_credences)
inputs1_credences_omicron = removeNonOmicron(inputs2_credences)

<h3>Greedy Algorithm</h3>

In [2]:
# Utility functions

def getThresholdUtility(n):
    return torch.tensor(np.arange(0,n+1,1),dtype = torch.float32)

def getExponentialUtility(p, upto):
    p = p/(1+p)
    d = p ** np.arange(0,upto+1,1)
    
    # Prevent float precision issues
    index = np.argwhere(d < 10 ** -6).reshape(-1)
    if len(index) > 0:
        index = np.min(index)
        d = d[:index]
        
    return torch.tensor(1-d,dtype = torch.float32)

def getMarginalImprovement(utility):
    return torch.cat( (utility[1:] - utility[:-1], torch.zeros(1)) )

In [5]:
#candidates: [candidate, 1 - pMHC hit probability]
#columnIndex: [diplotype, allele in diplotype]
#columnWeights: [diplotype]
#distributions: [dummy, diplotype, distribution]
#marginalImprovement: [improvement (shifting from i to i+1, so last entry should be 0)]

def evaluateCandidates(candidates, columnIndex, columnWeights, distributions, marginalImprovement):
    probabilityOfHit = 1 - torch.prod(candidates[:, columnIndex], dim = 2).unsqueeze(2)
    shiftedMass = distributions * probabilityOfHit
    improvement = torch.sum( shiftedMass * marginalImprovement, dim = 2)
    weightedImprovement = torch.sum( improvement * columnWeights, dim = 1)
    return weightedImprovement

def updateDistribution(newRow, columnIndex, distributions):
    probabilityOfMiss = torch.prod(newRow[columnIndex], dim = 1).reshape(1, -1, 1)
    shiftedMass = distributions * (1-probabilityOfMiss)
    
    convolution = distributions * probabilityOfMiss
    convolution[:,:,1:] += shiftedMass[:,:,:-1]
    convolution[:,:,-1] += shiftedMass[:,:,-1]
    return convolution

def evaluateDesign(candidates, seqs, columnIndex, columnWeights, design, utility):
    distributions = torch.zeros( (1, len(columnIndex), len(utility)) )
    distributions[:, :, 0] = 1
    
    seqToIndex = {}
    for i, seq in enumerate(seqs):
        seqToIndex[seq] = i
        
    for seq in design:
        row = seqToIndex[seq]
        distributions = updateDistribution(candidates[row], columnIndex, distributions)
        
    scores = torch.sum( distributions * utility.view(1,1,-1), dim = 2 ).reshape(-1)
    return torch.sum(scores * columnWeights).numpy()

class evaluateCandidatesModule(torch.nn.Module):
    def __init__(self, columnIndex, columnWeights, marginalImprovement, device):
        super(evaluateCandidatesModule, self).__init__()
        self.device = device
        self.columnIndex = columnIndex
        self.columnWeights = columnWeights.cuda(self.device)
        self.marginalImprovement = marginalImprovement.cuda(self.device)
        
    def updateDistributions(self, distributions):
        self.distributions = distributions.cuda(self.device)
        
    def forward(self, candidates):
        probabilityOfHit = 1 - torch.prod(candidates[:, self.columnIndex], dim = 2).unsqueeze(2)
        shiftedMass = self.distributions * probabilityOfHit
        improvement = torch.sum( shiftedMass * self.marginalImprovement, dim = 2)
        weightedImprovement = torch.sum( improvement * self.columnWeights, dim = 1)
        return weightedImprovement.cpu()

In [6]:
def greedySelectionMulticore(candidates,
                             seqs,
                             columnIndex,
                             columnWeights,
                             designSize,
                             marginalImprovement,
                             threshold,
                             batchSize,
                             devices):
    
    # Set up modules on different devices
    modules = [evaluateCandidatesModule(columnIndex, columnWeights, marginalImprovement, device)
               for device in devices]
    
    # Distribute the computation between devices
    numRows = candidates.shape[0]
    numVertical = (numRows//(len(devices) * batchSize))
    sliceSize = (numRows//(len(devices) * numVertical)) + 1
    slices = []
    z = 0
    for _ in range(numVertical):
        singleSlice = []
        if z*sliceSize >= numRows: break
        for device in devices:
            if z == numVertical * len(devices) - 1:
                singleSlice.append( candidates[z*sliceSize:].cuda(device) )
            else:
                singleSlice.append( candidates[z*sliceSize:(z+1)*sliceSize].cuda(device) )
            z += 1
        slices.append(singleSlice)
    
    # Initialize selected set and score
    selectedSet = []
    score = 0
    selectable = np.ones(numRows)
    
    # Initialize coverage distributions
    distributions = torch.zeros( (1, len(columnIndex), len(marginalImprovement)) )
    distributions[:, :, 0] = 1
    
    for _ in range(designSize):
        # Update distributions in modules
        for module in modules:
            module.updateDistributions(distributions)

        # Compute marginal utilities
        allImprovements = []
        # We need to batch the following vector operations due to space limitations
        for singleSlice in tqdm( slices, position = 0, leave = True):
            improvements = parallel.parallel_apply(modules, singleSlice)
            allImprovements.append(torch.cat(improvements))
        allImprovements = torch.cat(allImprovements).numpy()

        # Argmax
        selection = np.argmax(allImprovements * selectable)
        
        # Add best sequence
        selectedSeq = seqs[selection]
        selectedSet.append(selectedSeq)
        
        # Update score
        delta = allImprovements[selection]
        score += delta
        
        print("Sequence added: {}, Objective: {:.5f}, Delta: {:.5f}".format(
            selectedSeq, score, delta))
        
        # Update distributions for next round
        distributions = updateDistribution(candidates[selection], columnIndex, distributions)
        
        # Remove invalid candidates from consideration
        for i, seq in enumerate(seqs):
            if lev.distance(seq, selectedSeq) <= threshold:
                selectable[i] = 0
    
    torch.cuda.empty_cache()
    return selectedSet

In [7]:
# Number of GPUs available
list(range(torch.cuda.device_count()))

[0, 1, 2, 3, 4, 5, 6, 7]

<h3>Examples</h3>

In [None]:
# Replace inputs#_credences with inputs#_credences_omicron to use the dataset that has been filtered for
# peptides that are only also present in the Omicron strains

# The arguments to greedySelectionMulticore are as follows:
# 1. A tuple consisting of credences for individual allele binding, a list of peptides, a list of diplotypes
#   where each diplotype is given as a list of alleles, and a set of weights for each diplotype
#     We provide "inputs1_credences", "inputs1_credences_omicron", and "inputs1_MIRA_01" as possible inputs
#     which correspond to the credences we derived, those credences without peptides absent from omicron strains,
#     and credences from Liu et al. respectively.
#     We also provide values for MHC2 as "inputs2_credences", "inputs2_credences_omicron", and "inputs2_MIRA_01"
# 2. The number of peptides in the vaccine design
# 3. The marginal improvement of the utility as an array. The ith entry of the array should
#   contain the marginal improvement for going from i to i+1
# 4. The Levenshtein distance threshold. Peptides that are within this threshold of the peptides
#   that have already been selected will not be considered for inclusion
# 5. The batch size. This is required because of memory limitations on the GPU. The larger this value the better.
#   It seems values between 10-50 work relatively well
# 6. The list of devices to use. list(range(torch.cuda.device_count()))) should enumerate all available devices

In [None]:
# Produce an peptide set with 19 peptides with an exponentially decaying utility using continuous valued binding credences

mhc1_design = greedySelectionMulticore(*inputs1_credences,
               19,
               getMarginalImprovement( getExponentialUtility(4, 19) ),
               3,
               10,
               list(range(torch.cuda.device_count())))

In [None]:
# Produce an peptide set with 26 peptides with a threshold utility using 0-1 binding credences

mhc2_design = greedySelectionMulticore(*inputs1_MIRA_01,
               19,
               getMarginalImprovement( getThresholdUtility(8) ),
               3,
               30, # Batch size should be as large as memory allows to optimize parallelization
               list(range(torch.cuda.device_count())))

In [None]:
# Evaluate both designs using continuous binding credences and the exponentially decaying utility function

evaluateDesign(*inputs1_credences, mhc1_design, getExponentialUtility(0.5, 19))\
,evaluateDesign(*inputs2_credences, mhc2_design, getExponentialUtility(0.5, 19))

In [None]:
# Evaluate both designs using 0-1 binding credences and the threshold utility function

evaluateDesign(*inputs1_MIRA_01, mhc1_design, getThresholdUtility(8))\
,evaluateDesign(*inputs2_MIRA_01, mhc2_design, getThresholdUtility(8))