In [None]:
# Import libraries, key resources
from Bio.SeqUtils import seq3, seq1
from Bio.Seq import Seq
import codecs
from collections import Counter
import csv
from datetime import datetime
import geopandas as gpd
from matplotlib.colors import ListedColormap
from matplotlib.dates import DateFormatter, MonthLocator
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
from skbio import DNA
from skbio.alignment import local_pairwise_align_ssw
from skbio import Protein
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import warnings

blosum50 = \
    {
        '*': {'*': 1, 'A': -5, 'C': -5, 'B': -5, 'E': -5, 'D': -5, 'G': -5,
              'F': -5, 'I': -5, 'H': -5, 'K': -5, 'M': -5, 'L': -5,
              'N': -5, 'Q': -5, 'P': -5, 'S': -5, 'R': -5, 'T': -5,
              'W': -5, 'V': -5, 'Y': -5, 'X': -5, 'Z': -5},
        'A': {'*': -5, 'A': 5, 'C': -1, 'B': -2, 'E': -1, 'D': -2, 'G': 0,
              'F': -3, 'I': -1, 'H': -2, 'K': -1, 'M': -1, 'L': -2,
              'N': -1, 'Q': -1, 'P': -1, 'S': 1, 'R': -2, 'T': 0, 'W': -3,
              'V': 0, 'Y': -2, 'X': -1, 'Z': -1},
        'C': {'*': -5, 'A': -1, 'C': 13, 'B': -3, 'E': -3, 'D': -4,
              'G': -3, 'F': -2, 'I': -2, 'H': -3, 'K': -3, 'M': -2,
              'L': -2, 'N': -2, 'Q': -3, 'P': -4, 'S': -1, 'R': -4,
              'T': -1, 'W': -5, 'V': -1, 'Y': -3, 'X': -1, 'Z': -3},
        'B': {'*': -5, 'A': -2, 'C': -3, 'B': 6, 'E': 1, 'D': 6, 'G': -1,
              'F': -4, 'I': -4, 'H': 0, 'K': 0, 'M': -3, 'L': -4, 'N': 5,
              'Q': 0, 'P': -2, 'S': 0, 'R': -1, 'T': 0, 'W': -5, 'V': -3,
              'Y': -3, 'X': -1, 'Z': 1},
        'E': {'*': -5, 'A': -1, 'C': -3, 'B': 1, 'E': 6, 'D': 2, 'G': -3,
              'F': -3, 'I': -4, 'H': 0, 'K': 1, 'M': -2, 'L': -3, 'N': 0,
              'Q': 2, 'P': -1, 'S': -1, 'R': 0, 'T': -1, 'W': -3, 'V': -3,
              'Y': -2, 'X': -1, 'Z': 5},
        'D': {'*': -5, 'A': -2, 'C': -4, 'B': 6, 'E': 2, 'D': 8, 'G': -1,
              'F': -5, 'I': -4, 'H': -1, 'K': -1, 'M': -4, 'L': -4, 'N': 2,
              'Q': 0, 'P': -1, 'S': 0, 'R': -2, 'T': -1, 'W': -5, 'V': -4,
              'Y': -3, 'X': -1, 'Z': 1},
        'G': {'*': -5, 'A': 0, 'C': -3, 'B': -1, 'E': -3, 'D': -1, 'G': 8,
              'F': -4, 'I': -4, 'H': -2, 'K': -2, 'M': -3, 'L': -4, 'N': 0,
              'Q': -2, 'P': -2, 'S': 0, 'R': -3, 'T': -2, 'W': -3, 'V': -4,
              'Y': -3, 'X': -1, 'Z': -2},
        'F': {'*': -5, 'A': -3, 'C': -2, 'B': -4, 'E': -3, 'D': -5,
              'G': -4, 'F': 8, 'I': 0, 'H': -1, 'K': -4, 'M': 0, 'L': 1,
              'N': -4, 'Q': -4, 'P': -4, 'S': -3, 'R': -3, 'T': -2, 'W': 1,
              'V': -1, 'Y': 4, 'X': -1, 'Z': -4},
        'I': {'*': -5, 'A': -1, 'C': -2, 'B': -4, 'E': -4, 'D': -4,
              'G': -4, 'F': 0, 'I': 5, 'H': -4, 'K': -3, 'M': 2, 'L': 2,
              'N': -3, 'Q': -3, 'P': -3, 'S': -3, 'R': -4, 'T': -1,
              'W': -3, 'V': 4, 'Y': -1, 'X': -1, 'Z': -3},
        'H': {'*': -5, 'A': -2, 'C': -3, 'B': 0, 'E': 0, 'D': -1, 'G': -2,
              'F': -1, 'I': -4, 'H': 10, 'K': 0, 'M': -1, 'L': -3, 'N': 1,
              'Q': 1, 'P': -2, 'S': -1, 'R': 0, 'T': -2, 'W': -3, 'V': -4,
              'Y': 2, 'X': -1, 'Z': 0},
        'K': {'*': -5, 'A': -1, 'C': -3, 'B': 0, 'E': 1, 'D': -1, 'G': -2,
              'F': -4, 'I': -3, 'H': 0, 'K': 6, 'M': -2, 'L': -3, 'N': 0,
              'Q': 2, 'P': -1, 'S': 0, 'R': 3, 'T': -1, 'W': -3, 'V': -3,
              'Y': -2, 'X': -1, 'Z': 1},
        'M': {'*': -5, 'A': -1, 'C': -2, 'B': -3, 'E': -2, 'D': -4,
              'G': -3, 'F': 0, 'I': 2, 'H': -1, 'K': -2, 'M': 7, 'L': 3,
              'N': -2, 'Q': 0, 'P': -3, 'S': -2, 'R': -2, 'T': -1, 'W': -1,
              'V': 1, 'Y': 0, 'X': -1, 'Z': -1},
        'L': {'*': -5, 'A': -2, 'C': -2, 'B': -4, 'E': -3, 'D': -4,
              'G': -4, 'F': 1, 'I': 2, 'H': -3, 'K': -3, 'M': 3, 'L': 5,
              'N': -4, 'Q': -2, 'P': -4, 'S': -3, 'R': -3, 'T': -1,
              'W': -2, 'V': 1, 'Y': -1, 'X': -1, 'Z': -3},
        'N': {'*': -5, 'A': -1, 'C': -2, 'B': 5, 'E': 0, 'D': 2, 'G': 0,
              'F': -4, 'I': -3, 'H': 1, 'K': 0, 'M': -2, 'L': -4, 'N': 7,
              'Q': 0, 'P': -2, 'S': 1, 'R': -1, 'T': 0, 'W': -4, 'V': -3,
              'Y': -2, 'X': -1, 'Z': 0},
        'Q': {'*': -5, 'A': -1, 'C': -3, 'B': 0, 'E': 2, 'D': 0, 'G': -2,
              'F': -4, 'I': -3, 'H': 1, 'K': 2, 'M': 0, 'L': -2, 'N': 0,
              'Q': 7, 'P': -1, 'S': 0, 'R': 1, 'T': -1, 'W': -1, 'V': -3,
              'Y': -1, 'X': -1, 'Z': 4},
        'P': {'*': -5, 'A': -1, 'C': -4, 'B': -2, 'E': -1, 'D': -1,
              'G': -2, 'F': -4, 'I': -3, 'H': -2, 'K': -1, 'M': -3,
              'L': -4, 'N': -2, 'Q': -1, 'P': 10, 'S': -1, 'R': -3,
              'T': -1, 'W': -4, 'V': -3, 'Y': -3, 'X': -1, 'Z': -1},
        'S': {'*': -5, 'A': 1, 'C': -1, 'B': 0, 'E': -1, 'D': 0, 'G': 0,
              'F': -3, 'I': -3, 'H': -1, 'K': 0, 'M': -2, 'L': -3, 'N': 1,
              'Q': 0, 'P': -1, 'S': 5, 'R': -1, 'T': 2, 'W': -4, 'V': -2,
              'Y': -2, 'X': -1, 'Z': 0},
        'R': {'*': -5, 'A': -2, 'C': -4, 'B': -1, 'E': 0, 'D': -2, 'G': -3,
              'F': -3, 'I': -4, 'H': 0, 'K': 3, 'M': -2, 'L': -3, 'N': -1,
              'Q': 1, 'P': -3, 'S': -1, 'R': 7, 'T': -1, 'W': -3, 'V': -3,
              'Y': -1, 'X': -1, 'Z': 0},
        'T': {'*': -5, 'A': 0, 'C': -1, 'B': 0, 'E': -1, 'D': -1, 'G': -2,
              'F': -2, 'I': -1, 'H': -2, 'K': -1, 'M': -1, 'L': -1, 'N': 0,
              'Q': -1, 'P': -1, 'S': 2, 'R': -1, 'T': 5, 'W': -3, 'V': 0,
              'Y': -2, 'X': -1, 'Z': -1},
        'W': {'*': -5, 'A': -3, 'C': -5, 'B': -5, 'E': -3, 'D': -5,
              'G': -3, 'F': 1, 'I': -3, 'H': -3, 'K': -3, 'M': -1, 'L': -2,
              'N': -4, 'Q': -1, 'P': -4, 'S': -4, 'R': -3, 'T': -3,
              'W': 15, 'V': -3, 'Y': 2, 'X': -1, 'Z': -2},
        'V': {'*': -5, 'A': 0, 'C': -1, 'B': -3, 'E': -3, 'D': -4, 'G': -4,
              'F': -1, 'I': 4, 'H': -4, 'K': -3, 'M': 1, 'L': 1, 'N': -3,
              'Q': -3, 'P': -3, 'S': -2, 'R': -3, 'T': 0, 'W': -3, 'V': 5,
              'Y': -1, 'X': -1, 'Z': -3},
        'Y': {'*': -5, 'A': -2, 'C': -3, 'B': -3, 'E': -2, 'D': -3,
              'G': -3, 'F': 4, 'I': -1, 'H': 2, 'K': -2, 'M': 0, 'L': -1,
              'N': -2, 'Q': -1, 'P': -3, 'S': -2, 'R': -1, 'T': -2, 'W': 2,
              'V': -1, 'Y': 8, 'X': -1, 'Z': -2},
        'X': {'*': -5, 'A': -1, 'C': -1, 'B': -1, 'E': -1, 'D': -1,
              'G': -1, 'F': -1, 'I': -1, 'H': -1, 'K': -1, 'M': -1,
              'L': -1, 'N': -1, 'Q': -1, 'P': -1, 'S': -1, 'R': -1,
              'T': -1, 'W': -1, 'V': -1, 'Y': -1, 'X': -1, 'Z': -1},
        'Z': {'*': -5, 'A': -1, 'C': -3, 'B': 1, 'E': 5, 'D': 1, 'G': -2,
              'F': -4, 'I': -3, 'H': 0, 'K': 1, 'M': -1, 'L': -3, 'N': 0,
              'Q': 4, 'P': -1, 'S': 0, 'R': 0, 'T': -1, 'W': -2, 'V': -3,
              'Y': -2, 'X': -1, 'Z': 5}}

