In [1]:
import pandas as pd
import numpy as np

full_df = pd.read_csv('CoV-AbDab_080224.csv')
ab_df = full_df[full_df['VHorVHH'] != 'ND']
ab_df = full_df[full_df['VL'] != 'ND']
ab_df = ab_df[ab_df['Ab or Nb'] == 'Ab']
ab_df = ab_df[['Neutralising Vs', 'Not Neutralising Vs', 'VHorVHH', 'VL', 'Origin', 'Protein + Epitope', 'Name']]
ab_df = ab_df.fillna('')
ab_df.reset_index(inplace=True)

In [74]:
sources_to_keep = ['B-cells; SARS-CoV2_WT Convalescent Patient (Unvaccinated)', 
                         'B-cells; SARS-CoV1 Human Patient; SARS-CoV2 Vaccinee',
                         'B-cells; SARS-CoV2_WT Convalescent Patients',
                         'B-cells; SARS-CoV2_WT Vaccinee (BBIBP-CoV)',
                         'B-cells; SARS-CoV2_WT Vaccinee',
                         'B-cells; SARS-CoV2_WT Human Patient',
                         'B-cells; Unvaccinated SARS-CoV2_WT Human Patient',
                         'B-cells; SARS-CoV2_Gamma Human Patient',
                         'B-cells; SARS-CoV1 Human Patient',
                         'B-cells (SARS-CoV2_Beta Human Patient)'
                        ]

binding_to_keep = ["S; RBD", "S: RBD", "S; RBD/NTD"]

In [75]:
source_df = ab_df[ab_df['Origin'].isin(sources_to_keep)]
source_df = source_df[source_df['Protein + Epitope'].isin(binding_to_keep)]

In [76]:
source_df['Protein + Epitope'].value_counts()

S; RBD        2207
S: RBD           1
S; RBD/NTD       1
Name: Protein + Epitope, dtype: int64

In [77]:
heavy_chains = []
light_chains = []
antigens = []
target = []
names = []
origins = []

for i,row in source_df.iterrows():
    neut_ags = row['Neutralising Vs'].split(';')
    no_neut_ags = row['Not Neutralising Vs'].split(';')
    name = row['Name']
    origin = row['Origin']

    hc = row['VHorVHH']
    lc = row['VL']

    for n in neut_ags:
        if n != '':
            heavy_chains.append(hc)
            light_chains.append(lc)
            antigens.append(n)
            target.append(1)
            names.append(name)
            origins.append(origin)

    for n in no_neut_ags:
        if n != '':
            heavy_chains.append(hc)
            light_chains.append(lc)
            antigens.append(n)
            target.append(0)
            names.append(name)
            origins.append(origin)

interaction_df = pd.DataFrame({'names': names,
                               'origins': origins, 
                               'heavy': heavy_chains, 
                              'light': light_chains,
                              'antigens': antigens,
                              'target': target})    

