In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#setting base directories
basedir = "/content/drive/My Drive/MSc Research"
notebook_directory = basedir + "/notebooks"
data_directory = basedir + "/data"
pdb_directory = basedir + "/pdb"
installs_directory = basedir + "/installs"
alignment_directory = data_directory + "/alignments"
tmp_dir = alignment_directory + "/tmp/"
signal_p_directory = installs_directory + "/signalp-5.0b"
os.environ['LIBGL_ALWAYS_INDIRECT'] = '1'

In [None]:
#pip installs
!apt-get update
!pip install biopython
!pip install logomaker
!pip install py3Dmol
!pip install catboost

In [None]:
#installs
os.chdir(installs_directory)
!apt-get install muscle
!apt-get install -y cmake libfftw3-dev libomp-dev



In [None]:

%load_ext pycodestyle_magic

In [None]:

import re
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc, roc_auc_score
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import logomaker
from sklearn.metrics import classification_report, accuracy_score
from sklearn import svm
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import GridSearchCV
from sklearn.feature_selection import SelectFpr, f_classif
from sklearn.svm import SVC
from io import StringIO
from Bio import AlignIO, SeqIO
import logging
import seaborn as sns
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
import cProfile
from IPython.utils import io
import tqdm.notebook
import os
import py3Dmol
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
import subprocess
import shlex
import shutil
import stat
from collections import Counter
from ast import literal_eval
import io
import base64
from PIL import Image
from sklearn.ensemble import RandomForestClassifier
import fnmatch
from sklearn.model_selection import cross_val_score
import xgboost as xgb
from sklearn.linear_model import LogisticRegression
from imblearn.over_sampling import SMOTE



In [None]:
os.chdir(data_directory)


In [None]:
#method to align sequences from fasta files
def align_sequences(input_fasta, input_fasta_2 = '', mode = 'msa'):


  if mode == 'msa':
    cmd = [
          'muscle', '-diags', '-maxiters', '3', '-sv', '-in', input_fasta
          ]
  else:
    print("Invalid mode.")
    return

  try:
    result = subprocess.run(cmd, capture_output=True, text=True)
  except subprocess.CalledProcessError as e:
    print(f"Error running MUSCLE: {e}")
    return
  if result.stderr:
    print("Error: " + result.stderr)
  return result.stdout


In [None]:
#method to get the sequences as a file
def get_as_file(contents, file):
    print('file: ' + file)
    try:
        with open(file, 'w') as f:
            f.write(contents)
    except Exception as e:
        print(f"Error writing alignment to file: {e}")
    return file

In [None]:
"""Replace ambiguous characters with a gap or remove them."""
def clean_sequence(sequence):
    return Seq(str(sequence).replace('X', '-').replace('B', '-').replace('J', '-').replace('Z', '-').replace('.', '-').replace('*', '-'))


In [None]:
"""Clean ambiguous characters in sequences before alignment."""
def preprocess_fasta(input_fasta):
    temp_fasta = input_fasta + '.tmp'
    cleaned_records = []
    for record in SeqIO.parse(input_fasta, "fasta"):
        record.seq = clean_sequence(record.seq)
        cleaned_records.append(record)

    SeqIO.write(cleaned_records, temp_fasta, "fasta")

    os.replace(temp_fasta, input_fasta)

    return input_fasta

In [None]:
def process_alignment(input_fasta, output_fasta, mode='msa', input_fasta_2 = ''):
    res_file = None
    input_fasta = preprocess_fasta(input_fasta)
    aligned_seq = align_sequences(input_fasta, input_fasta_2=input_fasta_2, mode=mode)
    if aligned_seq:
      print('Alignment successful.')
      res_file = get_as_file(aligned_seq, output_fasta)
    else:
        print('Error: sequence alignment failed.')

    return res_file

In [None]:
#method for multiple sequence alignment
def msa(input_files, output_files, mode='msa'):
  res = list()
  for i in range(len(input_files)):
    out = None
    try:
      out = process_alignment(input_files[i], output_files[i])
    except Exception as e:
      logging.error(f"Error processing {input_files[i]}: {e}")
    res.append(out)
  return res

In [None]:
#method for calculating conservation scores of alignments
def calculate_conservation_score(fasta_file):
    """
    Calculate a simple conservation score for each column in the alignment.
    The score is the fraction of the most common residue in the column.
    """
    alignment = AlignIO.read(fasta_file, "fasta")
    conservation_scores = []
    for i in range(alignment.get_alignment_length()):
        column = alignment[:, i]  # Get the i-th column
        counts = Counter(column)  # Count the occurrences of each residue
        most_common_residue, most_common_count = counts.most_common(1)[0]
        score = most_common_count / len(column)  # Calculate conservation score
        conservation_scores.append(score)

    return conservation_scores


In [None]:
#method for calculating scores of alignments
def calculate_alignment_score(fasta_file):
    """
    Calculate the alignment score based on matches and gaps.
    A higher score indicates better alignment quality.
    """
    alignment = AlignIO.read(fasta_file, "fasta")
    score = 0
    total_columns = alignment.get_alignment_length()

    for i in range(total_columns):
        column = alignment[:, i]
        matches = column.count(column[0])
        gaps = column.count('-')
        score += matches - gaps

    return score / total_columns

In [None]:
#method to combine sequences from multiple files in a single file
def combine_sequences(files, combined_file):
    """
    combine avian and human files under one file per subtype
    """
    with open(combined_file, 'a') as outfile:
        for file in files:
            with open(file, 'r') as infile:
                outfile.write(infile.read())

In [None]:
h3, h5, h7 = "h3.fasta", "h5.fasta", "h7.fasta"