In [None]:
# Import Classifier Resources

import sys, os, pickle
from IPython.display import clear_output
from tqdm import tqdm
sys.path.append('/home/sj/ml/lib/python3.10/site-packages/') # *Change path as required*

from scipy.stats import percentileofscore
import csv

import numpy as np
import pandas as pd
import importlib
import random

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
# Plots stuff
import matplotlib as mpl
from matplotlib import patches
from pandas.plotting import table

import sklearn
from sklearn import metrics
from sklearn.model_selection import train_test_split, KFold
from skbio import DNA, Protein

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import L1Loss
import copy
from scipy.optimize import fsolve
import math

# import keras
# from keras.models import Sequential
# from keras.layers import Dense, Dropout
# from keras import regularizers

def overlap_seqs(list1,list2):
    overlap=[]
    for i in range(len(list1)):
        if list1[i] in list2:
            overlap.append(list1[i])
    return overlap

def flatten_list(listoflist):
    listoflist_fl = [];
    for l in range(len(listoflist)):
        for u in range(len(listoflist[l])):
            listoflist_fl.append(listoflist[l][u])
    return listoflist_fl

curr_int = np.int16
def convert_number(seqs): # convert to numbers already aligned seqs
    aa = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V',  'W', 'Y','-']
    aadict = {aa[k]: k for k in range(len(aa))}

    msa_num = np.array(list(map(lambda x: [aadict[y] for y in x], seqs[0:])), dtype=curr_int, order="c") ### Here change ####

    return msa_num

def uniqueIndexes(l):
    seen = set()
    res = []
    for i, n in enumerate(l):
        if n not in seen:
            res.append(i)
            seen.add(n)
    return res

def convert_letter(seqs_n): # convert to numbers already aligned seqs
    aa = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V',  'W', 'Y','-']
    aadictinv = {k: aa[k] for k in range(len(aa))}
    seqs=[]
    if type(seqs_n[0]) == curr_int:
        seqs.append(''.join([aadictinv[e] for e in seqs_n]))
    else:
        for t in range(len(seqs_n)):
            seqs.append(''.join([aadictinv[e] for e in seqs_n[t]]))
    return seqs

# add some functions for independent site models
def loglikelihood_indip_model(fields, logZ, seqs):
    return fields[np.arange(len(fields)), seqs].sum(axis=1) - logZ

def add_pseudocount(fields, n):
    return np.array([(f + 1/n) / np.sum((f + 1/n)) for f in fields])

def build_model(dims, dropout_prob=0.5):
    assert dims[0] == A * L
    assert dims[-1] == 1

    layers = [torch.nn.Flatten(), torch.nn.Linear(dims[0], dims[1])]
    for l in range(2, len(dims)):
        layers.append(torch.nn.LeakyReLU())
        layers.append(torch.nn.Linear(dims[l - 1], dims[l]))
        if l < len(dims) - 1:  # Add dropout except for the last layer
            layers.append(torch.nn.Dropout(p=dropout_prob))
    return torch.nn.Sequential(*layers)

def getAllModels(A, L, dropout_prob=0.5):
    return [
        build_model([A * L, 1], dropout_prob), # Perceptron (Parsimonious)
        build_model([A * L, 2, 1], dropout_prob),
        build_model([A * L, 4, 1], dropout_prob),
        build_model([A * L, 8, 1], dropout_prob),
        build_model([A * L, 16, 1], dropout_prob),
        build_model([A * L, 32, 1], dropout_prob),
        build_model([A * L, 64, 1], dropout_prob),
        build_model([A * L, 128, 1], dropout_prob),
        build_model([A * L, 16, 8, 1], dropout_prob),
        build_model([A * L, 32, 8, 1], dropout_prob),
        build_model([A * L, 64, 8, 1], dropout_prob),
        build_model([A * L, 128, 8, 1], dropout_prob),
        build_model([A * L, 32, 16, 1], dropout_prob),
        build_model([A * L, 64, 16, 1], dropout_prob),
        build_model([A * L, 128, 16, 1], dropout_prob),
        build_model([A * L, 32, 16, 8, 1], dropout_prob),
        build_model([A * L, 64, 16, 8, 1], dropout_prob),
        build_model([A * L, 64, 32, 16, 1], dropout_prob),
        build_model([A * L, 128, 32, 16, 1], dropout_prob),
        build_model([A * L, 128, 64, 16, 1], dropout_prob),
        build_model([A * L, 128, 64, 32, 1], dropout_prob),
        build_model([A * L, 128, 64, 32, 8, 1], dropout_prob),
        build_model([A * L, 128, 64, 32, 16, 1], dropout_prob)]

def getModelsBest(A, L, dropout_prob=0.5):
    return [
        build_model([A * L, 1], dropout_prob), # Perceptron (Parsimonious)
        build_model([A * L, 64, 1], dropout_prob)]

def getNumInputFeatures(X_train, X_test):
    allLetters = ''
    for i in X_train:
        allLetters= allLetters + i
    for i in X_test:
        allLetters= allLetters + i
    A = len(np.unique([*allLetters]))
    return A

def random9mer(num):
    # Define the amino acids alphabet
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"

    random_seqs = []
    # Convert the three-letter code to one-letter code (optional)
    for i in range(num):
        random_seqs.append("".join(random.choice(amino_acids) for _ in range(9)))
    return random_seqs

In [None]:
# Load and store SARS-CoV-2 genome and variant genomes from GISAID file

#source for genome and hand-typed coordinates
#https://www.ncbi.nlm.nih.gov/nuccore/NC_045512
fold_out = "Mutation Analysis/TREE_sel/" # *Change path as required*
fasta_file = fold_out + "full.fasta" # *Download full mutations fasta file from GISAID (login and click downloads, download 'allprot' file)*
coordinates={"ORF1ab":[266,13483],
             "S":[21563,25384],
             "ORF3a":[25393,26220],
             "E":[26245,26472],
             "M":[26523,27191],
             "ORF6":[27202,27387],
             "ORF7a":[27394,27759],
             "ORF7b":[27756,27887],
             "ORF8":[27894,28259],
             "N":[28274,29533],
             "ORF10": [29558,29674]
             }

# Import and store reference genome

ref = []
with open(fasta_file, 'r') as fasta:
    for line in fasta:
        line = line.strip()
        if line.startswith('>'):
            accepted = False
            metadata = line[1:].split('|')
            accID = metadata[3]
            if accID != 'EPI_ISL_402124':
                break
            gene = metadata[0].upper()
            date = metadata[2]
            country = metadata[10]
            if accID and gene and date and country:
                ref.append([gene])
                accepted = True
        else:
            if accepted == True:
                seq = line
                if seq[-1] == '*':
                    seq = seq[0:-1]
                ref[-1].append(seq)

# Create a DataFrame from the collected data
columns = ['gene', 'seq']
ref_df = pd.DataFrame(ref, columns=columns)

# Import and store all variant sequences
valid_genes = ['NSP1', 'NSP2', 'NSP3', 'NSP4', 'NSP5', 'NSP6', 'NSP7', 'NSP8', 'NSP9',
                  'NSP10', 'NSP11', 'NSP12', 'NSP13', 'NSP14', 'NSP15', 'NSP16', 'SPIKE',
                  'NS3', 'E', 'M', 'NS6', 'NS7A', 'NS7B', 'NS8', 'N', 'NS9B', 'NS9C']

valid_passages = ['original', 'orginal', 'orignal']
def is_valid_metadata(metadata):
    if len(metadata) != 11:
        return False

    gene = metadata[0].upper()
    if gene not in valid_genes:
        return False

    date = metadata[2]
    dates = date.split('-')
    if dates[0][0:2] != '20' or int(dates[1]) < 1 or int(dates[2]) < 1:
        return False

    accID = metadata[3].upper()
    if accID == 'EPI_ISL_402124':
        return False

    passage = metadata[4].lower()
    if passage not in valid_passages:
        return False

    animal = metadata[6].lower()
    if animal != 'human':
        return False

    return True

data = []
i = 0
labs = []
with open(fasta_file, 'r') as fasta:
    accID = ''
    gene = ''
    date = ''
    country = ''

    for line in fasta:
        line = line.strip()
        if line.startswith('>'):
            accepted = False
            metadata = line[1:].split('|')
            if is_valid_metadata(metadata):
                accID = metadata[3]
                gene = metadata[0].upper()
                date = metadata[2]
                country = metadata[10]
                if '\S' in metadata[-3]:
                    lab = metadata[-3]
                else:
                    lab = codecs.decode(metadata[-3], 'unicode_escape')
                if accID and gene and date and country:
                    data.append([accID, gene, date, country, lab])
                    accepted = True
                    i += 1
        else:
            if accepted == True:
                seq = line
                if seq[-1] == '*':
                    seq = seq[0:-1]
                data[-1].append(seq)

# Create a DataFrame from the collected data
columns = ['accID', 'gene', 'date', 'country', 'lab', 'seq']
var_df = pd.DataFrame(data, columns=columns)

# Remove rows from var_df where accID count is not equal to 27 - small proportion ~0.1%
accIDs = var_df['accID']
countsOfAccIDs = Counter(accIDs)
var_df = var_df[var_df['accID'].map(lambda x: countsOfAccIDs[x]) == 27]

print(int(len(var_df)/27), 'total variants.')

113563 total variants.


In [None]:
# Load and store PRIME epitopes, and convert to sort by NSP

epitopes = {}
# Experimentally verified PRIME epitopes
# epitopes['ORF3A'] = ['LYLYALVYF', 'FTSDYYQLY', 'LPFGWLIV', 'SASKIITLK', 'QSASKIITLK']
# epitopes['N'] = ['DLSPRWYFYY', 'LSPRWYFYY']
# epitopes['ORF1AB'] = ['EYADVFHLYL']
# epitopes['S'] = ['ECSNLLLQY', 'LPPAYTNSF', 'QYIKWPWYIW', 'YFPLQSYGF']
# epitopes['E'] = ['YFIASFRLF', 'LWLLWPVTL', 'QWNLVIGFLF', 'RFLYIIKLI', 'NRNRFLYII', 'ATSRTLSYYK']
# epitopes['M'] = ['YFIASFRLF', 'LWLLWPVTL', 'QWNLVIGFLF', 'RFLYIIKLI', 'NRNRFLYII', 'ATSRTLSYYK']