In [78]:
grouping_dict = {
    'SARS-CoV2_WT': ['SARS-CoV2_WT', 'SARS-CoV2_WT (weak)', 'SARS-CoV2_WT and SARS-CoV1', 'SARS-CoV2_WT, SARS-CoV2_Delta', 'SARS-CoV2_WT_Delta (weak)', 'SARS-CoV2_WT (weak) , SARS-CoV2_Delta (weak)'],
    'SARS-CoV2_Alpha': ['SARS-CoV2_Alpha', 'SARS-CoV2_Alpha (weak)'],
    'SARS-CoV2_Beta': ['SARS-CoV2_Beta', 'SARS-CoV2_Beta (weak)', 'SRAS-CoV2_Beta'],
    'SARS-CoV2_Delta': ['SARS-CoV2_Delta', 'SARS-CoV2_Delta (weak)'],
    'SARS-CoV2_Epsilon': ['SARS-CoV2_Epsilon', 'SARS-CoV2_Epsilon (weak)'],
    'SARS-CoV2_Gamma': ['SARS_CoV2_Gamma', 'SARS-CoV2_Gamma', 'SARS-CoV2_Gamma (weak)'],
    'SARS-CoV2_Eta': ['SARS-CoV2_Eta', 'SARS-CoV2_Eta (weak)'],
    'SARS-CoV2_Iota': ['SARS-CoV2_Iota', 'SARS-CoV2_Iota (weak)'],
    'SARS-CoV2_Lambda': ['SARS-CoV2_Lambda', 'SARS-CoV2_Lambda (weak)'],
    'SARS-CoV2_Kappa': ['SARS-CoV2_Kappa', 'SARS-CoV2_Kappa (weak)'],
    'SARS-CoV2_Omicron-BA1': ['SARS-CoV2_Omicron-BA1', 'SARS-CoV2_Omicron-BA1 (weak)', 'SARS-CoV2_Omicron_BA1', 'SARS-CoV2_Omicron_BA1.1', 'SARS-CoV2_Omicron-BA1.1 (weak)'],
    'SARS-CoV2_Omicron-BA2': ['SARS-CoV2_Omicron-BA2', 'SARS-CoV2_Omicron-BA2 (weak)', 'SARS_COV2_Omicron-BA2', 'SARS-CoV2_Omicron-BA2.12.1', 'SARS-CoV2_Omicron-BA2.12.1 (weak)', 'SARS-CoV2_Omicron-BA2.75', 'SARS-CoV2_Omicron-BA2.38', 'SARS-CoV2_Omicron-BA2.38 (weak)', 'SARS-CoV2_Omicron-BA2.75.1', 'SARS-CoV2_Omicron-BA2.75.5', 'SARS-CoV2_Omicron-BA2.75.5 (weak)'],
    'SARS-CoV2_Omicron-BA4': ['SARS-CoV2_Omicron-BA4', 'SARS-CoV2_Omicron-BA4 (weak)', 'SARS-CoV2_Omicron-BA4/BA', 'SARS-CoV2_Omicron-BA4.6', 'SARS-CoV2_Omicron-BA4.6 (weak)', 'SARS-CoV2_Omicron-BA4.7', 'SARS-CoV2_Omicron-BA4.7 (weak)'],
    'SARS-CoV2_Omicron-BA5': ['SARS-CoV2_Omicron-BA5', 'SARS-CoV2_Omicron-BA5 (weak)', 'SARS-CoV2_Omicron-BA5.9', 'SARS-CoV2_Omicron-BA5.9 (weak)'],
    'SARS-CoV2_Omicron-XBB': ['SARS-CoV2_Omicron-XBB']
}

test_groups = ['SARS-CoV2_Omicron-BA1', 'SARS-CoV2_Omicron-BA2', 'SARS-CoV2_Omicron-BA4', 
               'SARS-CoV2_Omicron-BA5']

reversed_dict = {}
for key, value_list in grouping_dict.items():
    for value in value_list:
        reversed_dict[value] = key

In [79]:
interaction_df['groups'] = interaction_df['antigens'].apply(lambda x: reversed_dict.get(x, pd.NA))

In [80]:
interaction_df = interaction_df.dropna()

In [81]:
from Bio import SeqIO

records = list(SeqIO.parse("covid_variants.fasta", "fasta"))

In [82]:
variant_dict = {rec.id:str(rec.seq) for rec in records}

In [83]:
interaction_df['covid_seq'] = interaction_df['groups'].apply(lambda x:variant_dict[x])

In [84]:
interaction_df = interaction_df.drop_duplicates(subset=['names', 'groups'])

In [85]:
train_interaction_df = interaction_df[~interaction_df['groups'].isin(test_groups)]
test_interaction_df = interaction_df[interaction_df['groups'].isin(test_groups)]

In [86]:
train_interaction_df.to_csv('processed_data_train.csv')

In [87]:
test_interaction_df.to_csv('processed_data_test.csv')

In [2]:
test_embs = np.load('best_embs_full.npy')

In [3]:
test_interaction_df = pd.read_csv('processed_data_test.csv', index_col=0)

In [4]:
test_interaction_df['prob'] = test_embs

In [12]:
test_interaction_df.to_csv('covid_results.csv')

In [5]:
targets = test_interaction_df['target'].tolist()
labels = test_interaction_df['groups'].tolist()
names = test_interaction_df['names'].tolist()
origins = test_interaction_df['origins'].tolist()

In [6]:
from sklearn.metrics import precision_recall_curve, auc

