In [1]:
import os
import sys
import json

import random
import numpy as np
import scipy as sp
import pandas as pd

from collections import Counter, defaultdict

import seaborn as sns
import matplotlib.pyplot as plt

import sklearn.metrics

In [2]:
def compute_matrices(m1, m2, num_reports):
    
    A = (m1.T @ m2)
    B = (m1.sum(0).reshape((m1.shape[1],1)) - A)
    B[B==0] = 1
    C = (m2.sum(0) - A)
    C[C==0] = 1
    D = (num_reports-(A+B+C))
    
    PRR = ((A/B)/(C/D))
    Tc = A/(A+B+C)

    return {
        'A': A,
        'B': B,
        'C': C,
        'D': D,
        'PRR': PRR,
        'Tc': Tc
    }

def build_dataframe(matrices, ordered_m1, m1_name, ordered_m2, m2_name, m12label=None, m22label=None, minA=10):
    
    mask = matrices['A']>=minA
    indices = np.where(mask)
    dfdata = {}

    dfdata[m1_name] = [ordered_m1[i] for i in indices[0]]
    if m12label is not None:
        dfdata[f"{m1_name}_label"] = [m12label.get(ordered_m1[i], i) for i in indices[0]]
    
    dfdata[m2_name] = [ordered_m2[i] for i in indices[1]]
    if m22label is not None:
        dfdata[f"{m2_name}_label"] = [m22label.get(ordered_m2[i], i) for i in indices[1]]

    for key, mat in matrices.items():
        dfdata[key] = mat[mask]

    return pd.DataFrame(dfdata)

def build_merged_dataframe(drugs, reactions, indications, num_reports, drug_names, reaction_names, indication_names, minA):
    matrices = compute_matrices(drugs, reactions, num_reports)
    drug_rxn = build_dataframe(matrices, drug_names, 'drug', 
                            reaction_names, 'reaction', minA=minA)
    
    matrices = compute_matrices(indications, reactions, num_reports)
    ind_rxn = build_dataframe(matrices, indication_names, 'indication', reaction_names, 'reaction', minA=minA)
    
    matrices = compute_matrices(indications, drugs, num_reports)
    ind_ing = build_dataframe(matrices, indication_names, 'indication', drug_names, 'drug', minA=minA)
    
    # build datafraem to look at confounding by indication
    ind_keep = ['drug', 
            'reaction', 
            'indication',
            'PRR_ing_rxn',
            'PRR_ind_ing',
            'PRR_ind_rxn']
    ind_merged = drug_rxn.merge(ind_ing, 
                                on='drug',
                                suffixes=('_ing_rxn', '_ind_ing')
                            ).merge(ind_rxn,
                                on=('indication', 'reaction',),
                                suffixes=('', '_ind_rxn'))

    ind_merged.rename(columns={'PRR': 'PRR_ind_rxn'}, inplace=True)
    ind_reduced = ind_merged[ind_keep]
    
    return drug_rxn, ind_rxn, ind_ing, ind_reduced

