# Deep-COVID: COVID-19 Epitope Vaccine Design - A Deep Learning Based Framework

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import torch
import math
from time import time
import torch.distributions as tdis
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset, RandomSampler, BatchSampler, DataLoader

In [None]:
def ReadTxtName(rootdir):                        ##FASTA Reader
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            lines.append(line)
    return lines
def read_from_file_with_enter(filename):
    fr = open(filename,'r')
    sample = ""
    samples = []
    for line in fr:
        if line.startswith('>'):
            sample = ""
            continue
        if line.startswith('\n'):
            samples.append(sample)
            continue
        sample += line[:-1]
    return samples

#### Define the Z-descriptor and ACC transformation for protein annotation

In [None]:
#Z-descriptor:
z1 = 0
z2 = 0
z3 = 0
def Z(antigen):
    
    global z1
    global z2
    global z3
    
    if antigen == "A":
        z1     =  0.07
        z2     = -1.73
        z3     =  0.09
        
    if antigen == "V":
        z1     = -2.69
        z2     = -2.53
        z3     = -1.29
        
    if antigen == "L":
        z1     = -4.19
        z2     = -1.03
        z3     = -0.98
        
    if antigen == "I":
        z1     = -4.44
        z2     = -1.68
        z3     = -1.03 
        
    if antigen == "P":
        z1     = -1.22
        z2     =  0.88
        z3     =  2.23
        
    if antigen == "F":
        z1     = -4.92
        z2     =  1.30
        z3     =  0.45
        
    if antigen == "W":
        z1     = -4.75
        z2     =  3.65
        z3     =  0.85
    
    if antigen == "M":
        z1     = -2.49
        z2     = -0.27
        z3     = -0.41
    
    if antigen == "K":
        z1     =  2.84
        z2     =  1.41
        z3     = -3.14
        
    if antigen == "R":
        z1     =  2.88
        z2     =  2.52
        z3     = -3.44
        
    if antigen == "H":
        z1     =  2.41
        z2     =  1.74
        z3     =  1.11
        
    if antigen == "G":
        z1     =  2.23
        z2     = -5.36
        z3     =  0.30
        
    if antigen == "S":
        z1     =  1.96
        z2     = -1.63
        z3     =  0.57
        
    if antigen == "T":
        z1     =  0.92
        z2     = -2.09
        z3     = -1.40
        
    if antigen == "C":
        z1     =  0.71
        z2     = -0.97
        z3     =  4.13
        
    if antigen == "Y":
        z1     = -1.39
        z2     =  2.32
        z3     =  0.01
        
    if antigen == "N":
        z1     =  3.22
        z2     =  1.45
        z3     =  0.84
        
    if antigen == "Q":
        z1     =  2.18
        z2     =  0.53
        z3     = -1.14
        
    if antigen == "D":
        z1     =  3.64
        z2     =  1.13
        z3     =  2.36
        
    if antigen == "E":
        z1     =  3.08
        z2     =  0.39
        z3     = -0.07
        
    return z1, z2, z3

##ACC Transformation: transfer the dataset to the same length
def ACC (dataset):    

    ACCN = [0] * len(dataset)
    for i in range (len(ACCN)):
        ACCN[i] = [0] * 46 ## 45 variables for ACC and 1 for whether it's BPAs

    for i in range (len(dataset)):
        description = [0] * len(dataset[i])
        for k in range(len(dataset[i])):
            description[k] = Z(dataset[i][k])


        ##l=1
        ##Calculate AJJ:
        l   = 1
        n   = len(dataset[i])
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][0] / (n-l))
        ACCN[i][0] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][1] / (n-l))
        ACCN[i][1] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][2] / (n-l))
        ACCN[i][2] = cal

        ##Calculate CJK:
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][1] / (n-l))
        ACCN[i][3] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][2] / (n-l))
        ACCN[i][4] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][0] / (n-l))
        ACCN[i][5] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][2] / (n-l))
        ACCN[i][6] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][0] / (n-l))
        ACCN[i][7] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][1] / (n-l))
        ACCN[i][8] = cal

        ##l=2
        ##Calculate AJJ:
        l   = 2
        n   = len(dataset[i])
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][0] / (n-l))
        ACCN[i][9] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][1] / (n-l))
        ACCN[i][10] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][2] / (n-l))
        ACCN[i][11] = cal

        ##Calculate CJK:
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][1] / (n-l))
        ACCN[i][12] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][2] / (n-l))
        ACCN[i][13] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][0] / (n-l))
        ACCN[i][14] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][2] / (n-l))
        ACCN[i][15] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][0] / (n-l))
        ACCN[i][16] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][1] / (n-l))
        ACCN[i][17] = cal

        ##l=3
        ##Calculate AJJ:
        l   = 3
        n   = len(dataset[i])
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][0] / (n-l))
        ACCN[i][18] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][1] / (n-l))
        ACCN[i][19] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][2] / (n-l))
        ACCN[i][20] = cal

        ##Calculate CJK:
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][1] / (n-l))
        ACCN[i][21] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][2] / (n-l))
        ACCN[i][22] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][0] / (n-l))
        ACCN[i][23] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][2] / (n-l))
        ACCN[i][24] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][0] / (n-l))
        ACCN[i][25] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][1] / (n-l))
        ACCN[i][26] = cal

        ##l=4
        ##Calculate AJJ:
        l   = 4
        n   = len(dataset[i])
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][0] / (n-l))
        ACCN[i][27] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][1] / (n-l))
        ACCN[i][28] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][2] / (n-l))
        ACCN[i][29] = cal

        ##Calculate CJK:
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][1] / (n-l))
        ACCN[i][30] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][2] / (n-l))
        ACCN[i][31] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][0] / (n-l))
        ACCN[i][32] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][2] / (n-l))
        ACCN[i][33] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][0] / (n-l))
        ACCN[i][34] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][1] / (n-l))
        ACCN[i][35] = cal

        ##l=5
        ##Calculate AJJ:
        l   = 5
        n   = len(dataset[i])
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][0] / (n-l))
        ACCN[i][36] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][1] / (n-l))
        ACCN[i][37] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][2] / (n-l))
        ACCN[i][38] = cal

        ##Calculate CJK:
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][1] / (n-l))
        ACCN[i][39] = cal
        cal = 0
        for j in range (n-l):
            cal = cal + (description[j][0] * description[j+1][2] / (n-l))
        ACCN[i][40] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][0] / (n-l))
        ACCN[i][41] = cal
        for j in range (n-l):
            cal = cal + (description[j][1] * description[j+1][2] / (n-l))
        ACCN[i][42] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][0] / (n-l))
        ACCN[i][43] = cal
        for j in range (n-l):
            cal = cal + (description[j][2] * description[j+1][1] / (n-l))
        ACCN[i][44] = cal
        
    return ACCN     

