# Influenza Enet Predictions
- Predicting dominant strains using Emergenet

In [3]:
import os 
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

## Create Enet Models
- Truncate HA at 565 amino acids, we only need HA models
- Give Enet the same name as the data file
    
### Running Processes

Computations are done in:
- `enet_train.py`
- `run_enet_train.sh`

To run, navigate to terminal and do

`chmod ugo+rwx run_enet_train.sh`

`./run_enet_train.sh`

## Predictions
E-Centroid: $$x_{*}^{t+\delta} = argmin_{y\in \bigcup_{r\leq t}H^{\tau}} \sum_{x \in {H^t}} \theta(x,y) - |H^t|A \text{ln}\omega_y$$
- $x_{*}^{t+\delta}$ is the dominant strain in the upcoming flu season at time $t+\omega$
- $H^t$ is the sequence population at time $t$
- $\theta(x,y)$ is the e-distance between $x$ and $y$ in their respective Enets
- $A = \frac{1-\alpha}{\sqrt{8}N^2}$, where $\alpha$ is a fixed significance level and $N$ is the sequence length considered
- $\text{ln}\omega_y$ is the membership degree of sequence $y$
- **Predict dominant strain based on HA data only** 
    - Then take corresponding NA strain
- Perform MeanShift clustering on the q-distance matrix for $H^t$
    - Cluster on $H^t$, then compute $x_{*}^{t+\delta}$ for the top ten largest clusters (we will use the largest two clusters for our predictions)

### Running Processes

Computations are done in:
- `enet_predictions.py`
- `run_enet_predictions.sh`

To run, navigate to terminal and do

`chmod ugo+rwx run_enet_predictions.sh`

`./run_enet_predictions.sh`

## Aggregate Predictions

In [4]:
PRED_DIR = 'results/enet_predictions/seasonal_predictions/'
DATA_DIR = 'raw_data/merged/'
DM_DIR = 'results/enet_predictions/distance_matrices/'

FILES = ['north_h1n1', 'north_h3n2', 'south_h1n1', 'south_h3n2']

NORTH_YEARS = []
for i in np.arange(3, 24):
    YEAR = ''
    if i < 10:
        YEAR += '0' + str(i)
    else:
        YEAR += (str(i))
    if i + 1 < 10:
        YEAR += '_0' + str(i + 1)
    else:
        YEAR += '_' + str(i + 1)
    NORTH_YEARS.append(YEAR)
        
SOUTH_YEARS = []
for i in np.arange(3, 24):
    if i < 10:
        SOUTH_YEARS.append('0' + str(i))
    else:
        SOUTH_YEARS.append(str(i))

In [5]:
for FILE in FILES:
    pred_df = pd.DataFrame(columns=['season',
                                    'name_0','cluster_size_0','ha_acc_0','ha_seq_0','na_acc_0','na_seq_0',
                                    'name_1','cluster_size_1','ha_acc_1','ha_seq_1','na_acc_1','na_seq_1',
                                    'name_2','cluster_size_2','ha_acc_2','ha_seq_2','na_acc_2','na_seq_2']) 
    YEARS = []
    if FILE[:5] == 'north':
        YEARS = NORTH_YEARS
    else:
        YEARS = SOUTH_YEARS
    for i in range(21):
        if not os.path.isfile(PRED_DIR + FILE + '/' + FILE + '_' + YEARS[i] + '_predictions.csv'):
            df1 = pd.DataFrame({'season':[YEARS[i]]})
            for j in range(3):
                df1['name_'+str(j)] = -1
                df1['cluster_size_'+str(j)] = -1
                df1['ha_acc_'+str(j)] = -1
                df1['ha_seq_'+str(j)] = -1
                df1['na_acc_'+str(j)] = -1
                df1['na_seq_'+str(j)] = -1
            pred_df = pd.concat([pred_df, df1])
            continue
                
        df = pd.read_csv(PRED_DIR + FILE + '/' + FILE + '_' + YEARS[i] + '_predictions.csv')
        
        # expand to larger dataframe
        df1 = pd.DataFrame({'season':[YEARS[i]]})
        for j in range(3):
            df1['name_'+str(j)] = df['name'].values[j]
            df1['cluster_size_'+str(j)] = df['cluster_size'].values[j]
            df1['ha_acc_'+str(j)] = df['acc'].values[j]
            df1['ha_seq_'+str(j)] = df['sequence'].values[j]
            df1['na_acc_'+str(j)] = df['acc_na'].values[j]
            df1['na_seq_'+str(j)] = df['sequence_na'].values[j]
            
        # append to pred_df
        pred_df = pd.concat([pred_df, df1])
        
    # Enet recommendation accession, name, sequence
    pred_df.to_csv('results/enet_predictions/' + FILE + '_predictions.csv', index=False)