<a href="https://colab.research.google.com/github/rcurrie/pancan-gtex/blob/master/infer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Infer
Infer primary site and tumor vs. normal from kallisto expression data using a trained neural network. Explain the classification via the SHAP package

In [None]:
import sys
import os
import json
import requests
import numpy as np
import pandas as pd
!pip install -q tables
import tensorflow as tf
from IPython.core.display import display, HTML

# Switch to a scratch data directory so all paths are local
os.makedirs(os.path.expanduser("~/data/pancan-gtex"), exist_ok=True)
os.chdir(os.path.expanduser("~/data/pancan-gtex"))

In [None]:
# Setup S3 connection to download the training set and trained model
import boto3
from botocore.handlers import disable_signing

bucket_name = "stuartlab"
endpoint = "s3.nautilus.optiputer.net"

# Set so that Tensorflow can pull from the PRP S3/CEPH storage cluster
os.environ["S3_ENDPOINT"] = endpoint

session = boto3.session.Session()
resource = boto3.resource("s3", endpoint_url="https://{}".format(endpoint))
resource.meta.client.meta.events.register('choose-signer.s3.*', disable_signing)
bucket = resource.Bucket(bucket_name)

# Output the data and checksums for the dataset and trained model
print("Dataset:")
for obj in bucket.objects.filter(Prefix="pancan-gtex"):
    print(obj.last_modified.isoformat(), obj.e_tag[1:-1], obj.key) 
    
print("Trained Model:") 
for obj in bucket.objects.filter(Prefix="rcurrie/pancan-gtex/models"):
    print(obj.last_modified.isoformat(), obj.e_tag[1:-1], obj.key) 

In [None]:
# Load background samples and ensemble to hugo table for explanation
if not os.path.exists("pancan-gtex-transcript.h5"):
    print("Downloading backgound...")
    r = requests.get("https://s3.nautilus.optiputer.net/stuartlab/pancan-gtex/pancan-gtex-transcript.h5")
    open("pancan-gtex.h5", "wb").write(r.content)
    
X = pd.read_hdf("pancan-gtex-transcript.h5", "samples")
Y = pd.read_hdf("pancan-gtex-transcript.h5", "labels")
print("Loaded {} samples with {} features and {} labels".format(X.shape[0], X.shape[1], Y.shape[1]))

    
if not os.path.exists("ensemble-to-hugo.tsv"):
    r = requests.get("https://s3.nautilus.optiputer.net/stuartlab/pancan-gtex/ensemble-to-hugo.tsv")
    open("ensemble-to-hugo.tsv", "wb").write(r.content)
    
ensemble_to_hugo = pd.read_table("ensemble-to-hugo.tsv", index_col=0)

In [None]:
# Load params and trained model from S3
r = requests.get("https://s3.nautilus.optiputer.net/stuartlab/rcurrie/pancan-gtex/models/params-transcript.json")
params = r.json()

r = requests.get("https://s3.nautilus.optiputer.net/stuartlab/rcurrie/pancan-gtex/models/model-transcript.h5")
import tempfile
temp_path = "/tmp/{}.h5".format(next(tempfile._get_candidate_names()))
open(temp_path, "wb").write(r.content)

model = tf.keras.models.load_model(temp_path)
os.remove(temp_path)
model.summary()

In [None]:
from google.colab import files

uploaded = files.upload()

In [None]:
# Load the sample's expression kallisto TPM output
filename = list(uploaded.keys())[0]

# # Process local sample
# filename = os.path.expanduser("~/data/samples/TH06_1172_S01/abundance.tsv")

sample = pd.read_table(filename, index_col=0, engine='c').astype(np.float32).T.loc[["tpm"]]
assert int(sample.iloc[0].sum()) == 1000000
sample.head()

In [None]:
prediction = model.predict(sample.filter(X.columns, axis="columns").sort_index(axis="columns"))[0]

In [None]:
# To predict any of the training samples:
# prediction = model.predict(X.loc[["TCGA-ZS-A9CE-01"]])[0]

In [None]:
tumor_normal_prediction_index = int(round(prediction[0]))
tumor_normal_prediction_value = prediction[0]
tumor_normal_prediction_label = params["tumor_normal"][tumor_normal_prediction_index]
display(HTML("<h3>{} with {:.2f} confidence</h3>".format(
    tumor_normal_prediction_label, tumor_normal_prediction_value)))

primary_site_prediction_index = np.argmax(prediction[1:-1])
primary_site_prediction_value = prediction[1:-1][primary_site_prediction_index]
primary_site_prediction_label = params["primary_site"][primary_site_prediction_index]
display(HTML("<h3>{} with {:.2f} confidence".format(
    primary_site_prediction_label, primary_site_prediction_value)))