def calculate_auprc_per_label(y_true, y_pred, label_types):
    # Create a dictionary to store true labels and predictions for each label type
    label_data = {}
    
    for label, true_val, pred_val in zip(label_types, y_true, y_pred):
        if label not in label_data:
            label_data[label] = {'y_true': [], 'y_pred': []}
        
        label_data[label]['y_true'].append(true_val)
        label_data[label]['y_pred'].append(pred_val)
    
    # Calculate AUPRC for each label type
    auprc_per_label = {}
    
    for label, data in label_data.items():
        if len(data['y_true']) > 3:
            precision, recall, _ = precision_recall_curve(data['y_true'], data['y_pred'])
            auprc = auc(recall, precision)
            auprc_per_label[label] = (auprc, np.unique(data['y_true'], return_counts=True)[1])
    return auprc_per_label


import math
import pandas as pd
import numpy as np
# from functools import lru_cache
from typing import List, Tuple, Union, Dict, Any

# import torch
# import torch.nn as nn
# import torch.nn.functional as F

from sklearn.metrics import (
    # accuracy_score, 
    precision_recall_curve, 
    # precision_score, 
    # recall_score, 
    average_precision_score, 
    roc_auc_score,
    cohen_kappa_score,
    # f1_score,
    # fbeta_score,
    # top_k_accuracy_score,
    matthews_corrcoef,
    confusion_matrix,
)


def fmax_score(ys: np.ndarray, preds: np.ndarray, beta = 1.0, pos_label = 1):
    """
    Radivojac, P. et al. (2013). A Large-Scale Evaluation of Computational Protein Function Prediction. Nature Methods, 10(3), 221-227.
    """
    precision, recall, thresholds = precision_recall_curve(y_true = ys, probas_pred = preds, pos_label = pos_label)
    # precision += 1e-4
    # recall += 1e-4
    # f1 = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
    # return np.nanmax(f1), thresholds[np.argmax(f1)]
    numerator = (1 + beta**2) * (precision * recall)
    denominator = ((beta**2 * precision) + recall)
    with np.errstate(divide='ignore', invalid='ignore'):
        fbeta = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=(denominator!=0))
    return np.nanmax(fbeta), thresholds[np.argmax(fbeta)]


def precision_recall_at_k(y: np.ndarray, preds: np.ndarray, k: int, names: np.ndarray = None):
    """ Calculate recall@k, precision@k, and AP@k for binary classification.
    """
    assert preds.shape == y.shape
    assert k > 0
    
    # Sort the scores and the labels by the scores
    sorted_indices = np.argsort(preds.flatten())[::-1]
    sorted_preds = preds[sorted_indices]
    sorted_y = y[sorted_indices]
    if names is not None:
        sorted_names = names[sorted_indices]
    else: sorted_names = None

    # Get the scores of the k highest predictions
    topk_preds = sorted_preds[:k]
    topk_y = sorted_y[:k]
    
    # Calculate the recall@k and precision@k
    recall_k = np.sum(topk_y, axis=-1) / np.sum(y, axis=-1)
    precision_k = np.sum(topk_y, axis=-1) / k
    
    # Calculate the AP@k
    ap_k = average_precision_score(topk_y, topk_preds)

    if k > preds.shape[-1]:
        recall_k = np.nan
        precision_k = np.nan
        ap_k = np.nan

    return recall_k, precision_k, ap_k, (sorted_y, sorted_preds, sorted_names)


def get_metrics_binary(preds, ys, k, verbose=False, logger=None, context=None):
    """ Wrapper for getting binary classification metrics. If k is a float, then get top k*100% of predictions.
    """
    if type(k) is float and k < 1:
        k = int(k * ys.shape[0])
    
    # Efficiently compute all these metrics together
    tn, fp, fn, tp = confusion_matrix(ys, np.round(preds)).ravel()
    specificity = np.divide(tn, (tn + fp))
    recall = np.divide(tp, (tp + fn))
    with np.errstate(divide='ignore', invalid='ignore'):
        npv = np.divide(tn, (tn + fn))
        precision = np.divide(tp, (tp + fp))
        f1 = np.divide(2 * precision * recall, (precision + recall))
    accuracy = (tp + tn) / (tn + fn + tp + fp)
    
    fmax, _ = fmax_score(ys, preds)
    recall_k, precision_k, ap_k, _ = precision_recall_at_k(ys, preds, k)
    auroc_score = roc_auc_score(ys, preds)
    auprc_score = average_precision_score(ys, preds)
    mcc = matthews_corrcoef(ys, np.round(preds))
    
    metrics_dict = {
        "fmax": fmax,
        "mcc": mcc,
        "auroc": auroc_score,
        "auprc": auprc_score,
        "npv": npv,
        "specificity": specificity,
        "f1": f1,
        f"recall@{k}": recall_k,
        f"precision@{k}": precision_k,
        f"ap@{k}": ap_k,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
    }
    
    if context is not None and context == "multiclass":
        # Cross-entropy loss
        # ce_loss = F.cross_entropy(preds, labels, reduction='mean')
        cohen_kappa = cohen_kappa_score(ys, np.round(preds))
        metrics_dict["cohen_kappa"] = cohen_kappa

    if verbose:
        metrics_str = ', '.join([f'{key} = {value:.4f}' for key, value in metrics_dict.items()])
        if logger is None:
            print(metrics_str)
        else:
            logger.info(metrics_str)
    
    return tuple(metrics_dict.values()), tuple(metrics_dict.keys())