In [None]:
h3_human, h5_human, h7_human = "h3_human.fasta", "h5_human.fasta", "h7_human.fasta"
h3_avian, h5_avian, h7_avian = "h3_avian.fasta", "h5_avian.fasta", "h7_avian.fasta"

In [None]:
h3_pre = alignment_directory + "/pre/" + h3
h5_pre = alignment_directory + "/pre/" + h5
h7_pre = alignment_directory + "/pre/" + h7

In [None]:
#combine sequences for all hosts in a single subtype file for alignment
combine_sequences([alignment_directory + "/pre/" + h3_human, alignment_directory + "/pre/" + h3_avian], h3_pre)
combine_sequences([alignment_directory + "/pre/" + h5_human, alignment_directory + "/pre/" + h5_avian], h5_pre)
combine_sequences([alignment_directory + "/pre/" + h7_human, alignment_directory + "/pre/" + h7_avian], h7_pre)

In [None]:
h3_aligned = alignment_directory + "/msa/" + h3
h5_aligned = alignment_directory + "/msa/" + h5
h7_aligned = alignment_directory + "/msa/" + h7

In [None]:
#if no alignment files are available in the directory, use code in this cell
h3_aligned, h5_aligned, h7_aligned = msa([alignment_directory+ "/pre/" + h3, alignment_directory+ "/pre/" + h5, alignment_directory+ "/pre/" + h7], [h3_aligned, h5_aligned, h7_aligned])


In [None]:
#method to test alignment accuracy based on conservation and alignment scores
def test_alignment_accuracy(alignment_file):
  scores = calculate_conservation_score(alignment_file)

  average_score = sum(scores) / len(scores)
  print(f"\nAverage Conservation Score: {average_score:.2f}")

  alignment_score = calculate_alignment_score(alignment_file)
  print(f"\nAlignment Score: {alignment_score:.2f}")


In [None]:
#method to check if fasta is aligned
def is_fasta_aligned(input_fasta, dir):
  is_aligned = False
  with open(os.path.join(dir, input_fasta), 'r') as file:
    alignment = AlignIO.read(file, 'fasta')
    alignment_lengths = [len(record.seq) for record in alignment]
    is_aligned = len(set(alignment_lengths)) == 1

  return is_aligned


In [None]:
#method for signal peptide detection
def run_signalp(fasta_dir, fasta_file):

  signalp_output = ""

  original_dir = os.getcwd()

  signalp_executable = os.path.join(signal_p_directory, 'bin')

  os.chmod(signalp_executable  + "/signalp", stat.S_IRWXU)

  os.chdir(signalp_executable)

  shutil.copy(os.path.join(fasta_dir, fasta_file), os.path.join(signalp_executable, fasta_file))

  fasta_path = os.path.join(fasta_dir, fasta_file)

  command = f'./signalp -fasta "{fasta_file}" -format short -gff3 -prefix output'

  output_dir = os.path.join(signalp_executable, "output.gff3")

  result = subprocess.run(shlex.split(command), capture_output=True, text=True)


  if result.returncode != 0:
    print("SignalP did not run successfully.")
    print(result.stderr)
    exit(1)
  else:
    with open(output_dir, 'r') as f:
      signalp_output = f.read()

  os.remove(os.path.join(signalp_executable, fasta_file))
  os.chdir(original_dir)
  os.remove(output_dir)
  return signalp_output


In [None]:
#get signal peptide proteins for all subtypes
h3_signal = run_signalp(alignment_directory + "/msa/", h3)
h5_signal = run_signalp(alignment_directory + "/msa/", h5)
h7_signal = run_signalp(alignment_directory + "/msa/", h7)

In [None]:
#remove n terminals of proteins based on detected signal peptide indices
def remove_n_terminal(signalp_output, input_fasta, output_fasta):
    crop_positions = {}
    for line in signalp_output.strip().split('\n'):
        if line.startswith('#'):
            continue
        parts = line.split('\t')
        if len(parts) >= 9:
            sequence_id = parts[0]
            start = int(parts[3]) - 1
            end = int(parts[4])
            crop_positions[sequence_id] = (start, end)

    with open(input_fasta, "r") as handle:
        fasta_sequences = list(SeqIO.parse(handle, "fasta"))

    # Crop the sequences
    for seq_record in fasta_sequences:
        seq_id = seq_record.id
        if seq_id in crop_positions:
            start, end = crop_positions[seq_id]
            seq_record.seq = seq_record.seq[end:]

    with open(output_fasta, "w") as output_handle:
        SeqIO.write(fasta_sequences, output_handle, "fasta")

    return output_fasta

In [None]:
h3_cropped = alignment_directory + "/signalp" + "/" + h3
h5_cropped = alignment_directory + "/signalp" + "/" + h5
h7_cropped= alignment_directory + "/signalp" + "/" + h7


In [None]:
#if mature proteins not available in hx_cropped directories, run cell to cleave proteins
remove_n_terminal(h3_signal, h3_aligned, h3_cropped)
remove_n_terminal(h5_signal, h5_aligned, h5_cropped)
remove_n_terminal(h7_signal, h7_aligned, h7_cropped)

In [None]:
h3_realigned = alignment_directory + "/msa2/" + h3
h5_realigned = alignment_directory + "/msa2/" + h5
h7_realigned = alignment_directory + "/msa2/" + h7


In [None]:
#realign cleaved proteins
msa([h3_cropped, h5_cropped, h7_cropped], [h3_realigned, h5_realigned, h7_realigned])

In [None]:
#method to extract subtype and year details from the header of a given sample
def get_header_details(header):
  subtype = []
  year = []
  host = []
  if (len(header) <= 0):
    print("header empty.")
  else:
    parts = header.split("|")
    if len(parts) >= 3:
        date = parts[0].strip()
        subtype = parts[1].strip().replace('_', '').replace('/', '').replace('A', '')
        year = date.split('-')[0] if '-' in date else []
        host = parts[2].strip()
    else:
        print("Invalid header format.")
  return host, subtype, year




