In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sys
import os

# Load & preprocess mutation data

In [None]:
df_brca = pd.read_csv("data/tcga_brca_mutations_by_gene.csv", index_col=0)

In [None]:
df_brca.head()

In [None]:
# Mutation data is {0, 1}, we don't need 64 bits
df = df_brca.astype('int8')

# Prepare cathegories

We want to know whether a sample belongs to a tumor or normal tissue.

TCGA documentation (https://docs.gdc.cancer.gov/Encyclopedia/pages/TCGA_Barcode/) mentions that the information is encoded in the 'BarCode'

For instance, sample `TCGA-05-4244-01A-01R-1107-07`, the fourth identifier is `01A` which means 'Tumor' (`01`), whereas sample `TCGA-91-6829-11A-01R-1858-07`, has the fourth identifier `11A` which means Normal.



In [None]:
def to_tumor_normal(barcode):
    return barcode.split('-')[3][0] == '0'

tn = [to_tumor_normal(b) for b in df.index.values]
y = np.array(tn)

# Check that we only have tumor data, not normal tissue
sum(y), sum(1-y)

### Load clinical table

In [None]:
df_clinical = pd.read_csv("data/tcga_brca_clinical_data.tsv", sep="\t", low_memory=False)

In [None]:
df_clinical

In [None]:
# Rename some columns for easier access
cli = df_clinical.rename(columns={'Sample ID': 'sample_id', 
                         'Overall Survival (Months)': 'os', 
                         'Overall Survival Status': 'os_status'}
                        )

# Drop all other columns, we are not using them
cli = cli[['sample_id', 'os', 'os_status']].copy()
cli.set_index(cli.sample_id, inplace=True)
cli.drop(columns='sample_id', inplace=True)

In [None]:
# Remove samples with missing values
to_remove = pd.isna(cli.os) | pd.isna(cli.os_status)
sum(to_remove)

In [None]:
cli = cli.loc[~to_remove].copy()

In [None]:
# How many in each 'Survival state'
cli.os_status.value_counts()

In [None]:
# Censor data
cli = cli.loc[cli.os_status == 'DECEASED'].copy()

In [None]:
cli = cli[['os']].copy()
cli

### Intersect clinical and mutataion data

In [None]:
df = cli.join(df, how='inner').copy()

In [None]:
df.head()

In [None]:
# Are there samples with mutataions?
count_mut_per_sample = df.sum(axis=1)
(count_mut_per_sample == 0).sum()

In [None]:
# Are there genes with zero or low number of mutataions?
count_mut_per_gene = df.sum(axis=0)
keep = count_mut_per_gene > 3
keep.sum()

In [None]:
# Only keep genes with 3 or more mutataions
keep_names = count_mut_per_gene[keep].index
df = df[keep_names].copy()
df.head()

In [None]:
df.sum(axis=0).min()

# Create dataset for model trainig

In [None]:
x = df.iloc[:,1:].to_numpy()
x

In [None]:
y = df.iloc[:,0].to_numpy()
y

Make sure the dimentions for `x` and `y` match

In [None]:
# Create list of genes
genes = list(df.columns[1:])
x.shape, y.shape, len(genes)

# Create Regressors and anlyze feature importance

In [None]:
from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, GradientBoostingRegressor
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import GridSearchCV

def fit_cv(model):
    ' Find best number of estimators for the model '
    param_grid = [{'n_estimators': [1, 3, 5, 10, 20, 30, 40, 50, 60, 80, 90, 100]}]
    gs = GridSearchCV(model, param_grid, cv=5)
    gs.fit(x, y)
    n = gs.best_params_['n_estimators']
    print(f"Best 'n_estimators'= {n}")
    return n

def importance(model):
    ' Show most important genes '
    model.fit(x,y)
    top = pd.Series(model.feature_importances_, genes).sort_values(ascending=False)
    print('Top genes:')
    print(top.head(10))

In [None]:
rf = RandomForestRegressor(n_jobs=-1, random_state=42)
n = fit_cv(rf)

rf = RandomForestRegressor(n_estimators=n, n_jobs=-1, random_state=42)
importance(rf)

In [None]:
gb = GradientBoostingRegressor(random_state=42)
n = fit_cv(gb)

gb = GradientBoostingRegressor(n_estimators=n, random_state=42)
importance(gb)

In [None]:
et = ExtraTreesRegressor(n_jobs=-1, random_state=42)
n = fit_cv(et)

et = ExtraTreesRegressor(n_estimators=n, n_jobs=-1, random_state=42)
importance(et)