In [None]:
##Prepare T-cell Epitope dataset:
##Read the dataset
file    = r'./PositiveT.txt'
datasetT1 = read_from_file_with_enter(file)

file    = r'./NegativeT.txt'
dataset21 = ReadTxtName(file)
datasetT2 = [0] * 774
for i in range (774):
    datasetT2[i] = dataset21[(i+1)*2-1]

##datasetT1 for positive T-cell epitopes and datasetT2 for negative T-cell epitopes

In [None]:
##Prepare B-cell Epitope dataset:
##Read the dataset
file      = r'./PositiveB.txt'
datasetB1 = ReadTxtName(file)

file      = r'./NegativeB.txt'
datasetB2 = ReadTxtName(file)

##datasetB1 for positive B-cell epitopes and datasetB2 for negative B-cell epitopes

#### Train the Protective Antigen Prediction Tool

In [None]:
##Prepare Protective Antigen dataset:
##Read the dataset
file      = r'./Antigen.txt'
datasetA  = ReadTxtName(file)

##First half of datasetA is protective antigens

##Protein annotation:
target = ACC(datasetA)
for i in range (len(target)):
    if i < 300:
        target[i][45] = 1
    if i > 299:
        target[i][45] = 0


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


def preprocess(x):
    return x.view(-1, 1, 45, 1)



def Antigens(antigens, num_epochs=9000, batch_size=128, model=None, lrate=0.001):
    
    counter   = antigens.shape[1] - 1
    X         = antigens[:,:counter]        #Protein ACC
    Y         = antigens[:,45:]             #Whether it's BPA
    
    
    if not model:
        model = nn.Sequential(                         #Deep Neural Network
                    Lambda(preprocess),
                    nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(16, 1, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(),
                    nn.AvgPool2d(2),
                    Lambda(lambda x: x.view(x.size(0), -1)),
                )
        
        model2 = nn.Sequential(       
                    nn.Linear(32, 64),
                    nn.Tanh(),
                    nn.Linear(64, 128),
                    nn.Tanh(),
                    nn.Linear(128, 64),
                    nn.Tanh(),
                    nn.Linear(64, 32),
                    nn.Tanh(),
                    nn.Linear(32, 1),                    
                    nn.Sigmoid()
                )
        
    opt   = optim.Adam(model.parameters(), lr=lrate)
    td    = TensorDataset(X, Y)
    epoch = 0

    while epoch < num_epochs: 
        for x, y in DataLoader(td, batch_size, shuffle=True, drop_last=True):   
            opt.zero_grad()
            result = model(x)
            loss   = ((result - y)**2).mean()
            #loss   = nn.CrossEntropyLoss(result, y).mean()
            #loss   = torch.mean(torch.clamp(1 - result.t() * y, min=0))  # hinge loss
            #loss  += 0.01 * torch.mean(model.weight ** 2)  # l2 penalty
            loss.backward()
            opt.step()
    
        epoch += 1
        
        print(loss)
    return model

Amodel = Antigens(torch.tensor(target))
##Amodel is the protective antigen prediction tool

#### Design the training data for Deep-COVID

In [None]:
##Cartesian Products of the T-cell and B-cell epitopes:
candidates = []
for i in range (len(datasetB1)):
    for j in range (len(datasetT1)):
        candidates.append(datasetB1[i]+datasetT1[j])
        candidates.append(datasetT1[j]+datasetB1[i])
        
##Sieve out the protective antigens:
dataset = []

for i in range (len(candidates)):
    target  = ACC([candidates[i]])[0][0:45]
    if (Amodel(torch.tensor(target)) - 1)**2<0.015:
        dataset.append(candidates[i])
        
##Prepare the same length negative dataset:
Ndataset1 = []
Ndataset  = []
for i in range(len(datasetT2)):
    for j in range (len(datasetB2)):
        Ndataset1.append(datasetB2[j]+datasetT2[i])
for i in range (667786):
    Ndataset.append(Ndataset1[i*5])
    
finalset = dataset + Ndataset
##finalset is the training data for Deep-COVID

#### Train Deep-COVID

In [None]:
##Protein Annotation:
target = ACC(finalset)
for i in range (len(target)):
    if i < (len(target)/2):
        target[i][45]=1
    if i > (len(target)/2):
        target[i][45]=0
        
##Train Deep-COVID:
Deep_COVID = Antigens(torch.tensor(target), num_epochs=2000, batch_size=8192, model=None, lrate=0.001)

### Input an annotated protein sequence to Deep_COVID, it will output whether it is a potential peptide vaccine