# All PRIME epitopes
epitopes['SPIKE'] = ['CNDPFLGVYY','STECSNLLLQY','CVADYSVLY','NIDGYFKIY','ECSNLLLQY','TSNQVAVLY','VADYSVLY','KIADYNYKL','RLQSLQTYV','YLQPRTFLL','VLNDILSRL','KLNDLCFTNV','RLNEVAKNL','TLDSKTQSL','ALNTLVKQL','SIIAYTMSL','RLDKVEAEV','NLNESLIDL','KLQDVVNQN','LLFNKVTLA','HLMSFPQSA','LITGRLQSL','FIAGLIAIV','GVYFASTEK','SVYAWNRKR','AQALNTLVK','GVYYHKNNK','SVLNDILSR','TLKSFTVEK','ASANLAATK','KSTNLVKNK','GSFCTQLNR','VVLSFELLH','YLQPRTFLLK','GTHWFVTQR','ASVYAWNRK','STGSNVFQT','NSASFSTFK','ASFSTFKCY','QYIKWPWYI','NYNYLYRLF','QYIKWPWYIW','PYRVVVLSF','SPRRARSVA','QPYRVVVLSF','GPKKSTNLV','SLSSTASAL','TPTWRVYST','YQPYRVVVL','LLFNKVTL','INITRFQTL','LVKNKCVNF','DLLFNKVTL','FNATRFASV','FKNLREFVF','LPFNDGVYF','LPFFSNVTW','QPRTFLLKY','LPPAYTNSF','GVVFLHVTY','DPFLGVYY','TEKSNIIRGW','ADAGFIKQY','SETKCTLKSF','EELDKYFKNH','TECSNLLLQY','QEVFAQVKQI','IPTNFTISV','HGVVFLHV','IAIVMVTI','LPLVSSQCV','QPYRVVVL','NATNVVIKV','VVFLHVTYV','TQDLFLPFF','RFDNPVLPF','YFPLQSYGF','LTDEMIAQY','IEDLLFNKV','TTEILPVSM','IADYNYKL','SAPHGVVFL','STECSNLLL','IKDFGGFNF','VRFPNITNL','TRFQTLLAL','FRSSVLHST','GNYNYLYRL','YRVVVLSF','VVFLHVTY','TRFASVYAW','YYPDKVFRS']
epitopes['ORF3a'] = ['FTSDYYQLY','FLCWHTNCY','HSYFTSDYY','TSDYYQLY','GLEAPFLYLY','FLYLYALVY','YLYALVYFL','LLYDANYFL','ALSKGVHFV','TVYSHLLLV','ALLAVFQSA','SASKIITLK','FTIGTVTLK','KRWQLALSK','QSASKIITLK','ASKIITLKK','HVTFFIYNK','ATATIPIQA','LYLYALVYF','YYQLYSTQL','HFVCNLLLL','YFTSDYYQL','LFVTVYSHL','APFLYLYAL','DLFMRIFTI','QSASKIITL','TLKKRWQL','TLKKRWQLA','NPLLYDANY','LLLLFVTVY','LEAPFLYLY','EEHVQIHTI','SEHDYQIGGY','LPFGWLIV','LYDANYFL','QSINFVRII','CRSKNPLLY','KGVHFVCNL','VEHVTFFIY']
epitopes['E'] = ['FLAFVVFLL','SLVKPSFYV','LFLAFVVFL','VFLLVTLAI','YVYSRVKNL','FVVFLLVTL','LAFVVFLLV','LAFVVFLL','IVNSVLLFL','LVKPSFYVY','VKPSFYVY']
epitopes['M'] = ['ATSRTLSYY','SSDNIALLV','LVGLMWLSY','VATSRTLSY','TVATSRTLSY','TVATSRTLSYY','KLLEQWNLV','TLACFVLAA','FLYIIKLIFL','TLACFVLAAV','FVLAAVYRI','ATSRTLSYYK','LVIGAVILR','LFLTWICLL','LYIIKLIFL','LWLLWPVTL','QWNLVIGFLF','YFIASFRLF','LYIIKLIFLW','SYFIASFRL','RFLYIIKLI','LPKEITVAT','IIKLIFLWL','LEQWNLVIGF','EELKKLLEQW','LWPVTLACF','SSDNIALL','QWNLVIGFL','NRFLYIIKL','SSSDNIALL','SRTLSYYKL','NRNRFLYII','ARTRSMWSF']
epitopes['N'] = ['DLSPRWYFYY','LSPRWYFYY','DLSPRWYFY','VTPSGTWLTY','LLLDRLNQL','LQLPQGTTL','NFKDQVILL','SPRWYFYYL','GPQNQRNAP','TPSGTWLTY','MEVTPSGTW','DAALALLL','DAALALLLL','NAAIVLQL','NNAAIVLQL']
epitopes['ORF1ab'] = ['YSVIYLYLTFY','YADVFHLY','YADVFHLYL','DVFHLYLQY','NQEYADVFHLY','YADVFHLYLQY','WLMWLIINL','SLPGVFCGV','TLMNVLTLV','SMWALIISV','GVYSVIYLY','VYSVIYLYL','IYLYLTFYL','VYSVIYLYLTF','EYADVFHLY','VFHLYLQYI','EYADVFHLYL','LPGVYSVIY','HPNQEYADVF','QEYADVFHL','QEYADVFHLY','LPGVYSVI','VYSVIYLY']

prime_epitopes = {}
# Iterate over each gene in ref_df
for index, row in ref_df.iterrows():
    gene_name = row['gene']
    gene_sequence = row['seq']
    found_epitopes = [] # Initialise a list to store epitopes found in each gene
    # Iterate over each epitope in epitopes
    for epitope_gene, epitope_list in epitopes.items():
        # Search for epitope sequence within gene sequence
        for epitope in epitope_list:
            if epitope in gene_sequence:
                found_epitopes.append(epitope)
    # Store the found epitopes in prime_epitopes dictionary
    if len(found_epitopes) > 0:
        prime_epitopes[gene_name] = found_epitopes

In [None]:
# Functions for mutation detection

def getSeqAligned(epitope, sequence):
    """
    Retrieves the aligned sequence from a local pairwise alignment between an epitope and a larger sequence.

    Parameters:
    epitope (str): The epitope sequence.
    sequence (str): The larger sequence containing the epitope.

    Returns:
    str: The aligned sequence containing the epitope.

    """
    # Create BioPython Seq objects from the input strings
    epitope_prot = Protein(epitope)
    sequence_prot = Protein(sequence)

    # Perform local pairwise alignment using Smith-Waterman algorithm with a substitution matrix (blosum50)
    alignment = local_pairwise_align_ssw(epitope_prot, sequence_prot, substitution_matrix=blosum50)[2][1]

    # Extract the aligned sequence containing the epitope based on the alignment indices
    return sequence[alignment[0]:alignment[1]+1]

def getMutationType(seq1, seq2, subOnly=0):
    """
    Determines the type of mutation between two sequences and returns relevant information.

    Args:
        seq1 (str): The first sequence.
        seq2 (str): The second sequence.
        subOnly (int) = (default value = 0) Value to represent if all mutation types are checked, or just substitution.

    Returns:
        tuple: A tuple containing the mutation type and additional information:
            - If the sequences are the same, returns ('-1',)
            - If there is a substitution, returns ('substitution', mutation_positions), where mutation_positions is an array
              of lists containing the mutated positions and their corresponding residues before and after the mutation.
            - If there is an insertion, returns ('insertion', [inserted_seq, position]), where inserted_seq is the inserted
              sequence and position is the position in seq1 where the insertion occurred.
            - If there is a deletion, returns ('deletion', -1).
            - If no mutation is detected or an unsupported case is encountered, returns ('-1',).

    """
    if seq1 == seq2:
        # If the sequences are the same, then return -1
        return -1
    elif len(seq2) == len(seq1):
        mutation_positions = []
        for i in range(len(seq1)):
            residue1, residue2 = seq1[i], seq2[i]
            if residue1 != residue2:
                mutation_positions.append([i+1, residue1, residue2])
        if len(mutation_positions) > 0:
            # Return mutation type, and an array of the mutated positions, and their residues pre- and post-mutation
            return 'substitution', mutation_positions
    elif len(seq2) > len(seq1) and subOnly == 0:
        if seq1[0] == seq2[0] and seq1[-1] == seq2[-1]:
            if seq2[:-1] == seq1:
                return 'insertion', [seq2[-1],len(seq1)-1]
            if seq1[1:] == seq1:
                return 'insertion', [seq2[0],1]
            for i in range(len(seq1)):
                try:
                    seq1[i] != seq2[i]
                except IndexError as e:
                    return 'insertion', -1
                if seq1[i] != seq2[i]:
                    inserted_seq = seq2[i:i+len(seq2)-len(seq1)]
                    if seq1[i:] == seq2[i+len(inserted_seq):] and inserted_seq != '': # If ends match and no error
                        # Return mutation type, and a tuple of the inserted sequence and the position in seq1 in which it was inserted into
                        return 'insertion', [inserted_seq,i]
                    else:
                        # If no simple insertion, return mutation type, and -1
                        return 'insertion', -1
        else:
            # If no simple insertion, return mutation type, and -1
            return 'insertion', -1
    elif len(seq2) < len(seq1) and subOnly == 0:
        if seq1[0] == seq2[0] and seq1[-1] == seq2[-1]:
            if seq1[:-1] == seq2:
                return 'deletion', [seq1[-1],len(seq2)-1]
            if seq1[1:] == seq2:
                return 'deletion', [seq1[0],1]
            for i in range(len(seq1)):
                try:
                    seq1[i] != seq2[i]
                except IndexError as e:
                    return 'deletion', -1
                if seq1[i] != seq2[i]:
                    deleted_seq = seq1[i:i+len(seq1)-len(seq2)]
                    if seq2[i:] == seq1[i+len(deleted_seq):] and deleted_seq != '': # If ends match and no error
                        # Return mutation type, and a tuple of the inserted sequence and the position in seq1 in which it was inserted into
                        return 'deletion', [deleted_seq,i]
                    else:
                        # If no simple insertion, return mutation type, and -1
                        return 'deletion', -1
        else:
            # If no simple deletion, return mutation type, and -1
            return 'deletion', -1
    else:
        return -1

# *Use below if data is saved to file (to save having to repeat analysis each time)*
# mutations_detected = np.load('temp/IEDB_peptides_with_mutations.npy')

In [None]:
# # Load mutation_rows if saved file
# with open('temp/mutation_rows.pkl', 'rb') as file:
#     mutation_rows = pickle.load(file)

In [2]:
# # Get all epitopes and mutations - to use with PRIME2.0 server

# epi, mut = [], []
# for gene in mutation_rows.keys():
#     for epitope in mutation_rows[gene]:
#         epi.append(epitope['epitope'])
#         m = epitope['mutation']
#         if 'X' not in m and len(m) >= 8 and len(m) <= 14:
#             mut.append(m)
# epi, mut = np.unique(np.array(epi)), np.unique(np.array(mut))

