# Influenza Qnet Predictions 2022-2023
- Predicting dominant strain for the 2022-2023 flu season using Qnet

In [102]:
# basic imports
import os 
import numpy as np
import pandas as pd
import math
import warnings
warnings.filterwarnings('ignore')
import tqdm
from tqdm.notebook import trange, tqdm

# visualization
import seaborn as sns 
import matplotlib.pyplot as plt
%matplotlib inline

# other
from Bio import SeqIO
from collections import Counter
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.manifold import MDS
import Levenshtein as lev

# qnet
from quasinet.qnet import Qnet, qdistance, qdistance_matrix, membership_degree, save_qnet, load_qnet
from quasinet.qseqtools import list_trained_qnets, load_trained_qnet

## Data Sources
- NCBI: https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Protein
- GISAID: https://platform.epicov.org/epi3/cfrontend#586f5f

## Downloading Data
**GISAID (NCBI has few strains for this season): For creating 2021-2022 Qnets:**
1. Download amino acid data from both sources with the following filters:
    - Host: Human
    - Flu Season: 
        - Northern strains from 10/01/2021 - 5/01/2022
        - Southern strains from 04/01/2021 - 10/1/2021
        - Flu season dates from [CDC](https://www.cdc.gov/flu/school-business/travelersfacts.htm)
    - Segment: HA (4) and NA (6)
2. File names for raw data: HEMISPHERE_SEQUENCE_SEGMENT_SEASON
    - HEMISPHERE: "north" or "south"
    - SEQUENCE: "h1n1" or "h3n2"
    - SEGMENT: "ha" or "na"
    - SEASON: year the season begins in (ex. 21 for 2021-2022)
    
**NCBI (less duplicates): For finding centroid (need all strains):**
1. Download amino acid data from both sources with the following filters:
    - H1N1 HA: [Link](https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Protein&VirusLineage_ss=H1N1%20subtype,%20taxid:114727&HostLineage_ss=Homo%20sapiens%20(human),%20taxid:9606&ProtNames_ss=hemagglutinin&LabHost_s=include&SLen_i=550%20TO%20600&QualNum_i=0&CollectionDate_dr=2000-01-01T00:00:00.00Z%20TO%202022-05-01T23:59:59.00Z)
    - H1N1 NA: [Link](https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Protein&VirusLineage_ss=H1N1%20subtype,%20taxid:114727&HostLineage_ss=Homo%20sapiens%20(human),%20taxid:9606&LabHost_s=include&QualNum_i=0&CollectionDate_dr=2000-01-01T00:00:00.00Z%20TO%202022-05-01T23:59:59.00Z&SLen_i=450%20TO%20500&ProtNames_ss=neuraminidase)
    - H3N2 HA: [Link](https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Protein&HostLineage_ss=Homo%20sapiens%20(human),%20taxid:9606&LabHost_s=include&QualNum_i=0&CollectionDate_dr=2000-01-01T00:00:00.00Z%20TO%202022-05-01T23:59:59.00Z&VirusLineage_ss=H3N2%20subtype,%20taxid:119210&SLen_i=550%20TO%20650&ProtNames_ss=hemagglutinin)
    - H3N2 NA: [Link](https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Protein&HostLineage_ss=Homo%20sapiens%20(human),%20taxid:9606&LabHost_s=include&QualNum_i=0&CollectionDate_dr=2000-01-01T00:00:00.00Z%20TO%202022-05-01T23:59:59.00Z&SLen_i=450%20TO%20500&ProtNames_ss=neuraminidase&VirusLineage_ss=H3N2%20subtype,%20taxid:119210)
    
2. File names for raw data: SEQUENCE_SEGMENT
    - SEQUENCE: "h1n1" or "h3n2"
    - SEGMENT: "ha" or "na"

In [103]:
NCBI_PATH = 'raw_data/ncbi/'
GISAID_PATH = 'raw_data/gisaid/'

FILES = ['north_h1n1_ha_21', 'north_h1n1_na_21', 'north_h3n2_ha_21', 'north_h3n2_na_21',
         'south_h1n1_ha_21', 'south_h1n1_na_21', 'south_h3n2_ha_21', 'south_h3n2_na_21']

FILES_3CLUSTER = ['north_h1n1_na_21', 'north_h3n2_na_21',
                  'south_h1n1_na_21', 'south_h3n2_na_21']

NA_TRUNC = 469
HA_TRUNC = 566

## Creating New Qnet
- FASTA Header: Isolate name | Type | Segment | Collection date
- Truncate NA at 469 amino acids, HA at 566 amino acids
- Create Qnet from previous season and hemisphere

In [105]:
# input: fasta file name, length to truncate each sequence
# output: dataframe of sequences
def parse_fasta(file_name, trunc):
    acc = []
    seq = []
    for record in SeqIO.parse(file_name, 'fasta'):
        if len(record.seq) < trunc:
            continue
        acc.append(record.id.split('|')[0])
        seq.append(np.array(record.seq[:trunc].upper()))
    df = pd.DataFrame({'name':acc, 'sequence':seq})
    return df


# input: fasta file name, length to truncate each sequence
# output: dataframe of sequences, including date as a column
def parse_fasta_withdate(file_name, trunc):
    acc = []
    seq = []
    dat = []
    for record in SeqIO.parse(file_name, 'fasta'):
        if len(record.seq) < trunc:
            continue
        acc.append(record.id)
        dat.append(int(record.description.split('|')[2][:4]))
        seq.append(np.array(record.seq[:trunc].upper()))
    df = pd.DataFrame({'name':acc, 'sequence':seq, 'year':dat})
    return df


# input: dataframe of sequences, number of samples
# output: array of nucleotide lists
def sequence_array(seq_df, sample_size):
    seqs = seq_df['sequence'].sample(sample_size, random_state = 42).values
    seq_lst = []
    for seq in seqs:
        seq_lst.append(seq)
    return np.array(seq_lst)


# input: name to call qnet, array of nucleotide lists, number of nucleotides
# output: save qnet as joblib
def train_save_qnet(name, seq_arr, num_nuc):
    myqnet = Qnet(feature_names=['x'+str(i) for i in np.arange(num_nuc)],n_jobs=1)
    myqnet.fit(seq_arr)
    save_qnet(myqnet, 'qnet_models/' + name + '.joblib')

In [None]:
# create qnets for each dataset
for FILE in tqdm(FILES):
    TRUNC = HA_TRUNC
    if 'na' in FILE:
        TRUNC = NA_TRUNC
    seq_df = parse_fasta(GISAID_PATH + FILE + ".fasta", TRUNC) 
    seq_arr = sequence_array(seq_df, min(1000, len(seq_df)))
    train_save_qnet(FILE, seq_arr, TRUNC)

## Loading Past Qnet
- Show possible Qnets with `list_trained_qnets()`

In [107]:
# input: virus, protein, year
# output: qnet 
def load_influenza_qnet(virus, protein, year):
    myqnet = load_trained_qnet('influenza', virus + ';' + protein + ';' + str(year))
    TRUNC = HA_TRUNC
    if protein == 'na':
        TRUNC = NA_TRUNC
    # add feature names
    myqnet.feature_names=['x'+str(i) for i in np.arange(TRUNC)]
    return myqnet


# input: list of available years, virus, protein
# output: dict of qnets
def make_qnet_dict(years, virus, protein):
    qnet_dict = {}
    for year in years:
        qnet_dict[year] = load_influenza_qnet(virus, protein, year)
    return qnet_dict

In [305]:
H1N1_HA_YEARS = [2000, 2001, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]
H1N1_NA_YEARS = [2000, 2001, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]
H3N2_HA_YEARS = [2004, 2005, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]
H3N2_NA_YEARS = [2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]

# dict of qnets and years
# qnet_dict['h1n1_na']['qnets'][year] accesses a qnet
# qnet_dict['h1n1_na']['years'] accesses list of years available
qnet_dict = {
    'h1n1_ha':{'qnets':make_qnet_dict(H1N1_HA_YEARS, 'h1n1', 'ha'), 'years':H1N1_HA_YEARS},
    'h1n1_na':{'qnets':make_qnet_dict(H1N1_NA_YEARS, 'h1n1', 'na'), 'years':H1N1_NA_YEARS},
    'h3n2_ha':{'qnets':make_qnet_dict(H3N2_HA_YEARS, 'h3n2', 'ha'), 'years':H3N2_HA_YEARS},
    'h3n2_na':{'qnets':make_qnet_dict(H3N2_NA_YEARS, 'h3n2', 'na'), 'years':H3N2_NA_YEARS}
}

## Create Dictionary of Sequence Data and Distance Matrices
- `seq_dict[STRAIN]` contains a dataframe for strains of that variety `data` and the corresponding distance matrix `dist_matrix`
- If combined north and south data for that strain exceeds 1000, randomly sample 1000 sequences

In [108]:
seq_dict = {}

for i in trange(len(FILES)):
    FILE = FILES[i]
    FILE1 = FILES[(i+4)%8]
    NAME = FILE[6:13] # ex. 'h1n1_ha'
    
    # load this years qnet
    myqnet = load_qnet('qnet_models/' + FILE + '.joblib')
    
    # adjust trunc
    TRUNC = HA_TRUNC
    if 'na' in FILE:
        TRUNC = NA_TRUNC
        
    # load gisaid data (P^t - population at time t)
    gisaid_df1 = parse_fasta(GISAID_PATH + FILE + ".fasta", TRUNC)
    gisaid_df2 = parse_fasta(GISAID_PATH + FILE1 + ".fasta", TRUNC)
    gisaid_df = pd.concat([gisaid_df1, gisaid_df2]).sample(min(1000, len(gisaid_df1) + len(gisaid_df2)), random_state = 42)
    # make sequence matrix with gisaid data
    seqs_matrix = np.array(list(gisaid_df['sequence'].values))
    dist_matrix = qdistance_matrix(seqs_matrix, seqs_matrix, myqnet, myqnet)
    
    # save to dict
    seq_dict[FILE] = {'data':gisaid_df, 'dist_matrix':dist_matrix}

  0%|          | 0/8 [00:00<?, ?it/s]

## Predictions
Q-Centroid: $$\widehat{x}^{t+1} = argmin_{x\in P} \sum_{y \in P^t} \theta(x,y)$$
- Where $P^t$ is the sequence population at time $t$ and $P = P^t \cup P^{t-1} \cup P^{t-2} \cup \dots \cup P^1$.
- $\theta(x,y)$ is the qdistance between x and y in their respective Qnets

In [138]:
rec_files = []
rec_names = []
rec_seqs = []

# set to True to include sequences from P^t-1 U P^t-2 U ... U P^1 - previous populations
PAST = False

for FILE in tqdm(FILES):
    NAME = FILE[6:13] # ex. 'h1n1_ha'
    
    # load this years qnet
    myqnet = load_qnet('qnet_models/' + FILE + '.joblib')
    
    # adjust trunc
    TRUNC = HA_TRUNC
    if 'na' in FILE:
        TRUNC = NA_TRUNC
        
    # load gisaid data (P^t - population at time t)
    gisaid_df = seq_dict[FILE]['data']
    # make sequence matrix with gisaid data
    cur_seqs_matrix = np.array(list(gisaid_df['sequence'].values))
    
    cur_rec_names = []
    cur_rec_seqs = []
    cur_rec_qdist_sums = []
    
    # loop through available past data
    if PAST:
        # load ncbi data (P^t-1 U P^t-2 U ... U P^1 - previous populations)
        ncbi_df = parse_fasta_withdate(NCBI_PATH + NAME + ".fasta", TRUNC)
        # loops through years with qnet available
        for yr in tqdm(qnet_dict[NAME]['years']):
            # filter ncbi df by year and drop the year column
            df = ncbi_df[ncbi_df['year'] == yr].drop(columns = 'year')
            if len(df) == 0:
                continue
            seq_df = df.sample(min(1000, len(df)), random_state = 42)

            # compute qdistance matrix
            past_seqs_matrix = np.array(list(seq_df['sequence'].values))
            dist_matrix = qdistance_matrix(past_seqs_matrix, cur_seqs_matrix, qnet_dict[NAME]['qnets'][yr], myqnet)

            # compute q-centroid using formula
            sums = list(dist_matrix.sum(axis=1))
            cur_min_ind = np.argmin(sums)
            cur_rec_name = seq_df.iloc[cur_min_ind].values[0]
            cur_rec_seq = seq_df.iloc[cur_min_ind].values[1]

            # save to current results
            cur_rec_names.append(cur_rec_name)
            cur_rec_seqs.append(cur_rec_seq)
            cur_rec_qdist_sums.append(min(sums))
    
    # find centroid of sequences in P^t
    dist_matrix = seq_dict[FILE]['dist_matrix']
    sums = list(dist_matrix.sum(axis=1))
    cur_rec_names.append(gisaid_df.iloc[np.argmin(sums)].values[0])
    cur_rec_seqs.append(gisaid_df.iloc[np.argmin(sums)].values[1])
    cur_rec_qdist_sums.append(min(sums))
    
    # find centroid among current results
    min_ind = np.argmin(cur_rec_qdist_sums)
    rec_name = cur_rec_names[min_ind]
    rec_seq = cur_rec_seqs[min_ind]
    
    # save results
    rec_files.append(FILE[:13])
    rec_names.append(rec_name)
    rec_seqs.append(rec_seq)

  0%|          | 0/8 [00:00<?, ?it/s]

In [139]:
for i in range(len(rec_seqs)):
    rec_seqs[i] = ''.join(rec_seqs[i])
    
predictions = pd.DataFrame({'strain':rec_files, 'name':rec_names, 'sequence':rec_seqs})
predictions

Unnamed: 0,strain,name,sequence
0,north_h1n1_ha,A/Netherlands/00068/2022,MKAILVVLLYTFTTANADTLCIGYHANNSTDTVDTVLEKNVTVTHS...
1,north_h1n1_na,A/Lyon/820/2021,MNPNQKIITIGSICMAIGTANLILQIGNIISIWVSHSIQIGNQSQI...
2,north_h3n2_ha,A/Denmark/370/2022,MKTIIALSNILCLVFAQKIPGNDNSTATLCLGHHAVPNGTIVKTIT...
3,north_h3n2_na,A/Michigan/UOM10042819294/2021,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNSPP...
4,south_h1n1_ha,A/Cote_D'Ivoire/1270/2021,MKAILVVLLYTFTTANADTLCIGYHANNSTDTVDTVLEKNVTVTHS...
5,south_h1n1_na,A/Dakar/35/2021,MNPNQKIITIGSICMAIGTANLILQIGNIISIWVSHSIQIGNQSQI...
6,south_h3n2_ha,A/Saint-Martin/00754/2022,MKTIIALSNILCLVFAQKIPGNDNSTATLCLGHHAVPNGTIVKTIT...
7,south_h3n2_na,A/Texas/12723/2022,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNSPP...


In [140]:
# save dataframe as csv
os.makedirs('results', exist_ok=True)  
predictions.to_csv('results/influenza_qnet_predictions_2022_2023.csv', index=False)  

## Multi-Cluster Predictions
- Compute distance matrix between sequences in $P^t$
- Create three clusters, then find the dominant strain of each cluster

In [141]:
# input: dataframe of sequences, qdistance matrix, number of clusters
# output: recommended names, corresponding sequences
def multiple_cluster_predictions(seq_df, qdist_matrix, n_clusters = 3):
    # convert qdist_matrix to dataframe
    columns = np.arange(0, qdist_matrix.shape[1])
    index = np.arange(0, qdist_matrix.shape[0])
    dm = pd.DataFrame(qdist_matrix, columns=columns, index=index)
        
    # convert distance matrix to embedding
    embedding = MDS(n_components=2, dissimilarity="precomputed", random_state=42)
    dm_embed = embedding.fit_transform(dm)
    
    # cluster the distance matrix
    clustering = KMeans(n_clusters=n_clusters, random_state=42)
    clustering_predictions = clustering.fit_predict(dm_embed)
    
    # find unique clusters
    unique_clusters = np.unique(clustering_predictions)
    
    rec_names = []
    rec_seqs = []
    for class_ in unique_clusters:
        # separate distance matrix into submatrices
        wanted_names = dm.columns[clustering_predictions == class_]
        sub_dist_matrix = dm.loc[wanted_names, wanted_names]
        # find centroid
        pred_ind = sub_dist_matrix.median(axis=1).idxmin()
        rec_name = seq_df.iloc[int(pred_ind)].values[0]
        rec_seq = seq_df.iloc[int(pred_ind)].values[1]
        rec_names.append(rec_name)
        rec_seqs.append(''.join(rec_seq))
        
    return rec_names, rec_seqs

In [142]:
rec_files = []
rec_names_0 = []
rec_seqs_0 = []
rec_names_1 = []
rec_seqs_1 = []
rec_names_2 = []
rec_seqs_2 = []

for FILE in tqdm(FILES_3CLUSTER):
    # find centroid for each of 3 clusters
    rec_names, rec_seqs = multiple_cluster_predictions(seq_dict[FILE]['data'], seq_dict[FILE]['dist_matrix'], 3)
    rec_files.append(FILE[:13])
    rec_names_0.append(rec_names[0])
    rec_seqs_0.append(rec_seqs[0])
    rec_names_1.append(rec_names[1])
    rec_seqs_1.append(rec_seqs[1])
    rec_names_2.append(rec_names[2])
    rec_seqs_2.append(rec_seqs[2])

  0%|          | 0/4 [00:00<?, ?it/s]

In [143]:
predictions_3cluster = pd.DataFrame({'strain':rec_files, 
                                     'name 0':rec_names_0, 
                                     'sequence 0':rec_seqs_0, 
                                     'name 1':rec_names_1, 
                                     'sequence 1':rec_seqs_1, 
                                     'name 2':rec_names_2, 
                                     'sequence 2':rec_seqs_2})
predictions_3cluster

Unnamed: 0,strain,name 0,sequence 0,name 1,sequence 1,name 2,sequence 2
0,north_h1n1_na,A/Netherlands/10646/2022,MNPNQKIITIGSICMAIGTANLILQIGNTISIWVSHSIQIGNQSQI...,A/Sydney/234/2022,MNPNQKIITIGSICMTIGTANLILQIGNMISIWVSHSIQIGNQSQI...,A/Wisconsin/03/2021,MNTNQRIITIGTVCLIVGIISLLLQIGNIVSLWVSHSIQTRWENHT...
1,north_h3n2_na,A/Maine/02/2022,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNSPP...,A/Michigan/UOM10042819294/2021,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNSPP...,A/Netherlands/10082/2022,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNSPP...
2,south_h1n1_na,A/Switzerland/86136/2022,MNPNQKIITIGSICMAIGTANLILQIGNIISIWVSHSIQIGNQSQI...,A/Wisconsin/04/2021,MNTNQRIITIGTVCLIVGIISLLLQIGNIVSLWVSHSIQTKWENHT...,A/Wisconsin/05/2021,MNTNQRIITIGTVCLIVGIISLLLQIGNIVSLWVSHSIQTKWENHT...
3,south_h3n2_na,A/Congo/313/2021,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNPPP...,A/Texas/12723/2022,MNPNQKIITIGSVSLTISTICFFMQIAILITTVTLHFKQYEFNSPP...,A/Netherlands/00037/2022,MNPNQKIITIGSVSLTISTICFLMQIAILITTVTLHFKQYEFNSPX...


In [144]:
# save dataframe as csv
os.makedirs('results', exist_ok=True)  
predictions_3cluster.to_csv('results/influenza_qnet_predictions_3cluster_2022_2023.csv', index=False)  