In [77]:
def evaluate_dataset(output_path, dataset_path, split, dataset_idx=0, exampleid=0, show_plots=False):

    final_predictions = np.load(os.path.join(output_path, f'final_predictions_{split}_{dataset_idx}.npy'))
    
    config_path = os.path.join(dataset_path, 'config.json')
    if not os.path.exists(config_path):
        raise Exception(f"No config file found at: {config_path}")
    config = json.load(open(config_path))
    
    # Get dimensions from config
    ndrugs = config['ndrugs']
    nreactions = config['nreactions']
    nindications = config['nindications']
    nreports = config['nreports']
    # nexamples = config['nexamples']

    # For each prediction, split and reshape
    reshaped = final_predictions[exampleid].reshape(nreports, ndrugs+nreactions+nindications)
    reactions = reshaped[:,:nreactions]
    drugs = reshaped[:,nreactions:(ndrugs+nreactions)]
    indications = reshaped[:,(ndrugs+nreactions):]
    #print(reactions.shape, drugs.shape, indications.shape)
    pred = np.hstack([reactions, drugs, indications])

    target_drugs = sp.sparse.load_npz(os.path.join(dataset_path, 'drugs.npz')).toarray()
    target_reactions = sp.sparse.load_npz(os.path.join(dataset_path, 'reactions.npz')).toarray()
    target_indications = sp.sparse.load_npz(os.path.join(dataset_path, 'datasets', f'{exampleid}_indications.npz')).toarray()
    target_reactions_observed = sp.sparse.load_npz(os.path.join(dataset_path, 'datasets', f'{exampleid}_reactions_observed.npz')).toarray()
    actual = np.hstack([target_reactions, target_drugs, np.zeros(shape=indications.shape)])

    #print(target_reactions.shape, target_drugs.shape, target_indications.shape)
    
    print(f"Accuracy")
    print(f"R, D, I: {(pred==actual).sum()/pred.size}")
    print(f"   Ones: {(actual[actual==1]==pred[actual==1]).sum()/(actual[actual==1]==1).sum()}")

    print()
    newpred = pred[:,:nreactions]
    newactual = actual[:,:nreactions]
    print(f"R      : {(newpred==newactual).sum()/newpred.size}")
    print(f"   Ones: {(newactual[newactual==1]==newpred[newactual==1]).sum()/(newactual[newactual==1]==1).sum()}")

    print()
    newpred = pred[:,nreactions:(nreactions+ndrugs)]
    newactual = actual[:,nreactions:(nreactions+ndrugs)]
    print(f"   D   : {(newpred==newactual).sum()/newpred.size}")
    print(f"   Ones: {(newactual[newactual==1]==newpred[newactual==1]).sum()/(newactual[newactual==1]==1).sum()}")

    print()
    newpred = pred[:,(nreactions+ndrugs):]
    newactual = actual[:,(nreactions+ndrugs):]
    print(f"      I: {(newpred==newactual).sum()/newpred.size}")
    #print(f"   Ones: {(newactual[newactual==1]==newpred[newactual==1]).sum()/(newactual[newactual==1]==1).sum()}")
    print()
    
    drug_names = config['drug_names']
    reaction_names = config['reaction_names']
    indication_names = config[f'dataset_{exampleid}']['indication_names']

    drug_rxn_obs, _, _, observed = build_merged_dataframe(target_drugs, target_reactions_observed, target_indications, nreports, drug_names, reaction_names, indication_names, minA=1)
    
    drug_rxn_factors = json.load(open(os.path.join(dataset_path, 'drug_rxn_factors.json')))
    drug_rxn_truth = json.load(open(os.path.join(dataset_path, 'drug_rxn_truth.json')))

    drug_rxn_obs['factor'] = [drug_rxn_factors[d][r] for _, (d, r) in drug_rxn_obs[['drug', 'reaction']].iterrows()]
    drug_rxn_obs['truth'] = [drug_rxn_truth[d][r] for _, (d, r) in drug_rxn_obs[['drug', 'reaction']].iterrows()]
    
    statistic = 'PRR'
    
    print(f"Observed AUROC: {sklearn.metrics.roc_auc_score(drug_rxn_obs['truth'], drug_rxn_obs[statistic])}")
    print(f"Observed AUPR : {sklearn.metrics.average_precision_score(drug_rxn_obs['truth'], drug_rxn_obs[statistic])}")
    
    print(f"Ind PRRs greatrer than 10: {(observed['PRR_ind_rxn'] > 10).sum()}")

    # we only care about whent the relationship between the indication and the reaction is high
    ind_rxn_high = observed[observed['PRR_ind_rxn'] > 10]
    #print(ind_rxn_high.shape)

    rho, pval = sp.stats.spearmanr(ind_rxn_high['PRR_ind_ing'], ind_rxn_high['PRR_ing_rxn'])
    ymin = 0.5*(ind_rxn_high['PRR_ing_rxn'].min())
    ymax = 2*(ind_rxn_high['PRR_ing_rxn'].max())
    
    if show_plots:
        plt.figure(figsize=(10,6))
        plt.title(split)
        plt.subplot(2,2,1)
        plt.scatter(ind_rxn_high['PRR_ind_ing'], ind_rxn_high['PRR_ing_rxn'], alpha=0.4, label=f"rho={rho:.2f}, P={pval:.2e}")
        plt.xscale('log')
        plt.yscale('log')
        plt.xlabel('PRR(Drug, Indication))')
        plt.ylabel('Observed PRR(Drug, Rxn)')
        plt.ylim(ymin, ymax)
        plt.legend()
        sns.despine()

        plt.subplot(2,2,2)
    

    drug_rxn_corr, _, _, corrected = build_merged_dataframe(target_drugs, reactions, target_indications, nreports, drug_names, reaction_names, indication_names, minA=0)
    corrected = corrected.merge(ind_rxn_high, on=('drug', 'reaction', 'indication'), suffixes=('_corrected', '_observed'))
    rho, pval = sp.stats.spearmanr(corrected['PRR_ind_ing_observed'], corrected['PRR_ing_rxn_corrected'])

    drug_rxn_corr['factor'] = [drug_rxn_factors[d][r] for _, (d, r) in drug_rxn_corr[['drug', 'reaction']].iterrows()]
    drug_rxn_corr['truth'] = [drug_rxn_truth[d][r] for _, (d, r) in drug_rxn_corr[['drug', 'reaction']].iterrows()]
    
    print(f"Corrected AUROC: {sklearn.metrics.roc_auc_score(drug_rxn_corr['truth'], drug_rxn_corr[statistic])}")
    print(f"Corrected AUPR : {sklearn.metrics.average_precision_score(drug_rxn_corr['truth'], drug_rxn_corr[statistic])}")

    if show_plots:
        plt.scatter(corrected['PRR_ind_ing_observed'], corrected['PRR_ing_rxn_corrected'], alpha=0.4, label=f"rho={rho:.2f}, P={pval:.2e}")
        plt.xscale('log')
        plt.yscale('log')
        plt.ylim(ymin, ymax)
        plt.xlabel('PRR(Drug, Indication))')
        plt.ylabel('Corrected PRR(Drug, Rxn)')
        plt.legend()
        sns.despine()

        plt.subplot(2,2,3)
        fpr, tpr, _ = sklearn.metrics.roc_curve(drug_rxn_obs['truth'], drug_rxn_obs[statistic])
        plt.plot(fpr, tpr, label='Observed')
        fpr, tpr, _ = sklearn.metrics.roc_curve(drug_rxn_corr['truth'], drug_rxn_corr[statistic])
        plt.plot(fpr, tpr, label='Corrected')
        plt.ylabel('TPR')
        plt.xlabel('FPR')
        plt.legend()
        sns.despine()
        
        plt.subplot(2,2,4)
        pr, re, _ = sklearn.metrics.precision_recall_curve(drug_rxn_obs['truth'], drug_rxn_obs[statistic])
        plt.plot(re, pr, label='Observed')
        pr, re, _ = sklearn.metrics.precision_recall_curve(drug_rxn_corr['truth'], drug_rxn_corr[statistic])
        plt.plot(re, pr, label='Corrected')
        plt.ylabel('Precision')
        plt.xlabel('Recall')
        plt.legend()
        sns.despine()

        plt.tight_layout()
        plt.suptitle(split)

    
    