# with open('epitopes_all.txt', 'w') as epitopes_file:
#     for epitope in epi:
#         epitopes_file.write(epitope + '\n')

# with open('mutations_all.txt', 'w') as mutations_file:
#     for mutation in mut:
#         mutations_file.write(mutation + '\n')

# Steps in between:
#     Go to http://prime.gfellerlab.org/
#     Alleles = A0101,A2501,B0801,B1801,C0702,A0201,A0301,A1101,A2402,A2601,B0702,B0801,B3501,B4402,B5101,C0401,C0501,C0602,C0701,C0702
#     Calculate PRIME2.0 scores on both epitopes_all.txt and mutations_all.txt files.
#     For both downloaded files, open in Excel
#     Delete commented out rows, so first row is table headers.
#     Delete columns after BestAllele.
#     Save, and run code below


# Load in scores calculated from PRIME2.0 server

epitopes_prime_scores = {}
with open('epitopes_all_res.txt', 'r') as file:
    header = next(file)
    for line in file:
        line = line.strip()
        values = line.split(',')
        epitope = values[0]
        scores = {
            '%Rank_bestAllele': values[1],
            'Score_bestAllele': values[2],
            '%RankBinding_bestAllele': values[3],
            'BestAllele': values[4]
        }
        epitopes_prime_scores[epitope] = scores

mutations_prime_scores = {}
with open('mutations_all_res.txt', 'r') as file:
    header = next(file)
    for line in file:
        line = line.strip()
        values = line.split(',')
        mutation = values[0]
        scores = {
            '%Rank_bestAllele': values[1],
            'Score_bestAllele': values[2],
            '%RankBinding_bestAllele': values[3],
            'BestAllele': values[4]
        }
        mutations_prime_scores[mutation] = scores

In [None]:
# Save/Load IEDB Data

range_len = [8, 9, 10, 11]
list_hlas = ['HLA-A*01:01', 'HLA-A*02:01', 'HLA-A*03:01', 'HLA-A*11:01', 'HLA-A*24:02', 'HLA-B*07:02', 'HLA-B*08:01', 'HLA-B*15:01', 'HLA-B*35:01', 'HLA-B*40:01'] # 10 most frequent HLAs used
A,L=20,9

# # Save the variables, dictionaries, and arrays to a file
# with open('data.pkl', 'wb') as file:
#     data = {
#         'pep_dict_train': pep_dict_train,
#         'pepsP_train': pepsP_train,
#         'pepsN_train': pepsN_train,
#         'pep_dict_test': pep_dict_test,
#         'pepsP_test': pepsP_test,
#         'pepsN_test': pepsN_test
#     }
#     pickle.dump(data, file)

# Load the variables, dictionaries, and arrays from the file
with open('Classifier/Data/iedb_data.pkl', 'rb') as file:
    data = pickle.load(file)
    pep_dict_train = data['pep_dict_train']
    pepsP_train = data['pepsP_train']
    pepsN_train = data['pepsN_train']
    pep_dict_test = data['pep_dict_test']
    pepsP_test = data['pepsP_test']
    pepsN_test = data['pepsN_test']

def getSplitData(pepsP_train, pepsN_train, pepsP_test, pepsN_test, split=0.25):

    hlaP_train, hlaN_train, hlaP_test, hlaN_test = [], [], [], []

    for alleles in range(len(list_hlas)):
        hlaP_train.append([list_hlas[alleles]] * len(pepsP_train[alleles]))
        hlaN_train.append([list_hlas[alleles]] * len(pepsN_train[alleles]))
        hlaP_test.append([list_hlas[alleles]] * len(pepsP_test[alleles]))
        hlaN_test.append([list_hlas[alleles]] * len(pepsN_test[alleles]))

    # Concatenate positive and negative peptides for each allele
    pepsP_train_concat = np.concatenate(pepsP_train)
    pepsN_train_concat = np.concatenate(pepsN_train)
    pepsP_test_concat = np.concatenate(pepsP_test)
    pepsN_test_concat = np.concatenate(pepsN_test)

    # Concatenate HLA labels for each peptide
    hlaP_train_concat = np.concatenate(hlaP_train)
    hlaN_train_concat = np.concatenate(hlaN_train)
    hlaP_test_concat = np.concatenate(hlaP_test)
    hlaN_test_concat = np.concatenate(hlaN_test)

    Ptest_cutoff = int(split*len(pepsP_test_concat))
    Ntest_cutoff = int(split*len(pepsN_test_concat))

    # Create X_train by combining non-SARS-CoV-2 data with 25% of SARS-CoV-2 data from pepsP_test and pepsN_test
    X_train = np.concatenate([pepsP_train_concat, pepsN_train_concat, pepsP_test_concat[:Ptest_cutoff], pepsN_test_concat[:Ntest_cutoff]])
    X_test = np.concatenate([pepsP_test_concat[Ptest_cutoff:], pepsN_test_concat[Ntest_cutoff:]])

    # Create y_train by assigning 1 to immunogenic peptides and 0 to non-immunogenic peptides
    y_train = np.concatenate([np.ones(len(pepsP_train_concat)), np.zeros(len(pepsN_train_concat)), np.ones(Ptest_cutoff), np.zeros(Ntest_cutoff)])
    y_test = np.concatenate([np.ones(len(pepsP_test_concat[Ptest_cutoff:])), np.zeros(len(pepsN_test_concat[Ntest_cutoff:]))])

    # Create hla_train, hla_test
    hla_train = np.concatenate([hlaP_train_concat, hlaN_train_concat, hlaP_test_concat[:Ptest_cutoff], hlaN_test_concat[:Ntest_cutoff]])
    hla_test = np.concatenate([hlaP_test_concat[Ptest_cutoff:], hlaN_test_concat[Ntest_cutoff:]])

    # Shuffle the data
    random_indices_train = np.random.permutation(len(X_train))
    random_indices_test = np.random.permutation(len(X_test))

    X_train = X_train[random_indices_train]
    hla_train = hla_train[random_indices_train]
    y_train = y_train[random_indices_train]

    X_test = X_test[random_indices_test]
    hla_test = hla_test[random_indices_test]
    y_test = y_test[random_indices_test]

    return X_train, y_train, hla_train, X_test, y_test, hla_test

def getSplitDataHLA(HLA, split=0.25):
    # (For single allele classifier)
    p_train, n_train, p_test, n_test = pep_dict_train[HLA][0], pep_dict_train[HLA][1], pep_dict_test[HLA][0], pep_dict_test[HLA][1]

    Ptest_cutoff = int(split*len(p_test))
    Ntest_cutoff = int(split*len(n_test))

    # Create X_train by combining non-sars-cov-2 data with 25% of sars-cov-2 data from pepsP_test and pepsN_test
    X_train = np.concatenate([p_train, n_train, p_test[:Ptest_cutoff], n_test[:Ntest_cutoff]])
    X_test = np.concatenate([p_test[Ptest_cutoff:], n_test[Ntest_cutoff:]])

    # Create y_train by assigning 1 to immunogenic peptides and 0 to non-immunogenic peptides
    y_train = np.concatenate([np.ones(len(p_train)), np.zeros(len(n_train)), np.ones(Ptest_cutoff), np.zeros(Ntest_cutoff)])
    y_test = np.concatenate([np.ones(len(p_test[Ptest_cutoff:])), np.zeros(len(n_test[Ntest_cutoff:]))])

    # Shuffle the data
    random_indices_train = np.random.permutation(len(X_train))
    random_indices_test = np.random.permutation(len(X_test))
    X_train = X_train[random_indices_train]
    y_train = y_train[random_indices_train]
    X_test = X_test[random_indices_test]
    y_test = y_test[random_indices_test]

    return X_train, y_train, X_test, y_test

def getPeptideGene(peptide):
    for geneName in ref_df['gene']:
        if peptide in ref_df[ref_df['gene'] == geneName]['seq'].values[0]:
            return geneName
    return -1

model = copy.deepcopy(getAllModels(A, L)[7]) #*Change model as required*
random_tensor = torch.nn.functional.one_hot(torch.LongTensor(convert_number(random9mer(1000000))), num_classes=A).type(torch.FloatTensor)
random_scores = np.squeeze(model(random_tensor).detach().numpy())

def percentile_rank(X_scores):
    # Calculate the percentile rank of X_scores with respect to random_scores
    X_ranks = [100 * np.sum(random_scores > score) / len(random_scores) for score in X_scores]
    return X_ranks

In [None]:
# Organise all IEDB SARS-CoV-2 reference genome genes

X_train, y_train, hla_train, X_test, y_test, hla_test = getSplitData(pepsP_train, pepsN_train, pepsP_test, pepsN_test)
X = np.array(list(X_train) + list(X_test))
y = np.array(list(y_train) + list(y_test))
hla = np.array(list(hla_train) + list(hla_test))
X_ref_by_gene, y_ref_by_gene, hla_ref_by_gene, auc_by_gene, all_X_gene = {}, {}, {}, {}, []
for geneName in ref_df['gene']:
    X_ref_by_gene[geneName], y_ref_by_gene[geneName], hla_ref_by_gene[geneName], auc_by_gene[geneName] = [], [], [], []
for i in range(len(X)):
    geneName = getPeptideGene(X[i])
    if geneName != -1:
        all_X_gene.append(geneName)
        X_ref_by_gene[geneName].append(X[i])
        y_ref_by_gene[geneName].append(y[i])
        hla_ref_by_gene[geneName].append(hla[i])
    else:
        all_X_gene.append('Unknown')

