# 1. Load Data

In [1]:
import numpy as np
from clinnet.data_loader_pnet import PNETData

In [2]:
data = PNETData(data_type=['mut_important', 'cnv_del', 'cnv_amp'], cnv_levels=3, mut_binary=True,
                balanced_data=False, combine_type='union', use_coding_genes_only=True)

In [3]:
x_train, y_train, x_valid, y_valid, x_test, y_test, genes, gene_status, class_weight = data.get_kf(1)

In [4]:
print(f"x_train: {x_train.shape}")
print(f"x_valid: {x_valid.shape}")
print(f"x_test:  {x_test.shape}")

print(f"y_train: {y_train.shape}, {np.unique(y_train, return_counts=True)}")
print(f"y_valid: {y_valid.shape}, {np.unique(y_valid, return_counts=True)}")
print(f"y_test:  {y_test.shape}, {np.unique(y_test, return_counts=True)}")

print(f"genes: {genes.shape}")
print(f"gene_status: {gene_status}")
print(f"class_weight: {class_weight}")

x_train: (606, 47451)
x_valid: (202, 47451)
x_test:  (203, 47451)
y_train: (606, 1), (array([0, 1], dtype=int64), array([406, 200], dtype=int64))
y_valid: (202, 1), (array([0, 1], dtype=int64), array([136,  66], dtype=int64))
y_test:  (203, 1), (array([0, 1], dtype=int64), array([136,  67], dtype=int64))
genes: (15817,)
gene_status: ['mut_important' 'cnv_del' 'cnv_amp']
class_weight: {0: 0.745575221238938, 1: 1.518018018018018}


# 2. Model

In [5]:
from clinnet.model import CLinNET, load_clinnet

In [6]:
clinnet_model = CLinNET(genes, gene_status, tissue='epithelial cell of prostate', saving_dir='PNETDataset_Run2', n_hids={'GO':5,'Reactome':5}, ex_source='fantom',
                        w_regs={'Diag':.005,'GO':[.01]*10,'Reactome':[.003]*10}, w_regs_outcome={'GS':.03,'Diag':.001,'GO':[0]*10,'Reactome':[0]*10},
                        learning_rate=.001, drop_rate=[.4, .3, .1, .0, .0, .0], verbose=True)

In [7]:
clinnet_model.train(x_train, y_train, x_valid, y_valid, class_weight=None, batch_size=80, epochs=300, verbose=2)

Epoch 1/300
8/8 - 10s - loss: 308.0401 - Input_Dense_loss: 18.8515 - GS_Dense_loss: 37.3006 - Diag_Dense_loss: 0.8559 - Hidd1_Dense_GO_loss: 0.9650 - Hidd2_Dense_GO_loss: 1.0271 - Hidd3_Dense_GO_loss: 1.0989 - Hidd4_Dense_GO_loss: 1.1976 - Hidd5_Dense_GO_loss: 0.7123 - Hidd1_Dense_Reactome_loss: 1.0464 - Hidd2_Dense_Reactome_loss: 1.0891 - Hidd3_Dense_Reactome_loss: 1.1186 - Hidd4_Dense_Reactome_loss: 1.1166 - Hidd5_Dense_Reactome_loss: 1.0767 - Input_Dense_f1: 0.5511 - Input_Dense_accuracy: 0.6568 - Input_Dense_auc: 0.6596 - Input_Dense_recall: 0.6250 - Input_Dense_precision: 0.4845 - GS_Dense_f1: 0.4821 - GS_Dense_accuracy: 0.6040 - GS_Dense_auc_1: 0.5876 - GS_Dense_recall_1: 0.5450 - GS_Dense_precision_1: 0.4225 - Diag_Dense_f1: 0.4486 - Diag_Dense_accuracy: 0.5363 - Diag_Dense_auc_2: 0.5685 - Diag_Dense_recall_2: 0.5600 - Diag_Dense_precision_2: 0.3672 - Hidd1_Dense_GO_f1: 0.4727 - Hidd1_Dense_GO_accuracy: 0.3168 - Hidd1_Dense_GO_auc_3: 0.4030 - Hidd1_Dense_GO_recall_3: 0.9550 - Hi



