# IRAT vs. Enet Comparison
- **Note: Enet and Qnet are interchangeable, Qnet was the old name**
- Compares risk assesment from IRAT and risk assessment using q-distance
- Use both NA and HA segments
- For each strain previously analyzed by IRAT
    - Collect strains one year leading up to month of analysis
    - For example, the "A/swine/Shandong/1207/2016" strain was assessed by IRAT in July 2020, so we will use human H1N1 strains circulating between July 1, 2019 through June 30, 2020
    - Note: had difficulty finding 'A/duck/New York/1996', only NA available, do not use in final results
    - For the following strains, only use upper bound of date due to small sample size
        - H1N2, H5N1, H5N6, H7N7, H9N2
    - Strains with 'Qnet Sample' = -1 have no available human strains
    - Construct a Enet using these strains **if there are more than 30 strains in the population for both NA and HA**
    - Compute the average q-distance among the strain in question and the circulating human strains for both NA and HA
        - Do this 10 times for each strain, with 100 samples from the human population each time
        - This is to compute variances
    - Average the NA and HA averages (using arithmetic and geometric mean)
    
### Filling in the Table
- To get a risk prediction score for each IRAT strain with **less than 30 strains in the population for either NA or HA**
    - Use all human strains that match the H number, i.e. H5NX for H5N6
- List of remaining strains
- For the following strains, only use upper bound of date due to small sample size
    - H5N2, H5N6, H5N8
- To be done
    - A/duck/New York/1996 (missing target strain HA)
    - A/Jiangxi-Donghu/346/2013 (only 5 H10N8 human strains ever recorded, 3 of them are the IRAT target strain itself)
    - A/Bangladesh/0994/2011 (only 12 H9N2 human strains up to Feb 2014)

In [1]:
# basic imports
import os 
import numpy as np
import pandas as pd
from scipy.stats import gmean
import warnings
warnings.filterwarnings('ignore')

# other
from Bio import SeqIO
from collections import Counter
import Levenshtein as lev
from tqdm.notebook import trange, tqdm

# enet
from quasinet.qnet import Qnet, qdistance, qdistance_matrix, membership_degree, save_qnet, load_qnet
from quasinet.qseqtools import list_trained_qnets, load_trained_qnet
from zedstat.textable import textable
from emergenet import Enet, save_model, load_model

## Data Sources
- IRAT (CDC): https://www.cdc.gov/flu/pandemic-resources/monitoring/irat-virus-summaries.htm#H1N2variant
- GISAID: https://platform.epicov.org/epi3/cfrontend#586f5f

In [41]:
GISAID_PATH = 'raw_data/gisaid/'
IRAT_PATH = 'results/'
ENET_PATH = 'enet_models/irat_enets/'

NA_TRUNC = 449
HA_TRUNC = 550

df = pd.read_csv(IRAT_PATH + 'irat_data.csv')

## Creating New Enet
- Truncate NA at 449 amino acids, HA at 550 amino acids (to prevent too many strains from being filtered out)
- Construct an Enet using these strains **if there are more than 30 strains in the population** after truncating to match the length of the IRAT strain
- Make sure to include IRAT sequence in training the Enet
- Save Enet as "VIRUS_NAME_na.joblib" or "VIRUS_NAME_ha.joblib"

In [42]:
# input: fasta file name, length to truncate each sequence
# output: dataframe of sequences
def parse_fasta(file_name, trunc):
    acc = []
    seq = []
    for record in SeqIO.parse(file_name, 'fasta'):
        if len(record.seq) < trunc:
            continue
        acc.append(record.id.split('|')[0])
        seq.append(np.array(record.seq[:trunc].upper()))
    df = pd.DataFrame({'name':acc, 'sequence':seq})
    return df


# input: dataframe of sequences, number of samples, IRAT strain
# output: array of nucleotide lists
def sequence_array(seq_df, sample_size, IRAT_strain, random_state = 42):
    seqs = seq_df['sequence'].sample(sample_size, random_state = random_state).values
    seq_lst = []
    for seq in seqs:
        seq_lst.append(seq)
    seq_lst.append(np.array(list(IRAT_strain)))
    return np.array(seq_lst)


# input: name to call enet, array of nucleotide lists, number of nucleotides
# output: save enet as joblib
def train_save_enet(name, seq_arr, num_nuc):
    myenet = Qnet(feature_names=['x'+str(i) for i in np.arange(num_nuc)],n_jobs=1)
    myenet.fit(seq_arr)
    save_qnet(myenet, ENET_PATH + name + '.joblib')