In [None]:
# *Use below if data/models are saved to files*
def getComparisonDF(model1, model2):
    classifiers = getAllModels(A, L)
    X_train, y_train, _, X_test, y_test, _ = getSplitData(pepsP_train, pepsN_train, pepsP_test, pepsN_test)
    X = np.array(list(X_train) + list(X_test))
    X_tensor = torch.nn.functional.one_hot(torch.LongTensor(convert_number(X)), num_classes=A).type(torch.FloatTensor)
    y = np.array(list(y_train) + list(y_test))
    X_ref_by_gene, y_ref_by_gene, hla_ref_by_gene, auc_by_gene, all_X_gene, mutation_exists = {}, {}, {}, {}, [], []
    for pep in X:
        geneName = getPeptideGene(pep)
        if geneName != -1:
            all_X_gene.append(geneName)
        else:
            all_X_gene.append('Unknown')
        if pep in mutations_detected:
            mutation_exists.append(True)
        else:
            mutation_exists.append(False)
    scores, ranks, preds = [], [], []
    for model_name in [model1, model2]:
        if model_name == 'PRIME2.0':
            scores.append(np.loadtxt('prime_X_scores'))
            ranks.append(np.loadtxt('prime_X_ranks'))
        else:
            if model_name == 'prtr-pars':
                model = copy.deepcopy(classifiers[0])
            elif model_name == 'prtr-deep':
                model = copy.deepcopy(classifiers[6])
            elif model_name == 'parsimonious':
                model = copy.deepcopy(classifiers[0])
            elif model_name == 'deep':
                model = copy.deepcopy(classifiers[7])
            else:
                print('Model unrecognised')
                break
            # Get model results on test data X
            model.load_state_dict(torch.load('Classifier/Models/' + model_name + '.pth')) # *Change path as required*
            current_scores = np.squeeze(model(X_tensor).detach().numpy())
            scores.append(current_scores)
            ranks.append(percentile_rank(current_scores))

    # Get immunogenic predictions, y_pred, from PRIME and deep classifier
    preds = [[], []]
    model_num = 0
    for model_name in [model1, model2]:
        if model_name == 'PRIME':
            threshold = 0.5
        else:
            threshold = 5
        for rank in ranks[model_num]:
            if rank <= threshold:
                preds[model_num].append(1)
            else:
                preds[model_num].append(0)
        model_num += 1

    # Put all PRIME and deep classifier data into dataframe
    overlap_df = pd.DataFrame({
        'X': X,
        'y': y,
        'gene': all_X_gene,
        'mutation_exists': mutation_exists,
        'model_1_scores': scores[0],
        'model_2_scores': scores[1],
        'model_1_ranks': ranks[0],
        'model_2_ranks': ranks[1],
        'model_1_pred': preds[0],
        'model_2_pred': preds[1]
    })
    return overlap_df

In [None]:
def allOverlapPlots(overlap_dic, name_X, name_y):
    scores_X, scores_y, ranks_X, ranks_y, pred_X, pred_y = overlap_df['model_1_scores'], overlap_df['model_2_scores'], overlap_df['model_1_ranks'], overlap_df['model_2_ranks'], overlap_df['model_1_pred'], overlap_df['model_2_pred']
    plt.scatter(scores_X, scores_y, marker='x', s=9, alpha = 0.5)
    plt.title(name_X + ' vs ' + name_y + ' Immunogenicity Prediction Scores')
    plt.xlabel(name_X + ' Classifier Scores')
    plt.ylabel(name_y + ' Classifier Scores')
    plt.show();

    # Scores Hexbin
    plt.hexbin(scores_X, scores_y, gridsize=20, cmap='Blues')
    plt.xlabel(name_X + ' Classifier Scores')
    plt.ylabel(name_y + ' Classifier Scores')
    plt.title('Hexbin Plot of Immunogenicity Prediction Scores')
    plt.colorbar(label='Count')
    plt.show();

    # Scores Histogram
    plt.hist(scores_X, histtype='step', bins=50, label= name_X + ' Classifier Scores')
    plt.hist(scores_y, histtype='step', bins=50, label= name_y + ' Classifier Scores')
    plt.xlabel('Immunogenicity Prediction Scores')
    plt.ylabel('Count')
    plt.title('Histogram of ' + name_X + ' vs ' + name_y + ' Classifier Immunogenicity Prediction Scores')
    plt.legend()
    plt.show();

    # Scores KDE Plot
    sns.kdeplot(scores_X, label= name_X + ' Classifier Scores', fill=True)
    sns.kdeplot(scores_y, label= name_y + ' Classifier Scores', fill=True)
    plt.xlabel('Immunogenicity Prediction Scores')
    plt.ylabel('Density')
    plt.title('KDE Plot of ' + name_X + ' vs ' + name_y + ' Classifier Immunogenicity Prediction Scores')
    plt.legend()
    plt.show();

    # Scatter Ranks
    plt.scatter(ranks_X, ranks_y, marker='x', s=9, alpha = 0.5)
    plt.title(name_X + ' vs ' + name_y + ' Classifier Immunogenicity Prediction %Ranks')
    plt.xlabel(name_X + ' %Rank Score')
    plt.ylabel(name_y + ' Classifier %Rank Score')
    plt.show();

    # Ranks Hexbin
    plt.hexbin(ranks_X, ranks_y, gridsize=20, cmap='Blues')
    plt.xlabel(name_X + ' Classifier %Rank Scores')
    plt.ylabel(name_y + ' Classifier %Rank Scores')
    plt.title('Hexbin Plot of Immunogenicity Prediction %Rank Scores')
    plt.colorbar(label='Count')
    plt.show()

    # Ranks Histogram
    plt.hist(ranks_X, histtype='step', bins=50, label=name_X + ' Classifier Scores')
    plt.hist(ranks_y, histtype='step', bins=50, label=name_y + ' Classifier Scores')
    plt.xlabel('Immunogenicity Prediction %Rank Scores')
    plt.ylabel('Count')
    plt.title('Histogram of ' + name_X + ' vs ' + name_y + ' Classifier Immunogenicity Prediction %Ranks')
    plt.legend()
    plt.show();

    # Ranks KDE Plot
    sns.kdeplot(ranks_X, label=name_X + ' Classifier Scores', fill=True)
    sns.kdeplot(ranks_y, label=name_y + ' Classifier Scores', fill=True)
    plt.xlabel('Immunogenicity Prediction %Rank Scores')
    plt.ylabel('Density')
    plt.title('KDE Plot of ' + name_X + ' vs ' + name_y + ' Classifier Immunogenicity Prediction %Ranks')
    plt.legend()
    plt.show();

    print(f"Correlation Coefficient (Scores): {np.corrcoef(scores_X, scores_y)[0, 1]}")
    print(f"Correlation Coefficient (Ranks): {np.corrcoef(ranks_X, ranks_y)[0, 1]}")

In [None]:
def mainOverlapPlots(overlap_df, name_X, name_y, unknown_gene_visible = True, mutations_visible = False):
    if unknown_gene_visible:
        gene_palette = sns.color_palette('hsv', n_colors=27)
    else:
        overlap_df = overlap_df[overlap_df['gene'] != 'Unknown']
        gene_palette = sns.color_palette('hsv', n_colors=26)
    X, y, mutation_exists, scores_X, scores_y, ranks_X, ranks_y, pred_X, pred_y = overlap_df['X'], overlap_df['y'], overlap_df['mutation_exists'], overlap_df['model_1_scores'], overlap_df['model_2_scores'], overlap_df['model_1_ranks'], overlap_df['model_2_ranks'], overlap_df['model_1_pred'], overlap_df['model_2_pred']
    print(f"Correlation Coefficient (Scores): {round(np.corrcoef(scores_X, scores_y)[0, 1],4)}")
    print(f"Correlation Coefficient (Ranks): {round(np.corrcoef(ranks_X, ranks_y)[0, 1],4)}")
    # Create the Marginal Histogram of %Rank Scores
    joint_plot = sns.jointplot(data=overlap_df, x="model_1_ranks", y="model_2_ranks", space=0, hue='gene', palette=gene_palette, xlim=(-max(ranks_X)*0.1, max(ranks_X)*1.1), ylim=(-max(ranks_y)*0.1, max(ranks_y)*1.1), height=6, ratio=5, edgecolor='None')
    joint_plot.set_axis_labels(name_X + ' Classifier Immunogenicity %Rank Scores', name_y + ' Classifier Immunogenicity %Rank Scores')
    if mutations_visible:
        mutation_exists_points = overlap_df[mutation_exists]
        plt.scatter(mutation_exists_points['model_1_ranks'], mutation_exists_points['model_2_ranks'], facecolors='none', edgecolors='black', linewidths=0.5, s=30, label='Peptides w/\n Mutations')
    plt.legend(title='Gene', loc='center left', bbox_to_anchor=(1.25, 0.5), fontsize=9)
    plt.suptitle(name_X + ' vs ' + name_y + ' Classifier %Rank Scores Marginal Histogram by Gene', y=1.02)
    plt.savefig('Classifier/Images/z.' + name_X + ' vs ' + name_y + ' Classifier %Rank Scores Marginal Histogram.png')
    plt.show();

    # Calculate ROC curve and AUC for Prime model
    fpr_1, tpr_1, thresholds_1 = roc_curve(y, pred_X)
    roc_auc_1 = auc(fpr_1, tpr_1)
    # Calculate PR curve and AUCPR for Prime model
    precision_1, recall_1, thresholds_1_pr = precision_recall_curve(y, pred_X)
    pr_auc_1 = average_precision_score(y, pred_X)
    # Calculate ROC curve and AUC for Deep model
    fpr_2, tpr_2, thresholds_2 = roc_curve(y, pred_y)
    roc_auc_2 = auc(fpr_2, tpr_2)
    # Calculate PR curve and AUCPR for Deep model
    precision_2, recall_2, thresholds_2_pr = precision_recall_curve(y, pred_y)
    pr_auc_2 = average_precision_score(y, pred_y)

    # Create a figure with two subplots side by side
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    # Plot ROC curves in the first subplot
    plt.sca(axes[0])
    plt.plot(fpr_1, tpr_1, color='orange', lw=2, label=name_X + ' Model (AUC = {:.2f})'.format(roc_auc_1))
    plt.plot(fpr_2, tpr_2, color='blue', lw=2, label=name_y + ' Model (AUC = {:.2f})'.format(roc_auc_2))
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend()
    # Plot PR curves in the second subplot
    plt.sca(axes[1])
    plt.plot(recall_1, precision_1, color='orange', lw=2, label=name_X + ' Model (AUCPR = {:.2f})'.format(pr_auc_1))
    plt.plot(recall_2, precision_2, color='blue', lw=2, label=name_y + ' Model (AUCPR = {:.2f})'.format(pr_auc_2))
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall (PR) Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig('Classifier/Images/z.' + name_X + ' vs ' + name_y + ' Classifier AUC & AUCPR Plots.png')
    plt.show();

In [3]:
mainOverlapPlots(getComparisonDF("PRIME2.0", "prtr-deep"), 'PRIME2.0', 'Prime-Trained Deep')

In [4]:
mainOverlapPlots(getComparisonDF("PRIME2.0", "prtr-pars"), 'PRIME2.0', 'Prime-Trained Parsimonious')

In [8]:
mainOverlapPlots(getComparisonDF("prtr-deep", "prtr-pars"), 'Prime-Trained Deep', 'Prime-Trained Parsimonious')

In [7]:
mainOverlapPlots(getComparisonDF("PRIME2.0", "parsimonious"), 'PRIME2.0', 'Parsimonious')

In [6]:
mainOverlapPlots(getComparisonDF("PRIME2.0", "deep"), 'PRIME2.0', 'Deep')

In [5]:
mainOverlapPlots(getComparisonDF("deep", "parsimonious"), 'Deep', 'Parsimonious')