INFO:tensorflow:Assets written to: result\PNETDataset_Run2\epithelial cell of prostate\saved_model\weights\assets


INFO:tensorflow:Assets written to: result\PNETDataset_Run2\epithelial cell of prostate\saved_model\weights\assets


In [8]:
clinnet_model.evaluate(x_valid, y_valid, x_test, y_test, converge_method='average')



{'Threshold': 0.6899999999999997,
 'Accuracy': 0.8817733990147784,
 'Precision': 0.8412698412698413,
 'Sensitivity_recall': 0.7910447761194029,
 'Specificity': 0.9264705882352942,
 'F1_score': 0.8153846153846154,
 'AUC': 0.92328797190518}

In [9]:
clinnet_model.save_predictions(x_train=x_train, y_train=y_train, x_valid=x_valid, y_valid=y_valid, x_test=x_test, y_test=y_test)



In [10]:
# clinnet_model = load_clinnet('result/PNETDataset_Run2/brain')

# 3. Interpretability

## 3.1 SHAP

In [11]:
from clinnet.shap import SHAP

In [12]:
shap = SHAP(clinnet_model, train_n_sample=600, test_n_sample=200)

In [13]:
shap.get_layer_shap(clinnet_model.model, x_train, x_test, y_train, y_test)

In [14]:
shap.save_shap_csv()
#shap.save_shap_plot()

In [16]:
from IPython.display import display, HTML

# Paths to your images
image_1_path = "result\PNETDataset_Run2\epithelial cell of prostate\interpretability\SHAP\Diag.png"
image_2_path = "result\PNETDataset_Run2\epithelial cell of prostate\interpretability\SHAP\GS.png"
image_3_path = "result\PNETDataset_Run2\epithelial cell of prostate\interpretability\SHAP\Hidd1_Reactome.png"

# HTML to display images side by side
html_code = f"""
<div style="display: flex; justify-content: space-around;">
    <img src="{image_1_path}" style="width: 30%; margin-right: 5px;" />
    <img src="{image_2_path}" style="width: 30%; margin-right: 5px;" />
    <img src="{image_3_path}" style="width: 30%; margin-right: 5px;" />
</div>
"""

display(HTML(html_code))


## 3.2 Sankey

In [17]:
from clinnet.sankey import Sankey

In [18]:
sankey = Sankey(shap.graph, shap.interpret_dir, sv_norm=shap.sv_norm, gene_status=gene_status)

In [None]:
sankey.plot_sankey(use_abb=True)

# 4. Cross Validation

In [1]:
from clinnet.data_loader_pnet import PNETData
from clinnet.cross_validation import CV

In [2]:
data_params = {"data_type":['mut_important', 'cnv_del', 'cnv_amp'], 
               "cnv_levels":3, 
               "mut_binary":True, 
               "balanced_data":False, 
               "combine_type":'union', 
               "use_coding_genes_only":True, 
               "n_split":5}

model_params = {"build": {"n_hids":{'GO':5,'Reactome':5}, 
                          "ex_source":'fantom', 
                          "w_regs": {'Diag':.001,'GO':[.001]*10,'Reactome':[.003]*10}, 
                          "w_regs_outcome":{'GS':.001,'Diag':.0001,'GO':[0]*10,'Reactome':[0]*10}, 
                          "learning_rate":.0001,
                          "drop_rate":[.2, .1, .0, .0, .0, .0], 
                          "verbose":True},

                "train": {"batch_size":80, 
                          "epochs":300,
                          "verbose":2}}

In [3]:
cv = CV(data_class=PNETData, data_params=data_params, model_params=model_params, tissue='epithelial cell of prostate', saving_dir='PNET_CV')

In [None]:
cv.run_cross_validation()

In [None]:
cv.aggregate_result()