In [None]:
for i in trange(len(df)):
    STRAIN = df['Influenza Virus'].iloc[i].replace('/',':')
    ha_irat_seq = df['HA Sequence'].iloc[i][:HA_TRUNC]
    na_irat_seq = df['NA Sequence'].iloc[i][:NA_TRUNC]
    
    # skip A/duck/New York/1996, A/Jiangxi-Donghu/346/2013, A/Bangladesh/0994/2011
    if i == 1 or i == 20 or i == 22:
        continue
    
    ha_df = parse_fasta(GISAID_PATH + STRAIN + "_ha.fasta", HA_TRUNC)
    na_df = parse_fasta(GISAID_PATH + STRAIN + "_na.fasta", NA_TRUNC)
    
    # skip if less than 30 sequences available or enet already exists 
    if len(ha_df) < 30 or os.path.exists(ENET_PATH + STRAIN + '_ha.joblib'):
        continue
    if len(na_df) < 30 or os.path.exists(ENET_PATH + STRAIN + '_na.joblib'):
        continue
    
    ha_arr = sequence_array(ha_df, min(1000, len(ha_df)), ha_irat_seq)
    na_arr = sequence_array(na_df, min(1000, len(na_df)), na_irat_seq)
    
    train_save_enet(STRAIN + '_ha', ha_arr, HA_TRUNC)
    train_save_enet(STRAIN + '_na', na_arr, NA_TRUNC)

## Average Qdistance
- Compute average qdistance between IRAT strain and the rest of the strains
- Do this 10 times for each strain, with 100 samples from the human population each time
- This is to compute variances

In [43]:
ha_sample = []
na_sample = []
avg_qdists_ha_10 = []
avg_qdists_na_10 = []
avg_qdists_both_10 = []
avg_qdists_geom_10 = []