In [None]:
#method to add the host sequence data provided in sequence_data to the dataframe given
def get_df_from_sequences(df, fasta_file):
    required_columns = ['label', 'year', 'subtype', 'sequence']
    sequences = []
    data = []
    if not all(col in df.columns for col in required_columns):
        print("Dataframe missing necessary columns. Adding...")
        df = pd.DataFrame(columns=required_columns)
    if not fasta_file:
        print("Fasta file not found.")
        return df

    with open(fasta_file, 'r') as file:
        alignment = AlignIO.read(file, 'fasta')
        for record in alignment:
        #each sample stores sequence data following an isolate id in the header
          isolate_id_pattern = r"(.*?)(EPI_ISL_\d+)([\w-]*)"
          matches = re.findall(isolate_id_pattern, record.id, re.DOTALL)
          if not matches:
              continue
          header, isolate_id, sequence = matches[0]
          sequence = str(record.seq).strip().replace("\n", "")
          if len(sequence) > 0:
                host, subtype, year = get_header_details(header.strip())
                data.append((get_label(host), year, subtype, sequence))
          else:
              print("Sequence not found for " + isolate_id)

    new_df = pd.DataFrame(data, columns=required_columns)
    df = pd.concat([df, new_df], ignore_index=True)

    return df

In [None]:
#'Equivalent amino acid numbering for subtypes currently circulating in humans or have pandemic potential.' by Burke & Smith (2014)
# (within the binding region)
data = [
        ('H', 'Y', 110),
        ('S', 'N', 126),
        ('S', 'P', 128),
        ('S', 'A', 137),
        ('A', 'V', 138),
        ('G', 'R', 143),
        ('I', 'T', 155),
        ('N', 'D', 158),
        ('T', 'A', 160),
        ('N', 'K', 186),
        ('D', 'G', 187),
        ('E', 'G', 190),
        ('T', 'I', 192),
        ('K', 'R', 193),
        ('Q', 'R or H', 196),
        ('V','I', 214),
        ('Q', 'L', 226),
        ('S', 'N', 227),
        ('G', 'S', 228),
        ('P', 'S', 239)]

known_mutations_df = pd.DataFrame(data, columns = ['original', "mutated", 'H3 index'])


In [None]:
#corresponding indices accros H3, H5 and H7 data types (Burke & Smith, 2014)
numbering_data = [
 (110, 103, 100),
 (126, 121, 116),
 (128, 123, 118),
 (137, 133, 127),
 (138, 134, 128),
 (143, 139, 132),
 (155, 151, 144),
 (158, 154, 147),
 (160, 156, 151),
 (186, 182, 177),
 (187, 183, 178),
 (190, 186, 181),
 (192, 188, 183),
 (193, 189, 184),
 (196, 192, 187),
 (197, 193, 188),
 (214, 210, 205),
 (226, 222, 217),
 (227, 223, 218),
 (228, 224, 219),
 (239, 235, 230)]

df_numbering_scheme = pd.DataFrame(data = numbering_data, columns = ['H3', 'H5', 'H7'])



In [None]:
#method to get the index of a subtype based on its h3 index
def get_subtype_index(subtype, h3_index):
  subtype = subtype.upper()[:2]
  if subtype == "H3":
    return h3_index
  elif subtype not in df_numbering_scheme.columns:
    print("Invalid subtype.")
    return -1
  else:
    res = df_numbering_scheme[df_numbering_scheme['H3'] == h3_index][subtype]
    if res.empty:
      print("Invalid H3 index.")
      return -1
    else:
      return res.iloc[0]

In [None]:
def get_min_max(df_numbering_scheme):
  min_value = df_numbering_scheme.min().min()
  max_value = df_numbering_scheme.max().max()
  return min_value, max_value


In [None]:
#get host from label
def get_label(host):
  host = host.strip().lower()
  if host == 'human':
    return 1
  else:
    return 0

In [None]:
#getting sequence data as one big dataframe that contains information regarding label (host), year, subtype and sequence
df_unprocessed = pd.DataFrame()
df_unprocessed = get_df_from_sequences(df_unprocessed, h3_realigned)
df_unprocessed = get_df_from_sequences(df_unprocessed, h5_realigned)
df_unprocessed = get_df_from_sequences(df_unprocessed, h7_realigned)

In [None]:
#method to eliminate duplicate sequences
def drop_duplicate_sequences(df):
  initial_length = len(df)
  df_dropped = df.drop_duplicates(subset=['label', 'subtype', 'sequence'], keep='first')
  removed = initial_length - len(df_dropped)
  print(f"{removed} / {initial_length} duplicates removed")
  return df_dropped

In [None]:
#method to filter sequences for min and max constraints of the mutations
def filter_sequences_for_length(df):
  min_value, max_value = get_min_max(df_numbering_scheme)
  min_value = int(min_value)
  max_value = int(max_value)
  df['sequence'] = df['sequence'].apply(lambda seq: seq[min_value:max_value+1])
  return df

In [None]:
#filtering dataframe rows and getting new dataframe for processed sequences
df_processed = drop_duplicate_sequences(df_unprocessed)
df_processed = filter_sequences_for_length(df_processed)


In [None]:

df = df_processed

In [None]:
#getting the number of positive and negative labels to ensure an even dataset
def get_positive_and_negative(df):
  neg_count = 0
  pos_count = 0
  for i, row in df.iterrows():
    if row['label'] == 1 :
      pos_count += 1
    elif row['label'] == 0:
      neg_count += 1
    else:
      print('Invalid label: ' + row['label'])

  print("negatives count: " + str(neg_count))
  print("positives count: " + str(pos_count))