def get_metrics_for_indices(preds, ys, indices, k):
    valid_ys = ys[indices]
    valid_preds = preds[indices]
    metrics, metric_names = get_metrics_binary(valid_preds, valid_ys, k, verbose=False)
    return metrics, metric_names


def get_metrics(preds: np.ndarray, ys: np.ndarray, labels: np.ndarray, k: Union[int, float] = 50, task: str = 'multilabel', logger: Any = None, average: str = "macro", verbose: bool = True) -> Tuple[Dict[str, Union[np.ndarray, float]], Union[np.ndarray, int]]:
    """ Wrapper for getting classification metrics. 
    Binary & Multilabel: Accuracy, AUROC, AUPRC, precision, recall, recall@50, precision@50, ap@50, fmax, f1
    Multiclass: Accuracy, AUROC, AUPRC, precision, recall, recall@50, precision@50, ap@50, fmax, f1, Cohen's kappa
    
    Args:
        preds: Predictions from model for each sample
        ys: Indicators for each sample being positive or negative
        labels: Labels for each sample
        k: Number of top predictions to consider for recall@k, precision@k, and ap@k
        task: Type of classification task, one of 'binary', 'multilabel', or 'multiclass'
        logger: Logger to use for printing metrics
        average: Type of averaging to use for multilabel or multiclass metrics, one of 'macro', 'weighted', 'micro', or None
        verbose: Whether to print metrics or not
    
    Returns:
        metrics_dict: Dictionary of metrics
        pos_samples: Number of positive samples for each class
    """
    assert average is None or average in ["macro", "weighted", "micro"]
    
    if task == 'binary':
        metrics, metric_names = get_metrics_binary(preds, ys, k, verbose=verbose, logger=logger)
        pos_samples = sum(ys)
        
    else:
        # compute macro metrics. NOTE: We use sort + return_index & return_counts to efficiently get all indices of unique values, ref: https://stackoverflow.com/questions/30003068/how-to-get-a-list-of-all-indices-of-repeated-elements-in-a-numpy-array
        # creates an array of indices, sorted by unique element
        idx_sort = np.argsort(labels)
        # sorts records array so all unique elements are together 
        sorted_labels = labels[idx_sort]
        # returns the unique values, the index of the first occurrence of a value, and the count for each element
        vals, idx_start, count = np.unique(sorted_labels, return_counts=True, return_index=True)
        # splits the indices into separate arrays
        indices_grouped_list = np.split(idx_sort, idx_start[1:])
        
        pos_samples = np.array([sum(ys[indices]) for indices in indices_grouped_list])
        
        if average == 'micro':
            metrics, metric_names = get_metrics_binary(preds, ys, k, verbose=False)
            
        else:
            metrics_list = []
            for indices in indices_grouped_list:
                metrics, metric_names = get_metrics_for_indices(preds, ys, indices, k)
                metrics_list.append(metrics)
                
            if average == 'macro':
                metrics = np.array(metrics_list).mean(axis=0)
            elif average == 'weighted':
                metrics = np.array(metrics_list).T @ pos_samples / pos_samples.sum()
            elif average is None:  # output label-stratified scores without averaging
                metrics = np.array(metrics_list).T  # from (n_classes, n_metrics) to (n_metrics, n_classes)
                
        if verbose and average is not None:
            metrics_str = ', '.join([f'{key} = {value:.4f}' for key, value in zip(metric_names, metrics)])
            if logger is not None:
                logger.info(metrics_str)

    metrics_dict = dict(zip(metric_names, metrics))
    return metrics_dict, pos_samples