In [None]:
# Detecting all immunogenic peptides through classifier
# *Use below if data/models are saved to files*
A, L = 20, 9
classifiers = getAllModels(A, L)
model = copy.deepcopy(classifiers[7])
model.load_state_dict(torch.load('Classifier/Models/deep.pth')) # *Change path as required*
from sklearn.exceptions import UndefinedMetricWarning
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
gene_count0, gene_count1 = [], []
for geneName in ref_df['gene']:
    X_tensor = torch.nn.functional.one_hot(torch.LongTensor(convert_number(X_ref_by_gene[geneName])), num_classes=A).type(torch.FloatTensor)
    gene_count0.append(np.sum(np.array(y_ref_by_gene[geneName])==0))
    gene_count1.append(np.sum(np.array(y_ref_by_gene[geneName])==1))
    y_tensor = torch.Tensor(y_ref_by_gene[geneName])

    if X_tensor.shape[0] != 0:
        # Calculate ROC curve and AUC
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(
            np.concatenate((np.zeros(len(X_tensor[y_tensor == 0])) + 0,
                            np.zeros(len(X_tensor[y_tensor == 1])) + 1), axis=0),
            np.concatenate((model(X_tensor[y_tensor == 0]).detach().numpy(),
                            model(X_tensor[y_tensor == 1]).detach().numpy()), axis=0)
        )
        auc = sklearn.metrics.auc(fpr, tpr)
    else:
        auc = 0
    if(np.isnan(auc)):
        auc = 0
    auc_by_gene[geneName] = auc

In [9]:
# Plot Pan-Allelic Classifier Performance by Gene

# Calculate the positions for the bars in each group
bar_width = 0.2
x_positions = np.arange(len(ref_df['gene']))

# Create a figure and axes
fig, ax1 = plt.subplots(figsize=(10, 6))

# Create a twin axes on the right side for the counts
ax2 = ax1.twinx()

# Plot the bar charts on the right y-axis
ax2.bar(x_positions - bar_width/2, gene_count0, width=bar_width, alpha=0.5, color='lime', label='Count (Neg. in Test)')
ax2.bar(x_positions + bar_width/2, gene_count1, width=bar_width, alpha=0.5, color='green', label='Count (Pos. in Test)')

# Set the range for the right y-axis (counts)
ax2.set_ylim(0, max(max(gene_count0), max(gene_count1)))

# Label the right y-axis
ax2.set_ylabel('Gene Test Data Count')
ax2.legend(loc='upper right', fontsize=9)

# Plot the line charts for the AUC values on the left y-axis
ax1.plot(x_positions, auc_by_gene.values(), label='Deep Classifier', color='blue', marker='o')

# Set the ticks and labels for x-axis
ax1.set_xticks(x_positions)
ax1.set_xticklabels(ref_df['gene'], rotation=45, ha='right', rotation_mode='anchor')

# Set the labels and title for the plot
ax1.set_xlabel('Gene')
ax1.set_ylabel('AUC')
ax1.set_ylim(0, 1)
ax1.legend(fontsize=8, loc = 'center right')
ax1.set_title('Pan-Allelic Classifier Performance by Gene')

plt.tight_layout(pad=2)
plt.savefig('temp/Pan-Allelic Classifier Performance by Gene', dpi=300)
plt.show()


In [None]:
# Detecting all immunogenic peptides through classifier
A, L = 20, 9
classifiers = getAllModels(A, L)
model = copy.deepcopy(classifiers[7])
model.load_state_dict(torch.load('Classifier/Models/deep.pth')) # *Change path as required*

ref_peptides = {}
ref_imm = {}
my_peptides = {}

for geneName in ref_df['gene']:
    protein_sequence = ref_df[ref_df['gene'] == geneName]['seq'].values[0]
    sequence_length = len(protein_sequence)
    peptide_length = 9
    ref_peptides[geneName] = []
    for i in range(sequence_length - peptide_length + 1):
        peptide = protein_sequence[i:i + peptide_length]
        if '*' not in peptide and '*' not in protein_sequence[i:i + peptide_length + 1]:
            ref_peptides[geneName].append(peptide)

    ref_peptides_tensor = torch.nn.functional.one_hot(torch.LongTensor(convert_number(ref_peptides[geneName])), num_classes=A).type(torch.FloatTensor)
    ref_imm[geneName] = model(ref_peptides_tensor).detach().numpy()

allScores = []
for geneName in ref_df['gene']:
    allScores.extend(ref_imm[geneName])
for geneName in ref_df['gene']:
    my_peptides[geneName] = [peptide for peptide, score in zip(ref_peptides[geneName], ref_imm[geneName]) if score > np.percentile(allScores, 97.5)]

# Extract mutations from variants of Deep Classifier epitopes
errors = 0
error_epitopes = []
mutation_rows = {}
mutation_count=0
for gene, epitopes in my_peptides.items():
    mutation_rows[gene] = []
    for epitope in epitopes:
        for _, row in var_df[var_df['gene'] == gene].iterrows():
            if row['gene'] == gene:
                seq = row['seq']
                if epitope not in seq and epitope not in error_epitopes:
                    try:
                        mutation_count+=1
                        mut_type = getMutationType(epitope, getSeqAligned(epitope, seq), subOnly=1)
                        if mut_type != -1:
                            if mut_type[1] == -1:
                                mutation_type = mut_type[0]
                                details = -1
                            else:
                                mutation_type = mut_type[0]
                                details = mut_type[1]
                            mutation_rows[gene].append({'epitope': epitope, 'accID': row['accID'], 'mutation': getSeqAligned(epitope, seq), 'mutation_type': mutation_type, 'mutation_type_details': details})
                    except ValueError as e:
                        error_epitopes.append(epitope)
                        errors += 1
    print(gene, "'s  Deep Classifier  mutations  found  -  (", list(my_peptides.keys()).index(gene)+1, "/ 27 )")

# Save the dictionary
with open('temp/mutation_rows.pkl', 'wb') as file:
    pickle.dump(mutation_rows, file)

In [None]:
# Get all epitopes and mutations immunogenicity scores from classifier

epi, mut = [], []
for gene in mutation_rows.keys():
    for epitope in mutation_rows[gene]:
        epi.append(epitope['epitope'])
        m = epitope['mutation']
        if 'X' not in m and len(m) >= 8 and len(m) <= 14:
            mut.append(m)
epi, mut = np.unique(np.array(epi)), np.unique(np.array(mut))

peptides_classifier_scores = {}
for peptide in epi:
    peptide_tensor = torch.nn.functional.one_hot(torch.LongTensor(convert_number([peptide])), num_classes=A).type(torch.FloatTensor)
    pep_imm = model(peptide_tensor).detach().numpy()[0][0]
    peptides_classifier_scores[peptide] = {'Immunogenicity Score': pep_imm}

mutations_classifier_scores = {}
for mutation in mut:
    mutation_tensor = torch.nn.functional.one_hot(torch.LongTensor(convert_number([mutation])), num_classes=A).type(torch.FloatTensor)
    mut_imm = model(mutation_tensor).detach().numpy()[0][0]
    mutations_classifier_scores[mutation] = {'Immunogenicity Score': mut_imm}

In [None]:
# Remove mutation data where mutated peptides with 'X' in them and ensure correct length

for gene in mutation_rows.keys():
    for epi in mutation_rows[gene]:
        mutation = epi['mutation']
        if 'X' in mutation or len(mutation) < 8 or len(mutation) >14:
            mutation_rows[gene].remove(epi)
    print(gene, "'s  mutation data cleaned  -  (", list(mutation_rows.keys()).index(gene)+1, "/ 27 )")

NSP1 's  mutation data cleaned  -  ( 1 / 27 )
NSP2 's  mutation data cleaned  -  ( 2 / 27 )
NSP3 's  mutation data cleaned  -  ( 3 / 27 )
NSP4 's  mutation data cleaned  -  ( 4 / 27 )
NSP5 's  mutation data cleaned  -  ( 5 / 27 )
NSP6 's  mutation data cleaned  -  ( 6 / 27 )
NSP7 's  mutation data cleaned  -  ( 7 / 27 )
NSP8 's  mutation data cleaned  -  ( 8 / 27 )
NSP9 's  mutation data cleaned  -  ( 9 / 27 )
NSP10 's  mutation data cleaned  -  ( 10 / 27 )
NSP11 's  mutation data cleaned  -  ( 11 / 27 )
NSP12 's  mutation data cleaned  -  ( 12 / 27 )
NSP13 's  mutation data cleaned  -  ( 13 / 27 )
NSP14 's  mutation data cleaned  -  ( 14 / 27 )
NSP15 's  mutation data cleaned  -  ( 15 / 27 )
NSP16 's  mutation data cleaned  -  ( 16 / 27 )
SPIKE 's  mutation data cleaned  -  ( 17 / 27 )
NS3 's  mutation data cleaned  -  ( 18 / 27 )
E 's  mutation data cleaned  -  ( 19 / 27 )
M 's  mutation data cleaned  -  ( 20 / 27 )
NS6 's  mutation data cleaned  -  ( 21 / 27 )
NS7A 's  mutation data

In [None]:
# Load immunogenicity scores of peptides and their unique mutations into dataframe

count=0
mutations_data = []
recorded_mutations = []
for gene in mutation_rows.keys():
    gene_filtered_df = var_df[var_df['gene'] == gene]
    for pep in mutation_rows[gene]:
        mutation = pep['mutation']
        variant = pep['accID']
        if mutation not in recorded_mutations and 'X' not in mutation and len(mutation) >= 8 and len(mutation) <= 14:
            filtered_df = gene_filtered_df[gene_filtered_df['accID'] == variant]
            date = filtered_df['date'].values[0]
            location = filtered_df['country'].values[0]
            lab = filtered_df['lab'].values[0]
            recorded_mutations.append(mutation)
            gene = gene
            peptide = pep['epitope']
            score_peptide = float((peptides_classifier_scores)[peptide]['Immunogenicity Score'])
            score_mutation = float((mutations_classifier_scores)[mutation]['Immunogenicity Score'])
            if score_mutation / score_peptide <= 0:
                if score_mutation > score_peptide:
                    evasion_score = np.log(abs(score_mutation) / (score_peptide + 2 * score_mutation))
            else:
                evasion_score = np.log(score_mutation / score_peptide)
            mutations_data.append([variant, date, location, lab, gene, peptide, mutation, score_peptide, score_mutation, evasion_score])
            count+=1
            if count%500==0:
                print('---', count, 'gene mutations recorded.')
print('A total of ', count, 'gene unique mutations recorded.')
diff_df = pd.DataFrame(mutations_data, columns=['variant', 'date', 'location', 'lab', 'gene', 'peptide', 'mutation', 'score_peptide', 'score_mutation', 'evasion_score'])
filtered_diff_df = diff_df.drop_duplicates(subset=['peptide', 'mutation', 'evasion_score'], keep='first')