In [96]:
# run_path = os.path.join('..', 'outputs/small_pt0.1_mif2_small_pt0.1_mif5_medium_pt0.1_mif5')
run_path = os.path.join('..', './outputs/all_small_as_train')
run_config = json.load(open(os.path.join(run_path, 'config.json')))
run_config

{'batch_size': 512,
 'hidden_dim': 8,
 'z_dim': 2,
 'epochs': 100,
 'lr': 0.001,
 'gpu': 3,
 'wd': 0.1,
 'dataset': ['./data/small_pt0.02_mif2',
  './data/small_pt0.03_mif2',
  './data/small_pt0.04_mif2',
  './data/small_pt0.05_mif2',
  './data/small_pt0.06_mif2',
  './data/small_pt0.07_mif2',
  './data/small_pt0.08_mif2',
  './data/small_pt0.09_mif2',
  './data/small_pt0.09_mif2_0iF',
  './data/small_pt0.09_mif2_0Vs',
  './data/small_pt0.09_mif2_0Vt',
  './data/small_pt0.09_mif2_1dh',
  './data/small_pt0.09_mif2_1eI',
  './data/small_pt0.09_mif2_1hk',
  './data/small_pt0.09_mif2_3qm',
  './data/small_pt0.09_mif2_3Z6',
  './data/small_pt0.09_mif2_4ZF',
  './data/small_pt0.09_mif2_5Uj',
  './data/small_pt0.09_mif2_5vR',
  './data/small_pt0.09_mif2_68x',
  './data/small_pt0.09_mif2_6eb',
  './data/small_pt0.09_mif2_6WE',
  './data/small_pt0.09_mif2_743',
  './data/small_pt0.09_mif2_7GS',
  './data/small_pt0.09_mif2_7vu',
  './data/small_pt0.09_mif2_a85',
  './data/small_pt0.09_mif2_AbO',

In [97]:
dataset_idx = 1
train_dataset_paths = list(map(lambda x: os.path.join('..', x), run_config['dataset']))
val_dataset_path = os.path.join('..', run_config['val_dataset'])
output_path = os.path.join('..', run_config['save_dir'])

evaluate_dataset(output_path, 
                 dataset_path=train_dataset_paths[dataset_idx], 
                 split='train', 
                 dataset_idx=dataset_idx, 
                 exampleid=0,
                 show_plots=False)

(10000, 50) (10000, 25) (10000, 30)
(10000, 50) (10000, 25) (10000, 30)
Accuracy
R, D, I: 0.7648838095238095
   Ones: 0.6236

R      : 0.774948
   Ones: 0.4522

   D   : 0.463788
   Ones: 0.795

      I: 0.9990233333333334

Observed AUROC: 0.7846217788861181
Observed AUPR : 0.2894872886147173
Ind PRRs greatrer than 10: 11
Corrected AUROC: 0.5557203090641328
Corrected AUPR : 0.05574330307299588


In [98]:
evaluate_dataset(output_path=output_path, 
                 dataset_path = 
                 val_dataset_path, 
                 split = 'val', 
                 exampleid=0,
                 show_plots=False)

(10000, 50) (10000, 25) (10000, 30)
(10000, 50) (10000, 25) (10000, 30)
Accuracy
R, D, I: 0.7659161904761905
   Ones: 0.5268

R      : 0.77839
   Ones: 0.2908

   D   : 0.460068
   Ones: 0.7628

      I: 1.0

Observed AUROC: 0.7994796445329908
Observed AUPR : 0.3626603495656726
Ind PRRs greatrer than 10: 52
Corrected AUROC: 0.4830048608042421
Corrected AUPR : 0.11511463418324905