In [None]:
#get logo to better analyse sequences (visually)
def get_logo_for_sequences(sequences):
  freq_mat = logomaker.alignment_to_matrix(sequences)
  freq_mat.head()
  logo = logomaker.Logo(freq_mat,
               fade_below=0.5,
               shade_below=0.5,
               color_scheme='skylign_protein',
               figsize = (100,5))

In [None]:
get_logo_for_sequences(df[df['label'] == 0 & df['subtype'].str.contains('H3')]['sequence'])

In [None]:
get_logo_for_sequences(df[df['label'] == 1 & df['subtype'].str.contains('H3')]['sequence'])

In [None]:
get_logo_for_sequences(df[df['label'] == 0 & df['subtype'].str.contains('H5')]['sequence'])

In [None]:
get_logo_for_sequences(df[df['label'] == 1 & df['subtype'].str.contains('H5')]['sequence'])

In [None]:
get_logo_for_sequences(df[df['label'] == 0 & df['subtype'].str.contains('H7')]['sequence'])

In [None]:
get_logo_for_sequences(df[df['label'] == 1 & df['subtype'].str.contains('H7')]['sequence'])

In [None]:
#method that encodes all sequences for known mutations (multiple hot encoding)
def encode_sequences(df, mutations_df):

  feature_names = list()
  encoded_sequences = []
  for _, row in df.iterrows():

    sequence = row['sequence']
    #get first two letter of subtype to infer HX
    subtype = row['subtype']
    sequence_features = [0] * len(mutations_df)
    index = 0

    min, _ = get_min_max(df_numbering_scheme)
    min = min.astype(int)

    for i, mutation_row in mutations_df.iterrows():
      original, mutated, H3_index = mutation_row
      feature_names.append(f"{original}_to_{mutated}_at_:{H3_index}")

      subtype_index = get_subtype_index(subtype, H3_index) - min

      #handle mutations that state 'X or Y' or 'Any'
      if 'or' in mutated:
            allowed_mutations = mutated.split('or')
            if sequence[subtype_index] in allowed_mutations:
              sequence_features[i] = 1
      elif mutated == 'Any':
            sequence_features[i] = 1
      elif sequence[subtype_index] == mutated:
            sequence_features[i] = 1

    encoded_sequences.append(tuple(sequence_features))

  return feature_names, encoded_sequences

In [None]:
#getting feature names and matrices for each sequence in df
feature_names, encoded_sequences = encode_sequences(df, known_mutations_df)

In [None]:
#storing feature matrices as new column
df['encoded sequences'] = encoded_sequences

In [None]:
#visualize feature relations in correlation matrix
data = np.vstack(df['encoded sequences'].apply(np.array))

corr_matrix = pd.DataFrame(data, columns=[known_mutations_df['H3 index']]).corr()

plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, annot_kws={"size": 5}, cmap='coolwarm', fmt='.2f')
plt.title('Correlation between Mutations')
plt.show()

In [None]:
#method for splitting the data into training and testing sets based on input percentage (for the test set)
def split_data(dataframe, percentage):
    df_shuffled = dataframe.sample(frac=1, random_state=42).reset_index(drop=True)
    return jackknife_split(df_shuffled, percentage)
#method applying the jackknife split
def jackknife_split(data, percentage):
    test_size = int((1 - percentage) * len(data))
    test_data = data.iloc[:-test_size]
    train_data = data.iloc[-test_size:]
    return train_data, test_data



In [None]:

#use SMOTE to oversample synthetic data, balancing h3, h5, and h7 training data
def smote_balance_data(h3_data, h5_data, h7_data):
  target_count = len(h3_data)

  x_h5 = np.array(h5_data['encoded sequences'].tolist())
  y_h5 = np.array(h5_data['label'].tolist())

  x_h7 = np.array(h7_data['encoded sequences'].tolist())
  y_h7 = np.array(h7_data['label'].tolist())

  smote = SMOTE(random_state=42)

  x_h5_resampled, y_h5_resampled = smote.fit_resample(x_h5, y_h5)

  x_h5_resampled = x_h5_resampled[:target_count]
  y_h5_resampled = y_h5_resampled[:target_count]

  x_h7_resampled, y_h7_resampled = smote.fit_resample(x_h7, y_h7)

  x_h7_resampled = x_h7_resampled[:target_count]
  y_h7_resampled = y_h7_resampled[:target_count]

  h5_resampled = pd.DataFrame({
      'encoded sequences': list(x_h5_resampled),
      'label': y_h5_resampled,
      'subtype': ['H5'] * len(y_h5_resampled)
  })

  h7_resampled = pd.DataFrame({
      'encoded sequences': list(x_h7_resampled),
      'label': y_h7_resampled,
      'subtype': ['H7'] * len(y_h7_resampled)
  })

  balanced_data = pd.concat([h3_data, h5_resampled, h7_resampled], ignore_index=True)



  return balanced_data



In [None]:
# forming the training and testing datasets
h3_train = df[(df['subtype'].str.contains('H3'))]
h5_train, h5_test = split_data(df[(df['subtype'].str.contains('H5'))], 0.2)
h7_train, h7_test = split_data(df[(df['subtype'].str.contains('H7'))], 0.2)

train_data = smote_balance_data(h3_train, h5_train, h7_train)

test_data = pd.concat([h5_test, h7_test])


x_train = np.array(train_data['encoded sequences'].tolist())
y_train = np.array(train_data["label"].tolist())
x_test = np.array(test_data['encoded sequences'].tolist())
y_test = np.array(test_data["label"]).tolist()


In [None]:
print("Number of training sequences: " + str(len(train_data)))
print("Number of testing sequences: " + str(len(test_data)))