--- 500 gene mutations recorded.
--- 1000 gene mutations recorded.
--- 1500 gene mutations recorded.
--- 2000 gene mutations recorded.
A total of  2103 gene unique mutations recorded.


In [None]:
# Need to run this for ~an hour~ to record total mutations observed from peptides identified from deep classifier

# # Load immunogenicity scores of peptides and their unique mutations into dataframe

# count, total=0,0
# mutations_data = []
# for gene in mutation_rows.keys():
#     gene_filtered_df = var_df[var_df['gene'] == gene]
#     for pep in mutation_rows[gene]:
#         total +=1
# progress_bar = tqdm(total=total, desc='Progress')
# for gene in mutation_rows.keys():
#     gene_filtered_df = var_df[var_df['gene'] == gene]
#     for pep in mutation_rows[gene]:
#         progress_bar.update(1)
#         mutation = pep['mutation']
#         variant = pep['accID']
#         if 'X' not in mutation and len(mutation) >= 8 and len(mutation) <= 14:
#             filtered_df = gene_filtered_df[gene_filtered_df['accID'] == variant]
#             date = filtered_df['date'].values[0]
#             location = filtered_df['country'].values[0]
#             lab = filtered_df['lab'].values[0]
#             gene = gene
#             peptide = pep['epitope']
#             score_peptide = float((peptides_classifier_scores)[peptide]['Immunogenicity Score'])
#             score_mutation = float((mutations_classifier_scores)[mutation]['Immunogenicity Score'])
#             if score_mutation / score_peptide <= 0:
#                 if score_mutation > score_peptide:
#                     evasion_score = np.log(abs(score_mutation) / (score_peptide + 2 * score_mutation))
#             else:
#                 evasion_score = np.log(score_mutation / score_peptide)
#             mutations_data.append([variant, date, location, lab, gene, peptide, mutation, score_peptide, score_mutation, evasion_score])
# #             count+=1
# #             if count%500==0:
# #                 print('---', count, 'gene mutations recorded.')

# print('A total of ', count, 'gene unique mutations recorded.')
# diff_df = pd.DataFrame(mutations_data, columns=['variant', 'date', 'location', 'lab', 'gene', 'peptide', 'mutation', 'score_peptide', 'score_mutation', 'evasion_score'])
# filtered_diff_df = diff_df.drop_duplicates(subset=['peptide', 'mutation', 'evasion_score'], keep='first')

In [10]:
# Plotting number of unique mutations per gene

geneCount = []
for gene in mutation_rows.keys():
    geneCount.append(len(filtered_diff_df[filtered_diff_df['gene'] == gene]))

plt.figure(figsize=(9, 4))
x_positions = np.arange(len(ref_df['gene']))
plt.bar(x_positions, geneCount)

# Add count labels on top of each bar
for i, count in enumerate(geneCount):
    plt.text(i, count, count, ha='center', va='bottom', fontsize=8)

plt.xticks(x_positions, ref_df['gene'], rotation=45, ha='right', rotation_mode='anchor')
plt.xlabel('Gene')
plt.ylabel('Count')
plt.tight_layout(pad=1)
plt.title('Count of Unique Mutated Deep Classifier Peptides by Gene - Total = ' + str(sum(geneCount)))
plt.savefig('temp/count2.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# SPIKE, N, M, NSP3, NSP4, NSP12, NSP13, NS3, NS8
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10036809/

In [None]:
# Gathering Mutation Details Data

recorded_mutations = []
position_changes, residue_mutations = {}, {}
for gene in mutation_rows.keys():
    for epi in mutation_rows[gene]:
        mutation = epi['mutation']
        if mutation not in recorded_mutations and 'X' not in mutation and len(mutation) >= 8 and len(mutation) <= 14:
            count+=1
            recorded_mutations.append(mutation)
            epitope_residue, mutation_residue = epi['mutation_type_details'][0][1], epi['mutation_type_details'][0][2]
            if epitope_residue not in residue_mutations:
                residue_mutations[epitope_residue] = []
                residue_mutations[epitope_residue].append(mutation_residue)
            else:
                residue_mutations[epitope_residue].append(mutation_residue)
            if len(mutation) not in position_changes:
                position_changes[len(mutation)] = []
                position_changes[len(mutation)].append(epi['mutation_type_details'][0][0])
            else:
                position_changes[len(mutation)].append(epi['mutation_type_details'][0][0])

In [11]:
from matplotlib.path import Path
from matplotlib.patches import PathPatch, FancyArrowPatch

def plot_chord_diagram(residue_mutations):
    residues = list(residue_mutations.keys())
    num_residues = len(residues)

    # Count occurrences of mutated residues for each original residue
    counts = np.zeros((num_residues, num_residues))
    for i, residue in enumerate(residues):
        mutations = residue_mutations[residue]
        for mutation in mutations:
            j = residues.index(mutation)
            counts[i, j] += 1

    # Normalize counts to calculate angles
    angles = 2 * np.pi * counts / counts.sum(axis=1, keepdims=True)

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.axis('off')

    # Draw circles for residues
    circle_positions = np.linspace(0, 2 * np.pi, num_residues, endpoint=False)
    x = np.cos(circle_positions)
    y = np.sin(circle_positions)
    labels = residues

    for i in range(num_residues):
        ax.plot(x[i], y[i], marker='o', markersize=20, color='skyblue')
        ax.text(1.1 * x[i], 1.1 * y[i], labels[i], fontsize=14, ha='center', va='center')

    # Draw shaded lines representing mutations with line thickness based on occurrences
    max_occurrences = np.max(counts)
    for i in range(num_residues):
        for j in range(num_residues):
            if i != j:
                alpha = np.clip(angles[i, j], 0, 1)  # Clip alpha values to the valid range
                linewidth = 0.5 + 4 * (counts[i, j] / max_occurrences)  # Adjust line thickness based on occurrences
                path = Path([(x[i], y[i]), (x[j], y[j])])
                patch = PathPatch(path, facecolor='skyblue', edgecolor='gray', alpha=alpha, linewidth=linewidth)
                ax.add_patch(patch)

                arrow_start = (x[i], y[i])
                arrow_end = (x[j], y[j])
                arrow_dx = arrow_end[0] - arrow_start[0]
                arrow_dy = arrow_end[1] - arrow_start[1]
                arrow_style = "->"  # Set the arrow style
                arrow_color = "gray"  # Set the arrow color
                arrow_alpha = alpha  # Set the arrow alpha value
                arrow_linewidth = linewidth  # Set the arrow linewidth
                ax.annotate("", arrow_start, arrow_end, arrowprops=dict(arrowstyle=arrow_style, color=arrow_color,
                                                                        alpha=arrow_alpha, linewidth=arrow_linewidth))
    plt.savefig('temp/chord.png', dpi=300, bbox_inches='tight')
    plt.show();

plot_chord_diagram(residue_mutations)

In [12]:
# Plotting matrix of epitope residue mutations

# Get sorted list of epitope residues
epitope_residues = sorted(residue_mutations.keys())

# Get sorted list of mutated residues
mutated_residues = sorted(set([residue for residues in residue_mutations.values() for residue in residues]))

# Create an empty grid with zeros
grid = np.zeros((len(mutated_residues), len(epitope_residues)))

# Count occurrences of mutated residues for each epitope residue
for i, epitope_residue in enumerate(epitope_residues):
    for j, mutated_residue in enumerate(mutated_residues):
        grid[j, i] = residue_mutations[epitope_residue].count(mutated_residue)

# Create the grid plot
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(grid, cmap='Reds')

# Set x-axis and y-axis labels
ax.set_xticks(np.arange(len(epitope_residues)))
ax.set_yticks(np.arange(len(mutated_residues)))
ax.set_xticklabels(epitope_residues)
ax.xaxis.tick_top()
ax.set_yticklabels(mutated_residues)

# Set overall x-axis and y-axis labels
ax.set_xlabel('Epitope Residues')
ax.xaxis.set_label_position('top')
ax.set_ylabel('Mutation Residues')

# Add count numbers to the grid
for i in range(len(mutated_residues)):
    for j in range(len(epitope_residues)):
        text = ax.text(j, i, int(grid[i, j]), ha='center', va='center', color='black')

# Set title
ax.set_title('Residue Mutation Matrix')

# Create a colorbar
cbar = fig.colorbar(im, ax=ax, label='Counts')

plt.savefig('temp/gridResidues.png', dpi=300, bbox_inches='tight')
# Display the plot
plt.show()


In [13]:
# Plotting Mutation Position Counts

for i in list(position_changes.keys()):
    positions = position_changes[i]
    unique_nums, counts = np.unique(positions, return_counts=True)
    fig, ax = plt.subplots(figsize=((i/9)*12.25, 1.2))
    ax.set_xlim(0, i+1)
    ax.set_ylim(0, 1)
    # Iterate over unique numbers and their counts
    for num, count in zip(unique_nums, counts):
        # Calculate the circle radius and color based on the count
        radius = count / len(positions) * 2.5
        color = count / len(positions) * 3
        # Draw the circle
        circle = plt.Circle((num, 0.5), radius, fc='white', edgecolor='black')
        ax.add_patch(circle)
        circle = plt.Circle((num, 0.5), radius, fc='yellow', alpha=1, edgecolor='black')
        ax.add_patch(circle)
        circle = plt.Circle((num, 0.5), radius, fc='red', alpha=color)
        ax.add_patch(circle)
        # Add text inside the circle
        ax.text(num, 0.5, f"{num}:\n{count}", ha='center', va='center', color='black', alpha=1, fontweight='medium')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title('Counts of Position Mutations for ' + str(i) + '-mers Deep Classifier Peptides')
    plt.savefig('temp/'+str(i)+'circ.png', dpi=300, bbox_inches='tight')
    plt.show()

In [14]:
# Create a distribution plot for the Deep Classifier immunogenicity cost of the mutations

fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(filtered_diff_df['evasion_score'], bins=100, label='Evasion Score')
ax.set_xlabel('Mutation Immunogenicity Evasion Score')
ax.set_ylabel('Count')
ax.set_title("Distribution of Unique Mutations' Evasion Score")
ax.axvline(x=0, color='red', linestyle='--', label='Evasion Score = 0')
ax.axvline(x=np.percentile(filtered_diff_df['evasion_score'], 1), color='blue', linestyle='--', label='1st Percentile')
ax.legend()
plt.savefig('temp/evasion1.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Most dangerous mutated peptides

mutations_dangerous = filtered_diff_df[filtered_diff_df['evasion_score'] < np.percentile([x for x in filtered_diff_df['evasion_score'] if str(x) != 'nan'], 2)]['mutation'].tolist()

print('2% most dangerous mutated peptides - according to Evasion Score:')
print(mutations_dangerous)

2% most dangerous mutated peptides - according to Evasion Score:
['MFLSTLMKC', 'LQKKKVNIN', 'LQKEKVTSI', 'IEAMMFTSD', 'DYTVIEVQG', 'SDDYIAING', 'LSGHNLAKH', 'EVQGYKSVN', 'DNLKTLLSL', 'LYGEVITFD', 'LEGEVITFD', 'LKFSPPALQ', 'LKFYPPALQ', 'LKFTPPALQ', 'LKFIPPALQ', 'CERVLNVVC', 'KPLEFGVTS', 'VADAVIKTL', 'IAKYSVKSV', 'VIKNFATSI', 'NFVQMAPIS', 'LDNYYRKDN', 'NSRIKASMP', 'AYFDTWFSQ', 'TIYSLLKDC', 'DTDLTKPYI', 'DAQSFLKPG', 'DATTAYANS', 'ANEYRLYLD', 'TYGVCLFWN', 'SDRVVFVLW', 'NVTGLFKDC', 'NFWNTFIRL', 'LLPLVSIQC', 'LLPLVSSHC', 'LYPLSETKC', 'LPPLSETKC', 'LTGIAVEQD', 'WFHAISGTN', 'MYLNGPQNQ', 'SSRTSTPGS', 'SLRTSTPGS', 'SSRSSTPGS']


In [15]:
exp_ver=[]
for pep in filtered_diff_df['peptide']:
    if pep in X:
        exp_ver.append(pep)


exp_vr_df = filtered_diff_df[filtered_diff_df['peptide'].isin(exp_ver)]


fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(exp_vr_df['evasion_score'], bins=100, label='Evasion Score')
ax.set_xlabel('Experimentally-Verified Mutations Immunogenicity Evasion Score')
ax.set_ylabel('Count')
ax.set_title("Distribution of Experimentally-Verified Mutations' Evasion Score")
ax.axvline(x=0, color='red', linestyle='--', label='Evasion Score = 0')
#ax.axvline(x=np.percentile(filtered_diff_df['evasion_score'], 1), color='blue', linestyle='--', label='1st Percentile')
ax.legend()

plt.show()
exp_vr_df['peptide'].tolist()[np.argmin(exp_vr_df['evasion_score'])]

In [None]:
# None of the most dangerous peptides (identified by my deep classifier) are experimentally-verified (via IEDB) - one is a subset of another.
# There are 69 peptides out of all of the peptides identified by my deep classifier with mutations that are experimentally-verified
# One has a significantly low evasion score: NLWNTFTRL


In [16]:
# Check frequency of most dangerous mutations over time

# Filter the DataFrame to include only rows with peptides in `rank_mutations_dangerous`
most_dangerous_df = var_df[var_df['seq'].str.contains('|'.join(mutations_dangerous))]

# Group the filtered DataFrame by date and count the occurrences
grouped_df = most_dangerous_df.groupby('date').size().reset_index(name='count')

# Get the counts per date from the original DataFrame
var_df_counts = var_df['date'].value_counts()

# Normalise the counts by the total number of occurrences per date in var_df
grouped_df['normalised_count'] = grouped_df.apply(lambda row: row['count'] / var_df_counts[row['date']], axis=1)

# Convert the 'date' column to datetime format
grouped_df['date'] = pd.to_datetime(grouped_df['date'])

# Plot the counts over time
plt.figure(figsize=(14, 5))
plt.plot(grouped_df['date'], grouped_df['normalised_count'])
plt.xlabel('Date')
plt.ylabel('Count')
plt.title('Normalised Count of 2% Most Dangerous Mutations over Time')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('temp/mut_counts.png', dpi=300, bbox_inches='tight')
plt.show();

In [17]:
# Plot GISAID Submissions of Deep Classifier Peptides by Country

# Get all country names
locations_deep = filtered_diff_df['location'].tolist()
locations_all = var_df['country'].tolist()

# Map to correct country names
country_name_mapping = {
    'Aruba': 'Netherlands',
    'Bonaire': 'Netherlands',
    'Bosnia and Herzegovina': 'Bosnia and Herz.',
    'Bra': 'Brazil',
    'Braz': 'Brazil',
    'Brazi': 'Brazil',
    'Bulgari': 'Bulgaria',
    'Bulgar': 'Bulgaria',
    'Colo': 'Colombia',
    'Colombi': 'Colombia',
    'Curacao': 'Netherlands',
    'Czech Republic': 'Czechia',
    'Eswatini': 'eSwatini',
    'Esto': 'Estonia',
    'French Guiana': 'France',
    'Germa': 'Germany',
    'Gha': 'Ghana',
    'Ghan': 'Ghana',
    'Gibraltar': 'United Kingdom',
    'Guadeloupe': 'France',
    'Hong Kong': 'China',
    'Mauritius': 'France',
    'Martinique': 'France',
    'Mayotte': 'France',
    'P': 'Poland',
    'Pola': 'Poland',
    'Polan': 'Poland',
    'Reunion': 'France',
    'UK': 'United Kingdom',
    'USA': 'United States of America',
}

locations_deep = [country_name_mapping.get(location, location) for location in locations_deep]
locations_all = [country_name_mapping.get(location, location) for location in locations_all]

# Calculate the ratio of counts for each country
country_counts = {'Country': [], 'Ratio': []}
added_countries = set()
for country in locations_deep:
    if country not in added_countries:
        count_deep = locations_deep.count(country)
        count_all = locations_all.count(country)
        ratio = count_deep / count_all if count_all != 0 and count_all>50 else 0
        country_counts['Country'].append(country)
        country_counts['Ratio'].append(ratio)
        added_countries.add(country)
gdf = gpd.GeoDataFrame(country_counts)

# Read the world shapefile (or any other shapefile containing country boundaries)
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))