In [7]:
calculate_auprc_per_label(targets, test_embs, labels)

{'SARS-CoV2_Omicron-BA2': (0.6606509285766826, array([203, 144])),
 'SARS-CoV2_Omicron-BA1': (0.638117715182259, array([253, 160])),
 'SARS-CoV2_Omicron-BA4': (0.7433847677946852, array([24, 19])),
 'SARS-CoV2_Omicron-BA5': (0.5383124959233341, array([192,  57]))}

In [8]:
get_metrics(np.array(test_embs), np.array(targets), np.array(labels), k=50, 
    task='multiclass',
    average=None,
    logger=None, 
    verbose=False)

({'fmax': array([0.64987406, 0.66129032, 0.7       , 0.58267717]),
  'mcc': array([0.06195148, 0.09040079, 0.19871523, 0.11654331]),
  'auroc': array([0.75353261, 0.73611111, 0.77192982, 0.78234649]),
  'auprc': array([0.6406075 , 0.66272152, 0.75208117, 0.54591136]),
  'npv': array([0.61407767, 0.5884058 , 0.58974359, 0.77419355]),
  'specificity': array([1.        , 1.        , 0.95833333, 1.        ]),
  'f1': array([0.01242236, 0.02739726, 0.26086957, 0.03448276]),
  'recall@50': array([0.225     , 0.22222222,        nan, 0.47368421]),
  'precision@50': array([0.72, 0.64,  nan, 0.54]),
  'ap@50': array([0.77454618, 0.84837508,        nan, 0.68229611]),
  'accuracy': array([0.61501211, 0.5907781 , 0.60465116, 0.7751004 ]),
  'precision': array([1.  , 1.  , 0.75, 1.  ]),
  'recall': array([0.00625   , 0.01388889, 0.15789474, 0.01754386])},
 array([160, 144,  19,  57]))

In [9]:
calculate_auprc_per_label(targets, test_embs, origins)

{'B-cells; SARS-CoV2_WT Convalescent Patient (Unvaccinated)': (0.7638888888888888,
  array([1, 3])),
 'B-cells; SARS-CoV2_WT Convalescent Patients': (0.7541249810180233,
  array([36, 48])),
 'B-cells; SARS-CoV2_WT Vaccinee (BBIBP-CoV)': (0.5212567766441261,
  array([47, 29])),
 'B-cells; SARS-CoV2_WT Vaccinee': (0.914950677987415, array([ 11, 105])),
 'B-cells; SARS-CoV2_WT Human Patient': (0.38736685044959596,
  array([497, 162])),
 'B-cells; Unvaccinated SARS-CoV2_WT Human Patient': (0.25, array([3, 1])),
 'B-cells; SARS-CoV2_Gamma Human Patient': (0.7243936431689466,
  array([37, 13])),
 'B-cells; SARS-CoV1 Human Patient': (0.7285714285714286, array([22,  4])),
 'B-cells (SARS-CoV2_Beta Human Patient)': (0.7629801664660999,
  array([18, 15]))}

In [10]:
get_metrics(np.array(test_embs), np.array(targets), np.array(origins), k=50, 
    task='multiclass',
    average=None,
    logger=None, 
    verbose=False)

({'fmax': array([0.71428571, 0.66666667, 0.69565217, 0.85714286, 0.78504673,
         0.44517185, 0.95022624, 0.57731959, 0.66666667]),
  'mcc': array([ 0.27824334,  0.        ,  0.        ,  0.        ,  0.        ,
          0.13687686,  0.03018233, -0.09070254,  0.        ]),
  'auroc': array([0.75185185, 0.92045455, 0.88357588, 0.33333333, 0.72164352,
         0.63961547, 0.59307359, 0.63609685, 0.66666667]),
  'auprc': array([0.77003896, 0.75      , 0.7378875 , 0.80555556, 0.7588613 ,
         0.38990009, 0.91642961, 0.54309175, 0.5       ]),
  'npv': array([0.58064516, 0.84615385, 0.74      , 0.25      , 0.42857143,
         0.75877863, 0.09565217, 0.61333333, 0.75      ]),
  'specificity': array([1.       , 1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 0.9787234, 1.       ]),
  'f1': array([0.23529412,        nan,        nan,        nan,        nan,
         0.04819277, 0.01886792,        nan,        nan]),
  'recall@50': array([       nan,        na