In [None]:
#get names of poly features based on mutations
def get_poly_feature_names(poly, x_train, feature_names):
  original_feature_names = [feature_names[i] for i in range(x_train.shape[1])]
  poly_feature_names = poly.get_feature_names_out(original_feature_names)
  poly_feature_names = [name.replace(' ', '&').replace('.', '&') for name in poly_feature_names]
  return pd.DataFrame({'feature' : x_train_poly[1], 'name': poly_feature_names})

In [None]:
#transform data
poly = PolynomialFeatures(degree=3)
x_train_poly = poly.fit_transform(x_train)
x_test_poly = poly.transform(x_test)
poly_feature_names = get_poly_feature_names(poly, x_train, feature_names)

In [None]:
#method for selection of poly features with p > 0.05
def select_from_poly(poly_features, x_train_poly, y_train, alpha=0.05, score_func=f_classif):
  selector = SelectFpr(score_func=score_func, alpha= alpha)
  selector.fit(x_train_poly, y_train)
  return selector


In [None]:
#select and transform features
selector = select_from_poly(poly_feature_names, x_train_poly, y_train)
x_train = selector.transform(x_train_poly)
x_test = selector.transform(x_test_poly)

In [None]:
#view selected features
selected_features = selector.get_support()
selected_feature_names = poly_feature_names[selected_features]
selected_feature_names.reset_index()

In [None]:
#construct logistic regression model
log_reg = LogisticRegression(random_state=42)
param_grid = {
    'C': [0.001, 0.01, 0.1, 1, 10, 100],
    'penalty': ['l2'],
    'solver': ['lbfgs', 'newton-cholesky'],
    'max_iter': [1000, 2000],
    'class_weight': ['balanced']
}

In [None]:
#construct linear svc model
svm = SVC(kernel='linear', probability=True,  class_weight='balanced', decision_function_shape= 'ovr', random_state= 42)
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],
              'tol': [0.0001, 0.001, 0.01],
              'max_iter': [5000],
             }


In [None]:
#construct random forest model
rf = RandomForestClassifier(random_state=42, class_weight = 'balanced')

param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [3, 5, 7],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
}


In [None]:
#construct xgboost model
xgb_clf = xgb.XGBClassifier(objective='binary:logistic', random_state=42)

param_grid = {
    'booster': ['gbtree', 'dart'],
    'n_jobs': [-1],
    'n_estimators': [200],
    'max_depth': [5, 7],
    'learning_rate': [0.05],
    'subsample': [0.7],
    'colsample_bytree': [0.7],
    'min_child_weight': [3],
    'gamma': [0.1, 0.2]
}


In [None]:
#method to train and evaluate given model using gridsearch
def train_and_evaluate_model(selected_model, param_grid, x_train, y_train, x_test, y_test):
  grid_search = GridSearchCV(selected_model, param_grid, cv=5, scoring='f1_weighted')

  grid_search.fit(x_train, y_train)

  model = grid_search.best_estimator_

  y_pred = model.predict(x_test)

  y_probs = model.predict_proba(x_test)[:, 1]

  print(f"Best Validation Score: {grid_search.best_score_}")

  return model, y_pred, y_probs


In [None]:
#method to print classification report of a model
def get_classification_report(y_test, y_pred):
  class_report = classification_report(y_test, y_pred, digits=4)
  print("Classification Report:\n", class_report)

In [None]:
#train desired model (inputted: xgboost)
model, y_pred, y_probs = train_and_evaluate_model(xgb_clf, param_grid, x_train, y_train, x_test, y_test)

In [None]:
#get predictive metrics
get_classification_report(y_test, y_pred)

In [None]:
#method to plot precision-recall
def get_precision_recall_curve(y_test, y_probs):
  precision, recall, thresholds = precision_recall_curve(y_test, y_probs)

  # Plot Precision-Recall curve
  plt.plot(recall, precision, marker='.')
  plt.xlabel('Recall')
  plt.ylabel('Precision')
  plt.title('Precision-Recall Curve')
  plt.show()

In [None]:
get_precision_recall_curve(y_test, y_probs)

In [None]:
#method to plot roc and calculate auc
def get_roc_curve(y_test, y_probs):
  # Compute ROC curve
  fpr, tpr, thresholds = roc_curve(y_test, y_probs)

  # Compute AUC
  auc = roc_auc_score(y_test, y_probs)
  print(f"AUC: {auc:.2f}")

  # Plot ROC curve
  plt.figure(figsize=(8, 6))
  plt.plot(fpr, tpr, color='blue', label=f'AUC = {auc:.2f}')
  plt.plot([0, 1], [0, 1], color='red', linestyle='--')
  plt.xlabel('False Positive Rate')
  plt.ylabel('True Positive Rate')
  plt.title('ROC Curve')
  plt.legend()
  plt.show()


In [None]:
get_roc_curve(y_test, y_probs)

In [None]:
#method to print confusion matrix showing TP, TN, FN & FP values
def get_conf_matrix(y_test, y_pred):
  conf_matrix = confusion_matrix(y_test, y_pred)
  print("Confusion Matrix:\n")
  print("True Positives: " + str(conf_matrix[0][0]))
  print("False Positives: " + str(conf_matrix[0][1]))
  print("False Negatives: " + str(conf_matrix[1][0]))
  print("True Negatives: " + str(conf_matrix[1][1]))
  return conf_matrix



In [None]:
cm = get_conf_matrix(y_test, y_pred)

In [None]:
#method to plot the confusion matrix into a heatmap for better visualization
def plot_conf_matrix(cm):
  cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

  plt.figure(figsize=(8, 6))
  sns.heatmap(cm_normalized, annot=True, fmt=".4f", cmap="Blues", cbar=False)
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  plt.title('Normalized Confusion Matrix')
  plt.show()