# Merge the world shapefile with the ratio data
merged = world.merge(gdf, left_on='name', right_on='Country', how='left')

# Set missing ratios to zero
merged['Ratio'] = merged['Ratio'].fillna(0)

# Create the choropleth map
fig, ax = plt.subplots(figsize=(12, 8))

# Create a colormap with autumn colors
cmap_colors = ['white'] + list(plt.cm.autumn_r(range(256)))
cmap = ListedColormap(cmap_colors)

# Plot the countries with ratios
merged.plot(column='Ratio', cmap=cmap, linewidth=0.25, ax=ax, edgecolor='black', legend=True)

# Remove x and y tick markers, and x and y tick labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])

# Remove axis labels
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('Ratio of GISAID Submissions of Mutations of Deep Classifier Peptides by Country')

# Save the map as PNG
plt.savefig('temp/map1.png', dpi=300, bbox_inches='tight')

plt.show()


In [18]:
plt.figure(figsize=(12, 3))
plt.bar(country_counts['Country'], country_counts['Ratio']);
plt.xticks(rotation=90);

In [19]:
# Plot time-series plots of Deep Classifier mutations and global new daily cases

# Count occurrences of dates from all GISAID variant data
all_dates = var_df['date'].tolist()
all_dates_count = Counter(all_dates)

# Normalise counts in all_dates
total_samples = len(all_dates)  # Assuming 27 as the total number of samples
all_dates_count = {date: count / 27 for date, count in all_dates_count.items()}

# Count occurrences of dates from all Deep-Classifier-mutation-filtered variant data
filtered_dates = filtered_diff_df['date'].tolist()
filtered_dates_count = Counter(filtered_dates)

# Normalise counts in filtered_dates relative to all_dates
normalised_date_counts = {date: count / all_dates_count[date] for date, count in filtered_dates_count.items()}

# Sort the dictionary by date in chronological order
sorted_dates = sorted(normalised_date_counts.keys())

# Extract the sorted dates and corresponding normalized counts
sorted_counts = [normalised_date_counts[date] for date in sorted_dates]

# Convert the date strings to datetime objects
dates = [datetime.strptime(date, '%Y-%m-%d') for date in sorted_dates]

# Plot the time-series data
fig, ax = plt.subplots(figsize=(10, 6))

# Variant emergence lines
filtered_variant_dates = [['2020-09-15', 'Alpha'], ['2020-12-15', 'Lambda'], ['2020-05-15', 'Beta'], ['2020-11-15', 'Gamma'], ['2020-10-15', 'Delta'], ['2020-03-15', 'Epsilon']]
for variant in filtered_variant_dates:
    variant_date = datetime.strptime(variant[0], '%Y-%m-%d')
    variant_label = variant[1]
    ax.axvline(x=variant_date, linestyle='--', color='red', alpha=0.5)
    ax.text(variant_date, max(sorted_counts), variant_label, color='red', rotation=90, va='top')

ax.plot(dates, sorted_counts, label='Normalised Counts', color='orange')
ax.set_xlabel('Date')
ax.set_ylabel('Normalised Count')
ax.set_title('Time-Series Plot of Normalised Frequency of Submitted Deep Classifier Peptide Mutations')

# Set the x-axis tick labels
date_formatter = DateFormatter('%Y-%m-%d')
ax.xaxis.set_major_formatter(date_formatter)
ax.xaxis.set_major_locator(MonthLocator())
fig.autofmt_xdate(rotation=45)

plt.tight_layout()
plt.legend()
plt.savefig('temp/time_muts.png', dpi=300, bbox_inches='tight')
plt.show()

# Plot Global Daily New COVID-19 Cases Data

world_dates, world_cases = [], []
with open('Mutation Analysis/world-covid-daily-cases.csv', 'r') as file:
    reader = csv.reader(file)
    for row in reader:
        world_dates.append(datetime.strptime(row[0], '%Y-%m-%d'))
        world_cases.append(int(row[1]))

fig, ax = plt.subplots(figsize=(11.5, 6))
ax.plot(world_dates, world_cases, label='New Daily Cases', color='black')
ax.xaxis.set_major_locator(plt.MaxNLocator(10))

# Iterate over variant dates and add vertical lines with labels
variant_dates = [['2020-09-15', 'Alpha'], ['2020-12-15', 'Lambda'], ['2020-05-15', 'Beta'], ['2020-11-15', 'Gamma'], ['2020-10-15', 'Delta'], ['2020-03-15', 'Epsilon']]
for variant in variant_dates:
    variant_date = datetime.strptime(variant[0], '%Y-%m-%d')
    variant_label = variant[1]
    ax.axvline(x=variant_date, linestyle='-.', color='red', alpha=0.5)
    ax.text(variant_date, ax.get_ylim()[1] * 0.15, variant_label, color='red', rotation=90, va='bottom')

fig.autofmt_xdate()
ax.set_xlabel('Date')
ax.set_ylabel('Count')
ax.set_title('Global New Daily Cases by Time')
ax.legend()
plt.savefig('temp/time_world.png', dpi=300, bbox_inches='tight')
plt.show()


In [20]:
plt.hist(filtered_diff_df[(filtered_diff_df['date'] > '2020-07-10') & (filtered_diff_df['date'] < '2020-09-01')]['location']);