# Train and Optimize

Train a neural network on TCGA+GTEX gene expression to classify tissue/disease

Try out a few hyper parameter frameworks:

* [Ray Tune](https://ray.readthedocs.io/en/latest/tune.html) (tried a few times but didn't manage to get it happy with Keras and cost functions)
* [Polyaxon](https://github.com/polyaxon/polyaxon)
* [Talos](https://github.com/autonomio/talos)

In [1]:
import sys
import os
import json
import pandas as pd
import numpy as np

# fix random seed for reproducibility
np.random.seed(42)

## Load and Wrangle Data

In [2]:
X = pd.read_hdf(os.path.expanduser("~/data/pancan_gtex.h5"), "samples")
Y = pd.read_hdf(os.path.expanduser("~/data/pancan_gtex.h5"), "labels")
print("Loaded {} samples with {} features and {} labels".format(X.shape[0], X.shape[1], Y.shape[1]))

Loaded 17964 samples with 42326 features and 42 labels


In [3]:
# Prune X to only KEGG pathway genes
# with open(os.path.expanduser("~/data/msigdb/c2.cp.kegg.v6.2.symbols.gmt")) as f:
#     subset_of_genes = list(set().union(*[line.strip().split("\t")[2:] for line in f.readlines()]))
# print("Pruning to only include KEGG pathway genes")

# Prune X to only Cosmic Cancer Genes
print("Pruning to only COSMIC genes")
subset_of_genes = pd.read_table("../data/cosmic_260818.tsv")["Gene Symbol"].values
    
pruned_X = X.drop(labels=(set(X.columns) - set(subset_of_genes)), axis=1)

# Order must match dataframe so we can use this as labels for shap
genes = list(pruned_X.columns.values)
print("Pruned expression to only include", len(genes), "genes")

# Create a multi-label one-hot for tumor/normal and primary site
from sklearn.preprocessing import LabelEncoder

primary_site_encoder = LabelEncoder()
Y["primary_site_value"] = pd.Series(
    primary_site_encoder.fit_transform(Y["_primary_site"]), index=Y.index, dtype='int32')

tumor_normal_encoder = LabelEncoder()
Y["tumor_normal_value"] = pd.Series(
    tumor_normal_encoder.fit_transform(Y["tumor_normal"]), index=Y.index, dtype='int32')

from keras.utils import np_utils
Y_onehot = np.append(
    Y["tumor_normal_value"].values.reshape(Y.shape[0],-1), 
    np_utils.to_categorical(Y["primary_site_value"]), axis=1)

Pruning to only COSMIC genes
Pruned expression to only include 700 genes


Using TensorFlow backend.


In [4]:
# Split into training and test sets strattified on primary site
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, test_index in split.split(X.values, Y.primary_site_value):
    X_train = pruned_X.values[train_index]
    X_test = pruned_X.values[test_index]
    Y_train = Y.iloc[train_index]
    Y_test = Y.iloc[test_index]
    Y_onehot_train = Y_onehot[train_index]
    Y_onehot_test = Y_onehot[test_index]
    
print("Train:", X_train.shape, "Test:", X_test.shape)

Train: (14371, 700) Test: (3593, 700)


In [5]:
# import matplotlib.pyplot as plt

# # Lets see how big each class is based on primary site
# plt.hist(Y_train.primary_site_value.values, alpha=0.5, label='Train')
# plt.hist(Y_test.primary_site_value.values, alpha=0.5, label='Test')
# plt.legend(loc='upper right')
# plt.title("Primary site distribution between train and test sets")
# plt.show()

# # Lets see how big each class is based tumor/normal
# plt.hist(Y_train.tumor_normal_value.values, alpha=0.5, label='Train')
# plt.hist(Y_test.tumor_normal_value.values, alpha=0.5, label='Test')
# plt.legend(loc='upper right')
# plt.title("Tumor/normal distribution between train and test sets")
# plt.show()

## Build and Train Model

In [6]:
# Only latest support functional keras models...
!pip uninstall talos -y
!pip install --upgrade git+https://github.com/autonomio/talos@master

# !pip install --upgrade talos==0.1.9.5
import talos

Uninstalling talos-0.3:
  Successfully uninstalled talos-0.3
[33mYou are using pip version 9.0.3, however version 18.0 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Collecting git+https://github.com/autonomio/talos@master
  Cloning https://github.com/autonomio/talos (to master) to /tmp/pip-ebh02bqf-build
Installing collected packages: talos
  Running setup.py install for talos ... [?25ldone
[?25hSuccessfully installed talos-0.3
[33mYou are using pip version 9.0.3, however version 18.0 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [31]:
from keras.models import Model
from keras.layers import Input, BatchNormalization, Dense, Dropout
from keras.callbacks import EarlyStopping
from keras import regularizers

params = {
    "width": (16, 32, 64)
}

def create_model(input_shape, output_shape, params):
    inputs = Input(shape=(input_shape,))

    x = BatchNormalization()(inputs)

    x = Dense(64, activation="relu")(x)
    x = Dropout(0.5)(x)
    
    x = Dense(64, activation="relu")(x)
    x = Dropout(0.5)(x)
    
    outputs = Dense(output_shape, activation="sigmoid")(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"])
    return model

def train_model(X_train, y_train, X_validate, y_validate, params):
    model = create_model(X_train.shape[1], y_train.shape[1], params)
#     model.summary()
    callbacks = [EarlyStopping(monitor="acc", min_delta=0.05, patience=2, verbose=2, mode="max")]
    history = model.fit(X_train, y_train, validation_data=[X_validate, y_validate],
                        epochs=10, batch_size=128, shuffle="batch", callbacks=callbacks)
    return history, model


hypteropt = talos.Scan(x=X_train, y=Y_onehot_train,
                       x_val=X_test, y_val=Y_onehot_test,
                       params=params,
                       dataset_name="pancan_gtex",
                       experiment_no="1",
                       model=train_model,
                       grid_downsample=1,
                       experimental_functional_support=True)

  0%|          | 0/16 [00:00<?, ?it/s]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


  6%|▋         | 1/16 [00:05<01:29,  5.95s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 12%|█▎        | 2/16 [00:11<01:22,  5.89s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 19%|█▉        | 3/16 [00:17<01:17,  5.99s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 25%|██▌       | 4/16 [00:23<01:10,  5.87s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 31%|███▏      | 5/16 [00:29<01:05,  5.94s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 38%|███▊      | 6/16 [00:35<00:59,  5.96s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 44%|████▍     | 7/16 [00:42<00:55,  6.13s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 50%|█████     | 8/16 [00:48<00:48,  6.10s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 56%|█████▋    | 9/16 [00:54<00:42,  6.14s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 62%|██████▎   | 10/16 [01:01<00:38,  6.37s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 69%|██████▉   | 11/16 [01:07<00:31,  6.38s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 75%|███████▌  | 12/16 [01:14<00:25,  6.43s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 81%|████████▏ | 13/16 [01:21<00:19,  6.66s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 88%|████████▊ | 14/16 [01:28<00:13,  6.68s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


 94%|█████████▍| 15/16 [01:34<00:06,  6.71s/it]

Train on 14371 samples, validate on 3593 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 00004: early stopping


100%|██████████| 16/16 [01:41<00:00,  6.77s/it]

Scan Finished!





In [32]:
report = talos.Reporting("pancan_gtex_1.csv")

In [33]:
report.table()

Unnamed: 0,val_acc,width
14,0.992044,17
4,0.991723,30
12,0.991626,18
15,0.991493,29
1,0.991445,16
6,0.991378,25
13,0.9913,21
11,0.991257,24
0,0.991233,20
7,0.991166,22