In [None]:
plot_conf_matrix(cm)

In [None]:
#method to get weights of a model
def get_coeffs(model):
  if hasattr(model, 'coef_'):
    return model.coef_[0]
  elif hasattr(model, 'feature_importances_'):
    return model.feature_importances_

In [None]:
#print final features and their determined coefficients
features_to_investigate = []
for i in range(0, len(get_coeffs(model))):
  if (abs(float(get_coeffs(model)[i])) > 0.0000):
    features_to_investigate.append(selected_feature_names.iloc[i]['name'])
  print("Feature :" +  str(selected_feature_names.iloc[i]['name']) + ", Coeff: " + f"{(get_coeffs(model)[i]):.3f}")


In [None]:
#get py3dmol view of a pdb protein
def get_protein_view(pdb_file_name, color = 'lightgrey'):
  pdb_data = ""
  with open(os.path.join(pdb_directory, pdb_file_name + ".pdb"), 'r') as file:
    pdb_data = file.read()
  view = py3Dmol.view()
  view.addModel(pdb_data, 'pdb')
  view.setStyle({'cartoon': {'color': color }})
  return view

In [None]:
def normalize(value, min_value, max_value):
    if max_value == min_value:
        min_value = -1
        max_value = 1
    return (value - min_value) / (max_value - min_value)

In [None]:
#method to visualize antigenic regions on a h3n2 human protein
def highlight_antigenic_regions(view):
  antigenic_sites = {
    "A": [(122, 130), (154, 160)],
    "B": [(155, 165), (189, 196)],
    "C": [(50, 58)],
    "D": [(160, 167), (200, 207)],
    "E": [(77, 83)],
    "F": [(220, 230)],
  }

  site_colors = {
    "A": "purple",
    "B": "black",
    "C": "green",
    "D": "yellow",
    "E": "magenta",
    "F": "orange"
  }

  for site, ranges in antigenic_sites.items():
    for res_range in ranges:
        start_res, end_res = res_range
        selection = f"resi {start_res}-{end_res}"  # Select the residue range
        view.addStyle({'chain': 'A', 'resi': list(range(start_res, end_res + 1))},  # Apply style to these residues
                      {'cartoon': {'color': site_colors[site]}})  # Apply color
  return view


In [None]:
#method to visualize mutation residues
def visualize_mutations(view, mutations):
    for res in mutations:
              view.addStyle({'resi': [res]} , {'stick': {'color': 'pink'}})
              view.zoomTo()

    return view


In [None]:
#method to highlight HA1, HA2, and RBS regions of a h3n2 human HA protein
def highlight_regions(view):
  # Define residue ranges for HA1, HA2, and RBS regions
  ha1_range = [(1, 329)]
  ha2_range = [(330, 566)]
  rbs_ranges = [(98,100), (130,133), (220,223)]  # Example RBS range


  for ranges in ha1_range:
      view.addStyle({'chain': 'A', 'resi': list(range(ranges[0], ranges[1] + 1))}, {'cartoon': {'color': 'blue'}})

  for ranges in ha2_range:
        view.addStyle({'chain': 'A', 'resi': list(range(ranges[0], ranges[1] + 1))}, {'cartoon': {'color': 'red'}})

  # Highlight RBS in green
  for ranges in rbs_ranges:
        view.addStyle({'chain': 'A', 'resi': list(range(ranges[0], ranges[1] + 1))}, {'cartoon': {'color': 'green'}})

  return view


In [None]:
#layer all viewings together or visualize one at a time by commenting lines out
view = get_protein_view("h3n2_human")
view_linear = visualize_mutations(view, [192, 226, 227])
view_linear = highlight_antigenic_regions(view_linear)
view_linear = highlight_regions(view_linear)
view_linear.show()

In [None]:

#method to visualize data overview
def plot_subtype_distribution(df, title):
  """
  Plots the subtype distribution for a given DataFrame, colored by host.
  """
  subtype_counts = df.groupby(['subtype', 'label'])['label'].count().unstack().fillna(0)
  subtype_counts_percent = subtype_counts.div(subtype_counts.sum(axis=1), axis=0) * 100

  plt.figure(figsize=(12, 6))
  ax = subtype_counts_percent.plot(kind='bar', stacked=True, color=['skyblue', 'salmon'])

  plt.title(title)
  plt.xlabel('Subtype')
  plt.ylabel('Percentage')
  plt.xticks(rotation=45, ha='right')
  plt.legend(title='Host', labels=['Avian', 'Human'])

  # Annotate bars with percentages
  for p in ax.patches:
    width = p.get_width()
    height = p.get_height()
    x, y = p.get_xy()
    ax.annotate(f'{height:.1f}%', (x + width / 2, y + height / 2), ha='center', va='center')

  plt.show()

# Plot subtype distribution for training data
plot_subtype_distribution(train_data.assign(subtype=train_data['subtype'].str[:2]), 'Host Distribution in Training Data (Per Subtype)')

# Plot subtype distribution for test data
plot_subtype_distribution(test_data.assign(subtype=test_data['subtype'].str[:2]), 'Host Distribution in Test Data (Per Subtype)')