for i in trange(len(df)):
    STRAIN = df['Influenza Virus'].iloc[i].replace('/',':')
    
    # skip if enet doesn't exist or one of the sequences doesn't exist
    if not os.path.exists(ENET_PATH + STRAIN + '_ha.joblib')\
    or not os.path.exists(ENET_PATH + STRAIN + '_na.joblib')\
    or df['HA Sequence'].iloc[i] == '-1' or df['HA Sequence'].iloc[i] == '-1':
        ha_sample.append(-1)
        na_sample.append(-1)
        avg_qdists_ha_10.append(-1)
        avg_qdists_na_10.append(-1)
        avg_qdists_both_10.append(-1)
        avg_qdists_geom_10.append(-1)
        continue
        
    # load enets
    ha_enet = load_qnet(ENET_PATH + STRAIN + '_ha.joblib')
    na_enet = load_qnet(ENET_PATH + STRAIN + '_na.joblib')
    
    # access irat sequences and all sequences
    ha_irat_seq = np.array(list(df['HA Sequence'].iloc[i][:HA_TRUNC]))
    na_irat_seq = np.array(list(df['NA Sequence'].iloc[i][:NA_TRUNC]))
    ha_df = parse_fasta(GISAID_PATH + STRAIN + "_ha.fasta", HA_TRUNC)
    na_df = parse_fasta(GISAID_PATH + STRAIN + "_na.fasta", NA_TRUNC)
    
    avg_qdists_ha = []
    avg_qdists_na = []
    avg_qdists_both = []
    avg_qdists_geom = []
    # repeat 10 times for variance computation
    for j in range(42, 52):
        ha_arr = sequence_array(ha_df, min(100, len(ha_df)//2), ha_irat_seq, random_state=j)
        na_arr = sequence_array(na_df, min(100, len(na_df)//2), na_irat_seq, random_state=j)
        # compute qdistance sum
        ha_qdist_sum = 0
        na_qdist_sum = 0
        num_ha = len(ha_arr)
        num_na = len(na_arr)
        for k in range(len(ha_arr)):
            qdist = qdistance(ha_irat_seq, ha_arr[k], ha_enet, ha_enet)
            if np.isnan(qdist):
                num_ha -= 1
                continue
            ha_qdist_sum += qdist
        for k in range(len(na_arr)):
            qdist = qdistance(na_irat_seq, na_arr[k], na_enet, na_enet)
            if np.isnan(qdist):
                num_na -= 1
                continue
            na_qdist_sum += qdist
        # compute qdistance averages
        avg_qdists_ha.append(ha_qdist_sum/num_ha)
        avg_qdists_na.append(na_qdist_sum/num_na)
        avg_qdists_both.append((ha_qdist_sum + na_qdist_sum)/(num_ha + num_na))
        avg_qdists_geom.append(np.sqrt((ha_qdist_sum/num_ha) * (na_qdist_sum/num_na))) 
    
    ha_sample.append(len(ha_df))
    na_sample.append(len(na_df))
    avg_qdists_ha_10.append(avg_qdists_ha)
    avg_qdists_na_10.append(avg_qdists_na)
    avg_qdists_both_10.append(avg_qdists_both) 
    avg_qdists_geom_10.append(avg_qdists_geom)

df['HA Qnet Sample'] = ha_sample
df['NA Qnet Sample'] = na_sample
df['HA Qdistance'] = avg_qdists_ha_10
df['NA Qdistance'] = avg_qdists_na_10
df['Arithmetic Mean'] = avg_qdists_both_10
df['Geometric Mean'] = avg_qdists_geom_10

# save dataframe as csv
os.makedirs('results', exist_ok=True)
df.to_csv('results/irat_average_qdistances.csv', index=False)  

  0%|          | 0/23 [00:00<?, ?it/s]

## A/duck/New York/1996
- A/duck/New York/1996 is missing target strain HA
- Cannot compute score

## A/Bangladesh/0994/2011 & A/Jiangxi-Donghu/346/2013
- Compute risk score using every Enet from the other strains, and take the average among NA and HA
- Take the geometric mean of the resulting NA and HA averages

In [44]:
df = pd.read_csv(IRAT_PATH + 'irat_data.csv')
df_filled = pd.read_csv('results/irat_average_qdistances.csv')

for n in tqdm([20, 22]):
    ha_irat_seq = np.array(list(df['HA Sequence'].iloc[n][:HA_TRUNC]))
    na_irat_seq = np.array(list(df['NA Sequence'].iloc[n][:NA_TRUNC]))
    
    ha_risk = []
    na_risk = []
    both_risk = []
    geom_mean = []

    for i in range(len(df)):
        STRAIN = df['Influenza Virus'].iloc[i].replace('/',':')

        # skip if enet doesn't exist
        if not os.path.exists(ENET_PATH + STRAIN + '_ha.joblib') or not os.path.exists(ENET_PATH + STRAIN + '_na.joblib'):
            continue
        # skip duck enet
        if STRAIN == 'A:duck:New York:1996':
            continue

        # load enets
        ha_enet = load_qnet(ENET_PATH + STRAIN + '_ha.joblib')
        na_enet = load_qnet(ENET_PATH + STRAIN + '_na.joblib')

        # access irat sequences and all sequences
        ha_df = parse_fasta(GISAID_PATH + STRAIN + "_ha.fasta", HA_TRUNC)
        na_df = parse_fasta(GISAID_PATH + STRAIN + "_na.fasta", NA_TRUNC)
        ha_arr = sequence_array(ha_df, min(1000, len(ha_df)), ha_irat_seq)
        na_arr = sequence_array(na_df, min(1000, len(na_df)), na_irat_seq)

        # compute qdistance sum
        ha_qdist_sum = 0
        na_qdist_sum = 0
        num_ha = len(ha_arr)
        num_na = len(na_arr)
        for j in range(len(ha_arr)):
            qdist = qdistance(ha_irat_seq, ha_arr[j], ha_enet, ha_enet)
            if np.isnan(qdist):
                num_ha -= 1
                continue
            ha_qdist_sum += qdist
        for j in range(len(na_arr)):
            qdist = qdistance(na_irat_seq, na_arr[j], na_enet, na_enet)
            if np.isnan(qdist):
                num_na -= 1
                continue
            na_qdist_sum += qdist

        # compute qdistance averages
        ha_risk.append(ha_qdist_sum/num_ha)
        na_risk.append(na_qdist_sum/num_na)
        both_risk.append((ha_qdist_sum + na_qdist_sum)/(num_ha + num_na))
        geom_mean.append(np.sqrt((ha_qdist_sum/num_ha) * (na_qdist_sum/num_na)))
        
    # save to results dataframe
    df_filled.at[n, 'HA Qdistance'] = ha_risk
    df_filled.at[n, 'NA Qdistance'] = na_risk
    df_filled.at[n, 'Arithmetic Mean'] = both_risk
    df_filled.at[n, 'Geometric Mean'] = geom_mean

  0%|          | 0/2 [00:00<?, ?it/s]

In [45]:
df_filled.iloc[[20, 22]]

Unnamed: 0,Influenza Virus,Virus Type,Dates of Risk Assessment,Potential Emergence Estimate,Potential Impact Estimate,Summary Risk Score Category,HA Sequence,NA Sequence,HA Qnet Sample,NA Qnet Sample,HA Qdistance,NA Qdistance,Arithmetic Mean,Geometric Mean
20,A/Bangladesh/0994/2011,H9N2,Feb 2014,5.6,5.4,Moderate,METVSLMTILLLVTTSNADKICIGHQSTNSTETVDTLTETNVPVTH...,MNPNQKIIALGSASLTIAIICLLIQIAILATTMTLHFMQNEHTNST...,-1,-1,"[0.042555813282309105, 0.2860528703519202, 0.0...","[0.027112393537369644, 0.054320689076068074, 0...","[0.034834103409839376, 0.17018677971399415, 0....","[0.03396748381927615, 0.12465387691405136, 0.0..."
22,A/Jiangxi-Donghu/346/2013,H10N8,Feb 2014,4.3,6.0,Moderate,MYKIVVIIALLGAVKGLDKICLGHHAVANGTIVKTLTNEQEEVTNA...,MNPNQKIITIGSVSLGLVILNILLHIVSITVTVLVLPGNGNNESCN...,-1,-1,"[0.04238151681917192, 0.277615833158554, 0.032...","[0.01782870236153502, 0.0776012718465409, 0.03...","[0.03010510959035347, 0.17760855250254745, 0.0...","[0.02748831477190638, 0.14677650267614667, 0.0..."


In [46]:
# save dataframe as csv
df_filled.to_csv('results/irat_average_qdistances.csv', index=False)  

## Average Results

In [47]:
df_filled = pd.read_csv('results/irat_average_qdistances.csv')
df_filled['Avg. HA Qdistance'] = df_filled['HA Qdistance'].apply(eval).apply(np.mean)
df_filled['Var. HA Qdistance'] = df_filled['HA Qdistance'].apply(eval).apply(np.var)
df_filled['Avg. NA Qdistance'] = df_filled['NA Qdistance'].apply(eval).apply(np.mean)
df_filled['Var. NA Qdistance'] = df_filled['NA Qdistance'].apply(eval).apply(np.var)
df_filled['Avg. Arithmetric Mean'] = df_filled['Arithmetic Mean'].apply(eval).apply(np.mean)
df_filled['Var. Arithmetric Mean'] = df_filled['Arithmetic Mean'].apply(eval).apply(np.var)
df_filled['Avg. Geometric Mean'] = df_filled['Geometric Mean'].apply(eval).apply(np.mean)
df_filled['Var. Geometric Mean'] = df_filled['Geometric Mean'].apply(eval).apply(np.var)
df_filled = df_filled.sort_values(by='Potential Emergence Estimate', ascending=False)
df_filled.to_csv('results/irat_average_qdistances.csv', index=False)
df_filled

Unnamed: 0,Influenza Virus,Virus Type,Dates of Risk Assessment,Potential Emergence Estimate,Potential Impact Estimate,Summary Risk Score Category,HA Sequence,NA Sequence,HA Qnet Sample,NA Qnet Sample,...,Arithmetic Mean,Geometric Mean,Avg. HA Qdistance,Var. HA Qdistance,Avg. NA Qdistance,Var. NA Qdistance,Avg. Arithmetric Mean,Var. Arithmetric Mean,Avg. Geometric Mean,Var. Geometric Mean
0,A/swine/Shandong/1207/2016,H1N1,Jul 2020,7.5,6.9,Moderate,MEARLFVLFCAFTTLKADTICVGYHANNSTDTVDTILEKNVTVTHS...,MNPNQKIITIGSICMTIGIASLILQIGNIISIWISHSIQIENQNQS...,8583,8583,...,"[0.05675103606380652, 0.05683624423196644, 0.0...","[0.04351973803384621, 0.043379545858900644, 0....",0.09328,1.305515e-08,0.020273,1.173676e-08,0.056777,3.112539e-09,0.043487,1.112853e-08
3,A/Ohio/13/2017,H3N2,Jul 2019,6.6,5.8,Moderate,MKTIIALSHILCLVFAQKLPGNDNNMATLCLGHHAVPNGTIVKTIT...,MNPNQKIITIGSVSLIIATICFLMQIAILVTTITLHFKQHNCDSSP...,12389,12388,...,"[0.024309921801497162, 0.023944381329675853, 0...","[0.023614312540517383, 0.02309467144393617, 0....",0.018423,1.920942e-07,0.030189,8.534729e-08,0.024303,6.31153e-08,0.023581,8.367947e-08
18,A/Hong Kong/125/2017,H7N9,May 2017,6.5,7.5,Moderate-High,MNTQILVFALIAIIPTNADKICLGHHAVSNGTKVNTLTERGVEVVN...,MNPNQKILCTSATAITIGAIAVLIGIANLGLNIGLHLKPGCNCSHS...,437,437,...,"[0.015096954599283424, 0.01674120024270414, 0....","[0.011713440103236638, 0.0127559787449982, 0.0...",0.028721,4.439764e-05,0.00571,1.270936e-08,0.017216,1.13378e-05,0.01272,2.685473e-06
19,A/Shanghai/02/2013,H7N9,Apr 2016,6.4,7.2,Moderate-High,MNTQILVFALIAIIPTNADKICLGHHAVSNGTKVNTLTERGVEVVN...,MNPNQKILCTSATAIIIGAIAVLIGMANLGLNIGLHLKPGCNCSHS...,178,178,...,"[0.0044008285297015455, 0.00470131803408878, 0...","[0.004300267277149767, 0.004559548154797955, 0...",0.005488,5.417886e-08,0.003548,1.185785e-08,0.004518,9.606804e-09,0.004411,6.396642e-09
21,A/Anhui-Lujiang/39/2018,H9N2,Jul 2019,6.2,5.9,Moderate,METVSLITILLVATASNADKICIGYQSTNSTETVDTLTENNVPVTH...,MNPNQKITAIGSVSLIIAIICLLMQIAILTTTMTLHFGQKECSNPS...,30,30,...,"[0.07525499471451194, 0.09384468115584135, 0.0...","[0.05324481306835133, 0.06364392174963164, 0.0...",0.026598,1.757813e-05,0.155139,0.001012103,0.090869,0.0003088926,0.064073,0.0001240755
4,A/Indiana/08/2011,H3N2,Dec 2012,6.0,4.5,Moderate,MKTIIAFSCILCLIFAQKLPGSDNSMATLCLGHHAVPNGTLVKTIT...,MNPNQKIITIGSVSLIIATICFLMQIAILVTTVTLHFKQHDYNSPP...,2298,2298,...,"[0.029919477755705103, 0.029752337598051272, 0...","[0.02061295671075932, 0.02049590227994502, 0.0...",0.051676,1.164524e-06,0.008631,2.453624e-06,0.030153,3.560676e-07,0.021035,2.454742e-06
2,A/California/62/2018,H1N2,Jul 2019,5.8,5.7,Moderate,MKVKLMVLLCTFTATYADTICVGYHANNSTDTVDTVLEKNVTVTHS...,MNPNQKIITIGSISLTLAAMCFLMQTAILVTNVTLHFNQCECHYPP...,55,55,...,"[0.06490895308022405, 0.06795093420062281, 0.0...","[0.06436640757832801, 0.06740140489247336, 0.0...",0.104613,0.001269566,0.058658,6.438462e-06,0.081636,0.0003460772,0.077365,0.0002051369
20,A/Bangladesh/0994/2011,H9N2,Feb 2014,5.6,5.4,Moderate,METVSLMTILLLVTTSNADKICIGHQSTNSTETVDTLTETNVPVTH...,MNPNQKIIALGSASLTIAIICLLIQIAILATTMTLHFMQNEHTNST...,-1,-1,...,"[0.034834103409839376, 0.17018677971399415, 0....","[0.03396748381927615, 0.12465387691405136, 0.0...",0.2078,0.01708841,0.182338,0.02159965,0.195142,0.01588109,0.184311,0.01637705
10,A/Sichuan/06681/2021,H5N6,Oct 2021,5.3,6.3,Moderate,MENIVLLLAIVSLVKSDQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,MNPNQKITCISATGVTLSIVSLLIGITNLGLNIGLHYKVSDSTTIN...,45,45,...,"[0.23168966429463134, 0.2169574620772385, 0.21...","[0.1449046891390129, 0.136728309069802, 0.1326...",0.365294,0.00145732,0.048914,2.950898e-05,0.207104,0.0003617637,0.13325,9.234817e-05
8,A/Vietnam/1203/2004,H5N1,Nov 2011,5.2,6.6,Moderate,MEKIVLLFAIVSLVKSDQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,MNPNQKIITIGSICMVTGIVSLMLQIGNMISIWVSHSIHTGNQHQS...,257,243,...,"[0.08820288247283928, 0.07886987661675382, 0.0...","[0.03363047149021574, 0.03285205746614019, 0.0...",0.167335,0.0001016787,0.010313,1.56671e-05,0.088824,4.218056e-05,0.040928,7.636776e-05


In [2]:
pd.read_csv('results/irat_average_qdistances.csv')[:22].corr()[:2][['Avg. Geometric Mean']]

Unnamed: 0,Avg. Geometric Mean
Potential Emergence Estimate,-0.706739
Potential Impact Estimate,-0.44708
