# Disease type classification for disease genes.

## Contents
* [Load gene and disease network.](#Load-gene-and-disease-network)
* [Load the training data.](#Load-the-training-data.)
* [Define-the-training process.](#Define-the-training-process.)
* [Train the classificatrion model.](#Train-the-classificatrion-model.)
* [Eval the results.](#Eval-the-results.)

In [1]:
import sys
import logging
import os
import os.path as osp
import torch
import random
import time
import torch.nn.functional as F
import pandas as pd
from torch.nn import Linear
import numpy as np
from IPython.display import display, HTML
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv
from sklearn.metrics import label_ranking_average_precision_score, label_ranking_loss, roc_auc_score
from sklearn.metrics import roc_curve, precision_recall_curve, auc, average_precision_score
from torch_geometric.data import InMemoryDataset, Data
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from tqdm import tqdm
import pickle
import gzip
from livelossplot import PlotLosses
import sklearn

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [2]:
##### Expreiment hyperparameters

#  NEGATIVE_SAMPLES determines how the negative examples for training are created.
#  Choose from {'random', 'random_only_disease_gene'}
#    * random: Choose a random (gene, disease) pair which is not in the positive set.
#    * random_only_disease_genes: Like random but gene must be assigned to at least one disease.
NEGATIVE_SAMPLES = 'random'
EXPERIMENT_SLUG = 'final'
device = torch.device('cuda')

## Load gene and disease network

In [3]:
# Define Constants.
HERE = osp.abspath('')
ROOT = osp.join(HERE, '..', '..')
DATA_SOURCE_PATH = osp.join(ROOT, 'data_sources')
GENE_DATASET_ROOT = osp.join(DATA_SOURCE_PATH, 'gene_net_dataset_fn_and_hpo_features')
DISEASE_DATASET_ROOT = osp.join(DATA_SOURCE_PATH, 'disease_net_no_hpo_sim_based')
RESULTS_STORAGE = osp.join(HERE, 'results', EXPERIMENT_SLUG)
MODEL_TMP_STORAGE = osp.join('/', 'var', 'tmp', 'dg_tmp')
sys.path.insert(0, osp.abspath(ROOT))

# Generate the disease and gene network.
from GeneNet import GeneNet
from DiseaseNet import DiseaseNet
from TheModel import TheModel

gene_dataset = GeneNet(
    root=GENE_DATASET_ROOT,
    humannet_version='FN',
    features_to_use='hpo',
    skip_truncated_svd=True
)

disease_dataset = DiseaseNet(
    root=DISEASE_DATASET_ROOT,
    hpo_count_freq_cutoff=40,
    edge_source='feature_similarity',
    feature_source='disease_publications',
    skip_truncated_svd=True,
    svd_components=2048,
    svd_n_iter=12
)
gene_net_data = gene_dataset[0]
disease_net_data = disease_dataset[0]
print(gene_net_data)
print(disease_net_data)
gene_net_data = gene_net_data.to(device)
disease_net_data = disease_net_data.to(device)

Data(edge_attr=[371502, 1], edge_index=[2, 371502], x=[17247, 6209])
Data(edge_attr=[282464], edge_index=[2, 282464], x=[8827, 11460])


## Load the training data.

In [5]:
# Generate training data.
disease_genes = pd.read_table(
    osp.join(DATA_SOURCE_PATH, 'genes_diseases.tsv'),
    names=['EntrezGene ID', 'OMIM ID'],
    sep='\t', 
    low_memory=False, 
    dtype={'EntrezGene ID': pd.Int64Dtype()}
)

disease_id_index_feature_mapping = disease_dataset.load_disease_index_feature_mapping()
gene_id_index_feature_mapping = gene_dataset.load_node_index_mapping()

all_genes = list(gene_id_index_feature_mapping.keys())
all_diseases = list(disease_id_index_feature_mapping.keys())

# 1. generate positive pairs.
# Filter the pairs to only include the ones where the corresponding nodes are available.
# i.e. gene_id should be in all_genes and disease_id should be in all_diseases.
positives = disease_genes[
    disease_genes["OMIM ID"].isin(all_diseases) & disease_genes["EntrezGene ID"].isin(all_genes)
]
covered_diseases = list(set(positives['OMIM ID']))
covered_genes = list(set(positives['EntrezGene ID']))

# 2. Generate negatives.
# Pick equal amount of pairs not in the positives.
negatives_list = []
while len(negatives_list) < len(positives):
    if NEGATIVE_SAMPLES == 'random_only_disease_genes' :
        gene_id = covered_genes[np.random.randint(0, len(covered_genes))]
    else:
        gene_id = all_genes[np.random.randint(0, len(all_genes))]
    disease_id = covered_diseases[np.random.randint(0, len(covered_diseases))]
    if not ((positives['OMIM ID'] == disease_id) & (positives['EntrezGene ID'] == gene_id)).any():
        negatives_list.append([disease_id, gene_id])
negatives = pd.DataFrame(np.array(negatives_list), columns=['OMIM ID', 'EntrezGene ID'])

In [6]:
# Disease classification data preparation.
# Load the disease classes.
GENE_CLASS_LABELS_FILE = osp.join(DATA_SOURCE_PATH, 'extracted_disease_class_assignments.tsv')
# Load the training data
disease_class_training_data = pd.read_csv(GENE_CLASS_LABELS_FILE, sep='\t')
# drop duplicates
unique_labeled_disease_class_genes = disease_class_training_data.drop_duplicates()
gene_id_node_index_df = pd.DataFrame(data=[(gene_id, node_index) for gene_id, node_index in gene_id_index_feature_mapping.items()], columns=['gene_id', 'node_index'])
disease_id_node_index_df = disease_id_node_index_df = pd.DataFrame(
    data=[(disease_id, node_index) for disease_id, node_index in disease_id_index_feature_mapping.items()], columns=['disease_id', 'disease_node_index']
)

# Create the gene index
# Join in the gene node indexes
disease_class_training_data = pd.merge(
    unique_labeled_disease_class_genes, 
    gene_id_node_index_df, 
    left_on='gene_id', 
    right_on='gene_id',
    validate='many_to_many'
)

disease_class_counts = disease_class_training_data['disease_class'].value_counts()
disease_class_target_classes = [
	'Ophthamological',
	'Connective tissue',
	'Endocrine',
	'Skeletal',
	'Metabolic',
	'Cardiovascular',
	'Dermatological',
	'Renal',
	'Hematological',
	'Immunological',
	'Muscular',
	'Developmental'
]


def get_negative_disease_class_data(pos_class, n):
    # n = n // 2
    return disease_class_training_data[disease_class_training_data['disease_class'] != pos_class].sample(n=n, random_state=42)

def get_positive_disease_class_data(pos_class):
    return disease_class_training_data[disease_class_training_data['disease_class'] == pos_class].copy()

def get_disease_class_training_data(pos_class):
    pos = get_positive_disease_class_data(pos_class)
    pos['label'] = 1
    neg = get_negative_disease_class_data(pos_class, len(pos))
    neg['label'] = 0
    data = pd.concat([pos, neg], ignore_index=True)
    x = data.iloc[:,3:].values
    y = data.iloc[:,4:5].values.ravel()
    
    return x, torch.tensor(y), data
disease_class_training_data

Unnamed: 0,disease_id,gene_id,disease_class,node_index
0,27,1836,Bone,1582
1,160,1836,Connective tissue,1582
2,430,1836,Skeletal,1582
3,496,1836,Bone,1582
4,1,1586,Endocrine,5881
...,...,...,...,...
2673,1615,5195,multiple,9456
2674,1615,5824,multiple,218
2675,1615,5825,multiple,167
2676,1615,9409,multiple,6077


## Define the training process.

In [7]:
model = TheModel(
    gene_feature_dim=gene_net_data.x.shape[1],
    disease_feature_dim=disease_net_data.x.shape[1],
    fc_hidden_dim=3000,
    gene_net_hidden_dim=830,
    disease_net_hidden_dim=500
).to(device)

def train_disease_classification(model_parameter_file):
    # Load the pretrained model.
    model.load_state_dict(torch.load(model_parameter_file))

    # Set classification training hyperparameters.
    lr_classification=0.00000347821
    weight_decay_classification=0.5165618
    folds=5
    max_epochs=100
    info_each_epoch = 1
    early_stopping_window=15
    final_disease_class_metrics = dict()
    losses = {
                    'train': [],
                    'val': [],
                    'AUC': 0,
                    'TPR': None,
                    'FPR': None
            }
    for disease_class in disease_class_target_classes:
        for fold in range(folds):
            losses[f'train_disease_class_{disease_class}_{fold}'] = []
            losses[f'val_disease_class_{disease_class}_{fold}'] = []
            final_disease_class_metrics[f'{disease_class}_{fold}'] = {
                'roc_auc': 0,
                'pr_auc': 0,
                'fmax': 0
            }


    torch.save(model.state_dict(), '/var/tmp/dg_tmp/tmp_model_state.ptm')
    class_count = 0
    for disease_class in disease_class_target_classes:
        class_count += 1
        print(f'Evaluate pretrained model on disease class {disease_class} ({class_count}/{len(disease_class_target_classes)})')
        x_disease_class, y_disease_class, _ = get_disease_class_training_data(disease_class)
        optimizer_disease_class = torch.optim.Adam(model.parameters(), lr=lr_classification, weight_decay=weight_decay_classification)
        criterion_disease_class = torch.nn.CrossEntropyLoss()

        kf = KFold(n_splits=folds, shuffle=True, random_state=42)
        fold = -1
        for train_fold_index, test_fold_index in kf.split(x_disease_class):
            fold += 1
            print(f'Starting Fold: {fold}')
            model.load_state_dict(torch.load('/var/tmp/dg_tmp/tmp_model_state.ptm'))
            model.mode = 'Classify'
            # Split into train and validation.
            x_test = x_disease_class[test_fold_index]
            y_test = y_disease_class[test_fold_index].to(device)
            id_tr, id_val = train_test_split(range(x_disease_class[train_fold_index].shape[0]), test_size=0.1, random_state=42)
            x_train = x_disease_class[train_fold_index][id_tr]
            y_train = y_disease_class[train_fold_index][id_tr].to(device)
            x_val = x_disease_class[train_fold_index][id_val]
            y_val = y_disease_class[train_fold_index][id_val].to(device)

            # print(x_test.shape, x_train.shape, x_val.shape)
            # print(y_test.shape, y_train.shape, y_val.shape)

            for epoch in range(max_epochs):
                model.train()

                batch_size = 16
                permutation = torch.randperm(x_train.shape[0])
                # train
                loss_items = []
                for i in range(0, x_train.shape[0], batch_size):
                        # print(f'doing batch {i//batch_size}/{x_train.shape[0]//batch_size}')
                        batch_indices = permutation[i:i+batch_size]
                        batch_x, batch_y = x_train[batch_indices].reshape(-1, 2), y_train[batch_indices]

                        optimizer_disease_class.zero_grad()
                        out = model(gene_net_data, disease_net_data, batch_x)
                        loss = criterion_disease_class(out, batch_y)
                        loss.backward()
                        optimizer_disease_class.step()
                        loss_items.append(loss.item())
                losses[f'train_disease_class_{disease_class}_{fold}'].append(np.mean(loss_items))

                # validation
                with torch.no_grad():
                    model.eval()
                    out = model(gene_net_data, disease_net_data, x_val)
                    loss = criterion_disease_class(out, y_val)
                    losses[f'val_disease_class_{disease_class}_{fold}'].append(loss.item())

                    if epoch % info_each_epoch == 0:
                        print(
                            'Epoch {}, train_loss: {:.4f}, val_loss: {:.4f}'.format(
                                epoch, losses[f'train_disease_class_{disease_class}_{fold}'][epoch], losses[f'val_disease_class_{disease_class}_{fold}'][epoch]
                            )
                        )

                # Early stopping
                if epoch > early_stopping_window:
                    # Stop if validation error did not decrease 
                    # w.r.t. the past early_stopping_window consecutive epochs.
                    last_window_losses = losses[f'val_disease_class_{disease_class}_{fold}'][epoch - early_stopping_window:epoch]
                    if losses[f'val_disease_class_{disease_class}_{fold}'][-1] > max(last_window_losses):
                        print('Early Stopping!')
                        break

            # Test the disease classification model for current fold.
            print(f'Test the model on fold {fold}:')
            with torch.no_grad():
                y_score = model(gene_net_data, disease_net_data, x_test)[:,1].cpu().detach().numpy()
                y = y_test.cpu().detach().numpy()
                final_disease_class_metrics[f'{disease_class}_{fold}']['roc_auc'] = sklearn.metrics.roc_auc_score(y, y_score)
                precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y, y_score)
                final_disease_class_metrics[f'{disease_class}_{fold}']['pr_auc'] = sklearn.metrics.auc(recall, precision)
                final_disease_class_metrics[f'{disease_class}_{fold}']['fmax'] = ((2*precision * recall) / (precision + recall + 0.00001)).max()
                print(final_disease_class_metrics[f'{disease_class}_{fold}'])
    return final_disease_class_metrics

## Train the classificatrion model.
* Run the training for each pretrained model.
* Store results in `disease_classification_results.gz`

In [8]:
import gzip
import pickle

pretrained_folds = 5
all_final_disease_class_metrics = []
results_file = osp.join(RESULTS_STORAGE, f'disease_classification_results.gz')
for i in range(1, pretrained_folds + 1):
    print('############################################################')
    print(f'# START CLASSIFICATION USING PRETRAINED MODEL FROM FOLD: {i} #')
    print('############################################################')
    results = train_disease_classification(osp.join(RESULTS_STORAGE, f'model_fold_{i}.ptm'))
    all_final_disease_class_metrics.append(results)
    with gzip.open(results_file, mode='wb') as f:
        pickle.dump(all_final_disease_class_metrics, f)

############################################################
# START CLASSIFICATION USING PRETRAINED MODEL FROM FOLD: 1 #
############################################################
Evaluate pretrained model on disease class Ophthamological (1/12)
Starting Fold: 0
Epoch 0, train_loss: 0.6949, val_loss: 0.6956
Epoch 1, train_loss: 0.6803, val_loss: 0.6937
Epoch 2, train_loss: 0.6741, val_loss: 0.6877
Epoch 3, train_loss: 0.6623, val_loss: 0.6840
Epoch 4, train_loss: 0.6527, val_loss: 0.6817
Epoch 5, train_loss: 0.6518, val_loss: 0.6765
Epoch 6, train_loss: 0.6384, val_loss: 0.6734
Epoch 7, train_loss: 0.6347, val_loss: 0.6693
Epoch 8, train_loss: 0.6261, val_loss: 0.6625
Epoch 9, train_loss: 0.6247, val_loss: 0.6598
Epoch 10, train_loss: 0.6184, val_loss: 0.6582
Epoch 11, train_loss: 0.6069, val_loss: 0.6544
Epoch 12, train_loss: 0.6068, val_loss: 0.6512
Epoch 13, train_loss: 0.5905, val_loss: 0.6459
Epoch 14, train_loss: 0.5921, val_loss: 0.6417
Epoch 15, train_loss: 0.5803, val_loss:

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 16, train_loss: 0.5826, val_loss: 0.6583
Epoch 17, train_loss: 0.5788, val_loss: 0.6569
Epoch 18, train_loss: 0.5744, val_loss: 0.6562
Epoch 19, train_loss: 0.5628, val_loss: 0.6539
Epoch 20, train_loss: 0.5597, val_loss: 0.6529
Epoch 21, train_loss: 0.5487, val_loss: 0.6514
Epoch 22, train_loss: 0.5515, val_loss: 0.6503
Epoch 23, train_loss: 0.5400, val_loss: 0.6486
Epoch 24, train_loss: 0.5403, val_loss: 0.6479
Epoch 25, train_loss: 0.5320, val_loss: 0.6475
Epoch 26, train_loss: 0.5189, val_loss: 0.6467
Epoch 27, train_loss: 0.5160, val_loss: 0.6459
Epoch 28, train_loss: 0.5123, val_loss: 0.6446
Epoch 29, train_loss: 0.5163, val_loss: 0.6435
Epoch 30, train_loss: 0.5004, val_loss: 0.6434
Epoch 31, train_loss: 0.5017, val_loss: 0.6434
Epoch 32, train_loss: 0.4945, val_loss: 0.6435
Epoch 33, train_loss: 0.4887, val_loss: 0.6418
Epoch 34, train_loss: 0.4812, val_loss: 0.6419
Epoch 35, train_loss: 0.4717, val_loss: 0.6417
Epoch 36, train_loss: 0.4680, val_loss: 0.6423
Epoch 37, tra

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



## Eval the results.
* Load the results.
* Compute the mean of all models and folds.
* Plot comparison between DiGI, SmuDGE and our model.

In [9]:
with gzip.open(results_file, mode='rb') as file:
    all_final_disease_class_metrics = pickle.load(file, encoding='bytes')

In [14]:
ours_roc_auc = []
ours_pr_auc = []
ours_fmax = []


for disease_class in disease_class_target_classes:
    print(f'{disease_class} (support: {disease_class_counts[disease_class]})')
    roc_aucs, pr_aucs, fmaxs = [], [], []
    for pretrained_results_by_model in all_final_disease_class_metrics:
        for fold in range(5):
            roc_aucs.append(pretrained_results_by_model[f'{disease_class}_{fold}']['roc_auc'])
            pr_aucs.append(pretrained_results_by_model[f'{disease_class}_{fold}']['pr_auc'])
            fmaxs.append(pretrained_results_by_model[f'{disease_class}_{fold}']['fmax'])
    
    ours_roc_auc.append(np.mean(roc_aucs))
    ours_pr_auc.append(np.mean(pr_aucs))
    ours_fmax.append(np.mean(fmaxs))
    
    print(f'roc_auc:\t{np.mean(roc_aucs):0.2f}')
    print(f'pr_auc: \t{np.mean(pr_aucs):0.2f}')
    print(f'fmax:   \t{np.mean(fmaxs):0.2f}')
    print()

Ophthamological (support: 170)
roc_auc:	0.86
pr_auc: 	0.87
fmax:   	0.82

Connective tissue (support: 56)
roc_auc:	0.74
pr_auc: 	0.78
fmax:   	0.73

Endocrine (support: 124)
roc_auc:	0.82
pr_auc: 	0.83
fmax:   	0.79

Skeletal (support: 88)
roc_auc:	0.87
pr_auc: 	0.89
fmax:   	0.86

Metabolic (support: 326)
roc_auc:	0.89
pr_auc: 	0.89
fmax:   	0.85

Cardiovascular (support: 117)
roc_auc:	0.74
pr_auc: 	0.78
fmax:   	0.74

Dermatological (support: 102)
roc_auc:	0.88
pr_auc: 	0.91
fmax:   	0.86

Renal (support: 66)
roc_auc:	0.74
pr_auc: 	0.77
fmax:   	0.76

Hematological (support: 189)
roc_auc:	0.85
pr_auc: 	0.87
fmax:   	0.80

Immunological (support: 128)
roc_auc:	0.89
pr_auc: 	0.91
fmax:   	0.85

Muscular (support: 85)
roc_auc:	0.84
pr_auc: 	0.86
fmax:   	0.82

Developmental (support: 57)
roc_auc:	0.70
pr_auc: 	0.69
fmax:   	0.74