In [None]:
model_results = {
    "Logistic Regression": """
        Feature :H_to_Y_at_:110, Coeff: -0.013
Feature :I_to_T_at_:155, Coeff: 0.018
Feature :T_to_I_at_:192, Coeff: 0.243
Feature :Q_to_L_at_:226, Coeff: 0.049
Feature :S_to_N_at_:227, Coeff: 0.036
Feature :G_to_S_at_:228, Coeff: 0.004
Feature :H_to_Y_at_:110^2, Coeff: -0.013
Feature :H_to_Y_at_:110&N_to_K_at_:186, Coeff: -0.012
Feature :H_to_Y_at_:110&T_to_I_at_:192, Coeff: -0.013
Feature :I_to_T_at_:155^2, Coeff: 0.018
Feature :I_to_T_at_:155&Q_to_L_at_:226, Coeff: 0.018
Feature :T_to_I_at_:192^2, Coeff: 0.243
Feature :Q_to_L_at_:226^2, Coeff: 0.049
Feature :Q_to_L_at_:226&P_to_S_at_:239, Coeff: -0.002
Feature :S_to_N_at_:227^2, Coeff: 0.036
Feature :G_to_S_at_:228^2, Coeff: 0.004
Feature :H_to_Y_at_:110^3, Coeff: -0.013
Feature :H_to_Y_at_:110^2&N_to_K_at_:186, Coeff: -0.012
Feature :H_to_Y_at_:110^2&T_to_I_at_:192, Coeff: -0.013
Feature :H_to_Y_at_:110&N_to_K_at_:186^2, Coeff: -0.012
Feature :H_to_Y_at_:110&N_to_K_at_:186&T_to_I_at_:192, Coeff: -0.012
Feature :H_to_Y_at_:110&T_to_I_at_:192^2, Coeff: -0.013
Feature :I_to_T_at_:155^3, Coeff: 0.018
Feature :I_to_T_at_:155^2&Q_to_L_at_:226, Coeff: 0.018
Feature :I_to_T_at_:155&Q_to_L_at_:226^2, Coeff: 0.018
Feature :T_to_I_at_:192^3, Coeff: 0.243
Feature :Q_to_L_at_:226^3, Coeff: 0.049
Feature :Q_to_L_at_:226^2&P_to_S_at_:239, Coeff: -0.002
Feature :Q_to_L_at_:226&P_to_S_at_:239^2, Coeff: -0.002
Feature :S_to_N_at_:227^3, Coeff: 0.036
Feature :G_to_S_at_:228^3, Coeff: 0.004
    """,
    "SVC with Linear Kernel": """
       Feature :H_to_Y_at_:110, Coeff: 0.000
Feature :I_to_T_at_:155, Coeff: 0.000
Feature :T_to_I_at_:192, Coeff: 0.667
Feature :Q_to_L_at_:226, Coeff: -0.222
Feature :S_to_N_at_:227, Coeff: -0.444
Feature :G_to_S_at_:228, Coeff: 0.173
Feature :H_to_Y_at_:110^2, Coeff: 0.000
Feature :H_to_Y_at_:110&N_to_K_at_:186, Coeff: 0.000
Feature :H_to_Y_at_:110&T_to_I_at_:192, Coeff: 0.000
Feature :I_to_T_at_:155^2, Coeff: 0.000
Feature :I_to_T_at_:155&Q_to_L_at_:226, Coeff: 0.000
Feature :T_to_I_at_:192^2, Coeff: 0.667
Feature :Q_to_L_at_:226^2, Coeff: -0.222
Feature :Q_to_L_at_:226&P_to_S_at_:239, Coeff: -0.222
Feature :S_to_N_at_:227^2, Coeff: -0.444
Feature :G_to_S_at_:228^2, Coeff: 0.173
Feature :H_to_Y_at_:110^3, Coeff: 0.000
Feature :H_to_Y_at_:110^2&N_to_K_at_:186, Coeff: 0.000
Feature :H_to_Y_at_:110^2&T_to_I_at_:192, Coeff: 0.000
Feature :H_to_Y_at_:110&N_to_K_at_:186^2, Coeff: 0.000
Feature :H_to_Y_at_:110&N_to_K_at_:186&T_to_I_at_:192, Coeff: 0.000
Feature :H_to_Y_at_:110&T_to_I_at_:192^2, Coeff: 0.000
Feature :I_to_T_at_:155^3, Coeff: 0.000
Feature :I_to_T_at_:155^2&Q_to_L_at_:226, Coeff: 0.000
Feature :I_to_T_at_:155&Q_to_L_at_:226^2, Coeff: 0.000
Feature :T_to_I_at_:192^3, Coeff: 0.667
Feature :Q_to_L_at_:226^3, Coeff: -0.222
Feature :Q_to_L_at_:226^2&P_to_S_at_:239, Coeff: -0.222
Feature :Q_to_L_at_:226&P_to_S_at_:239^2, Coeff: -0.222
Feature :S_to_N_at_:227^3, Coeff: -0.444
Feature :G_to_S_at_:228^3, Coeff: 0.173
    """,
    "Random Forest": """
        Feature :H_to_Y_at_:110, Coeff: 0.006
Feature :I_to_T_at_:155, Coeff: 0.007
Feature :T_to_I_at_:192, Coeff: 0.189
Feature :Q_to_L_at_:226, Coeff: 0.042
Feature :S_to_N_at_:227, Coeff: 0.058
Feature :G_to_S_at_:228, Coeff: 0.005
Feature :H_to_Y_at_:110^2, Coeff: 0.004
Feature :H_to_Y_at_:110&N_to_K_at_:186, Coeff: 0.002
Feature :H_to_Y_at_:110&T_to_I_at_:192, Coeff: 0.004
Feature :I_to_T_at_:155^2, Coeff: 0.008
Feature :I_to_T_at_:155&Q_to_L_at_:226, Coeff: 0.006
Feature :T_to_I_at_:192^2, Coeff: 0.227
Feature :Q_to_L_at_:226^2, Coeff: 0.038
Feature :Q_to_L_at_:226&P_to_S_at_:239, Coeff: 0.001
Feature :S_to_N_at_:227^2, Coeff: 0.054
Feature :G_to_S_at_:228^2, Coeff: 0.006
Feature :H_to_Y_at_:110^3, Coeff: 0.004
Feature :H_to_Y_at_:110^2&N_to_K_at_:186, Coeff: 0.001
Feature :H_to_Y_at_:110^2&T_to_I_at_:192, Coeff: 0.008
Feature :H_to_Y_at_:110&N_to_K_at_:186^2, Coeff: 0.002
Feature :H_to_Y_at_:110&N_to_K_at_:186&T_to_I_at_:192, Coeff: 0.001
Feature :H_to_Y_at_:110&T_to_I_at_:192^2, Coeff: 0.004
Feature :I_to_T_at_:155^3, Coeff: 0.008
Feature :I_to_T_at_:155^2&Q_to_L_at_:226, Coeff: 0.011
Feature :I_to_T_at_:155&Q_to_L_at_:226^2, Coeff: 0.005
Feature :T_to_I_at_:192^3, Coeff: 0.194
Feature :Q_to_L_at_:226^3, Coeff: 0.040
Feature :Q_to_L_at_:226^2&P_to_S_at_:239, Coeff: 0.001
Feature :Q_to_L_at_:226&P_to_S_at_:239^2, Coeff: 0.002
Feature :S_to_N_at_:227^3, Coeff: 0.056
Feature :G_to_S_at_:228^3, Coeff: 0.005
    """,

    "XGBoost": """
    Feature :H_to_Y_at_:110, Coeff: 0.052
Feature :I_to_T_at_:155, Coeff: 0.041
Feature :T_to_I_at_:192, Coeff: 0.142
Feature :Q_to_L_at_:226, Coeff: 0.040
Feature :S_to_N_at_:227, Coeff: 0.117
Feature :G_to_S_at_:228, Coeff: 0.000
Feature :H_to_Y_at_:110^2, Coeff: 0.018
Feature :H_to_Y_at_:110&N_to_K_at_:186, Coeff: 0.000
Feature :H_to_Y_at_:110&T_to_I_at_:192, Coeff: 0.047
Feature :I_to_T_at_:155^2, Coeff: 0.028
Feature :I_to_T_at_:155&Q_to_L_at_:226, Coeff: 0.045
Feature :T_to_I_at_:192^2, Coeff: 0.200
Feature :Q_to_L_at_:226^2, Coeff: 0.036
Feature :Q_to_L_at_:226&P_to_S_at_:239, Coeff: 0.000
Feature :S_to_N_at_:227^2, Coeff: 0.058
Feature :G_to_S_at_:228^2, Coeff: 0.000
Feature :H_to_Y_at_:110^3, Coeff: 0.000
Feature :H_to_Y_at_:110^2&N_to_K_at_:186, Coeff: 0.000
Feature :H_to_Y_at_:110^2&T_to_I_at_:192, Coeff: 0.000
Feature :H_to_Y_at_:110&N_to_K_at_:186^2, Coeff: 0.000
Feature :H_to_Y_at_:110&N_to_K_at_:186&T_to_I_at_:192, Coeff: 0.000
Feature :H_to_Y_at_:110&T_to_I_at_:192^2, Coeff: 0.000
Feature :I_to_T_at_:155^3, Coeff: 0.030
Feature :I_to_T_at_:155^2&Q_to_L_at_:226, Coeff: 0.000
Feature :I_to_T_at_:155&Q_to_L_at_:226^2, Coeff: 0.000
Feature :T_to_I_at_:192^3, Coeff: 0.060
Feature :Q_to_L_at_:226^3, Coeff: 0.045
Feature :Q_to_L_at_:226^2&P_to_S_at_:239, Coeff: 0.000
Feature :Q_to_L_at_:226&P_to_S_at_:239^2, Coeff: 0.000
Feature :S_to_N_at_:227^3, Coeff: 0.042
Feature :G_to_S_at_:228^3, Coeff: 0.000
"""

}

# Function to parse the model results
def parse_model_results(model_results):
    data = []

    for model_name, results in model_results.items():
        # Extract features and coefficients using regex
        for line in results.strip().split('\n'):
            match = re.search(r'Feature :(.+), Coeff: (.+)', line)
            if match:
                feature_name = match.group(1).strip()
                coeff = float(match.group(2).strip())

                # Only include features with non-zero coefficients
                if coeff != 0:
                    data.append((model_name, feature_name, coeff))

    return pd.DataFrame(data, columns=["Model", "Feature", "Coefficient"])

# Create the DataFrame
importance_df = parse_model_results(model_results)


In [None]:
#compare feature importances for linear and tree-based models
comparison_values = []

for index, row in importance_df.iterrows():

    # Extract the feature importances or coefficients for the current model
    model_importances = importance_df[importance_df['Model'] == row['Model']]['Coefficient'].values

    # Check if the model is linear (Logistic Regression or SVC with linear kernel)
    if row['Model'] in ['Logistic Regression', 'SVM with Linear Kernel']:
        # Take the absolute values of the coefficients
        model_importances = abs(model_importances)

    # Normalize the importances/coefficients
    normalized = model_importances / np.sum(model_importances)

    # Append the normalized value for the current feature to the list
    model_specific_index = index - importance_df[importance_df['Model'] == row['Model']].index[0]
    comparison_values.append(normalized[model_specific_index])

# Add the new comparison column to the DataFrame
importance_df['comparison'] = comparison_values


In [None]:
visualize the results of importance_df for each model

plt.figure(figsize=(12, 6))
sns.barplot(x='Feature', y='comparison', hue='Model', data=importance_df)
plt.xticks(rotation=90)
plt.title('Feature Importance Comparison Across Models')
plt.xlabel('Feature')
plt.ylabel('Importance (Normalized)')
plt.tight_layout()
plt.show()
