# 06 Region classifier

**Objectives:**
* train a classifier to predict Regions (`-->` multiclass)
* explain classifications in different regions with LIME

In [1]:
import pandas as pd
import numpy as np
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from boruta import BorutaPy

from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, ShuffleSplit

import lime 
import sklearn.datasets
from lime.lime_tabular import LimeTabularExplainer

import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
pandas2ri.activate()

sns.set(style="whitegrid")

---

## Load the data

In [2]:
wd = '/media/tmo/data/work/datasets/02_ST'

logcpm_path = wd + '/ashley_21.03.2018/logcpm_merge_20180212.pickle'
meta_path = wd + '/meta/meta.parquet'

In [3]:
%%time
meta_df = pd.read_parquet(meta_path)

CPU times: user 122 ms, sys: 274 ms, total: 396 ms
Wall time: 193 ms


In [4]:
%%time
logcpm_df = pickle.load(open(logcpm_path, "rb"))

logcpm_df.index.name = 'spot_UID'
logcpm_df.reset_index(inplace=True)
logcpm_df.rename(columns={'sampleID': 'slide_ID'}, inplace=True)

CPU times: user 16 s, sys: 6.89 s, total: 22.9 s
Wall time: 22.9 s


In [5]:
st_df = logcpm_df.merge(meta_df, how='inner', on=['spot_UID', 'slide_ID'])

In [6]:
st_df['slide_ID'] = st_df['slide_ID'].astype('category', copy=False)
st_df['GT'] = st_df['GT'].astype('category', copy=False)
st_df['age'] = st_df['age_GT'].astype('category', copy=False)
st_df['age_GT'] = st_df['age_GT'].astype('category', copy=False)

In [7]:
n_genes = 46454
gene_columns = st_df.columns[1:n_genes+1]

In [8]:
expression_df = st_df[gene_columns]

In [9]:
assert expression_df.shape == (10327, 46454)

In [108]:
region_levels_df = pd.read_csv('region_levels.csv', sep=',', keep_default_na=False)

In [110]:
genotype_df = st_df[['GT']]

In [111]:
region_df = st_df[['Region_predict']].copy()

In [112]:
region_df = region_df.merge(region_levels_df, on='Region_predict')

In [113]:
region_cat_df = region_df['Region_predict'].astype('category', copy=False)
region_level1_df = region_df['Level_01'].astype('category', copy=False)
region_level2_df = region_df['Level_02'].astype('category', copy=False)

---



In [120]:
region_df['Level_01'].value_counts()

CX    4104
BS    3028
NA    1651
HP    1544
Name: Level_01, dtype: int64

In [119]:
region_df['Level_02'].value_counts()

TH       1931
FB       1155
HPd      1122
HY       1097
AUD       860
OLF       749
CTXsp     709
PTL       585
SSp       532
NA        496
HPs       422
COM       295
RSP       262
ENTI      112
Name: Level_02, dtype: int64

---
## Extract *all-relevant* feature set: Region

In [25]:
boruta_rf = RandomForestClassifier(n_jobs=-1, n_estimators=1000, max_features='sqrt', max_depth=5)

def train_feature_selector(X_df=expression_df,  # the transcriptome expression vectors
                           y_df=region_df,      # the Region column
                           estimator=boruta_rf, verbose=2, seed=42):  # boruta parameters
    feature_selector = BorutaPy(estimator=estimator, verbose=verbose, random_state=seed, n_estimators='auto')    
    
    X = X_df.as_matrix()
    y = y_df.values.ravel()
    feature_selector.fit(X, y)
    
    return feature_selector

In [None]:
Region_feature_selector = train_feature_selector(y_df = region_df)

In [27]:
Region_features = list(gene_columns[Region_feature_selector.support_])

In [30]:
pd.DataFrame(Region_features).to_csv('06_region_features.txt', index=None, header=None)

## Extract *all-relevant* feature set: Level 01

In [None]:
level1_feature_selector = train_feature_selector(y_df = region_level1_df)

In [124]:
Level1_features = list(gene_columns[level1_feature_selector.support_])

In [126]:
pd.DataFrame(Level1_features).to_csv('06_level1_features.txt', index=None, header=None)

## Extract *all-relevant* feature set: Level 02

In [None]:
level2_feature_selector = train_feature_selector(y_df = region_level2_df)

Iteration: 	1 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	2 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	3 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	4 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	5 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	6 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	7 / 100
Confirmed: 	0
Tentative: 	46454
Rejected: 	0
Iteration: 	8 / 100
Confirmed: 	0
Tentative: 	4886
Rejected: 	41568


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	9 / 100
Confirmed: 	1274
Tentative: 	3612
Rejected: 	41568


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	10 / 100
Confirmed: 	1274
Tentative: 	3612
Rejected: 	41568


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	11 / 100
Confirmed: 	1274
Tentative: 	3612
Rejected: 	41568


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	12 / 100
Confirmed: 	1357
Tentative: 	2762
Rejected: 	42335


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	13 / 100
Confirmed: 	1357
Tentative: 	2762
Rejected: 	42335


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	14 / 100
Confirmed: 	1357
Tentative: 	2762
Rejected: 	42335


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	15 / 100
Confirmed: 	1357
Tentative: 	2762
Rejected: 	42335


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	16 / 100
Confirmed: 	1371
Tentative: 	2357
Rejected: 	42726


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	17 / 100
Confirmed: 	1371
Tentative: 	2357
Rejected: 	42726


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	18 / 100
Confirmed: 	1371
Tentative: 	2357
Rejected: 	42726


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	19 / 100
Confirmed: 	1377
Tentative: 	2030
Rejected: 	43047


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	20 / 100
Confirmed: 	1377
Tentative: 	2030
Rejected: 	43047


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	21 / 100
Confirmed: 	1377
Tentative: 	2030
Rejected: 	43047


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	22 / 100
Confirmed: 	1377
Tentative: 	1806
Rejected: 	43271


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	23 / 100
Confirmed: 	1377
Tentative: 	1806
Rejected: 	43271


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	24 / 100
Confirmed: 	1377
Tentative: 	1806
Rejected: 	43271


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	25 / 100
Confirmed: 	1377
Tentative: 	1806
Rejected: 	43271


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	26 / 100
Confirmed: 	1378
Tentative: 	1606
Rejected: 	43470


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	27 / 100
Confirmed: 	1378
Tentative: 	1606
Rejected: 	43470


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	28 / 100
Confirmed: 	1378
Tentative: 	1606
Rejected: 	43470


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	29 / 100
Confirmed: 	1378
Tentative: 	1392
Rejected: 	43684


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	30 / 100
Confirmed: 	1378
Tentative: 	1392
Rejected: 	43684


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	31 / 100
Confirmed: 	1378
Tentative: 	1392
Rejected: 	43684


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	32 / 100
Confirmed: 	1378
Tentative: 	1208
Rejected: 	43868


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	33 / 100
Confirmed: 	1378
Tentative: 	1208
Rejected: 	43868


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	34 / 100
Confirmed: 	1378
Tentative: 	989
Rejected: 	44087


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	35 / 100
Confirmed: 	1378
Tentative: 	989
Rejected: 	44087


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	36 / 100
Confirmed: 	1378
Tentative: 	989
Rejected: 	44087


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	37 / 100
Confirmed: 	1378
Tentative: 	813
Rejected: 	44263


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	38 / 100
Confirmed: 	1378
Tentative: 	813
Rejected: 	44263


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	39 / 100
Confirmed: 	1378
Tentative: 	813
Rejected: 	44263


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	40 / 100
Confirmed: 	1378
Tentative: 	660
Rejected: 	44416


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	41 / 100
Confirmed: 	1378
Tentative: 	660
Rejected: 	44416


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	42 / 100
Confirmed: 	1378
Tentative: 	660
Rejected: 	44416


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	43 / 100
Confirmed: 	1379
Tentative: 	558
Rejected: 	44517


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	44 / 100
Confirmed: 	1379
Tentative: 	558
Rejected: 	44517


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	45 / 100
Confirmed: 	1379
Tentative: 	558
Rejected: 	44517


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	46 / 100
Confirmed: 	1380
Tentative: 	476
Rejected: 	44598


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	47 / 100
Confirmed: 	1380
Tentative: 	476
Rejected: 	44598


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	48 / 100
Confirmed: 	1380
Tentative: 	476
Rejected: 	44598


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	49 / 100
Confirmed: 	1380
Tentative: 	432
Rejected: 	44642


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	50 / 100
Confirmed: 	1380
Tentative: 	432
Rejected: 	44642


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	51 / 100
Confirmed: 	1380
Tentative: 	381
Rejected: 	44693


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	52 / 100
Confirmed: 	1380
Tentative: 	381
Rejected: 	44693


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	53 / 100
Confirmed: 	1380
Tentative: 	381
Rejected: 	44693


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	54 / 100
Confirmed: 	1381
Tentative: 	343
Rejected: 	44730


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	55 / 100
Confirmed: 	1381
Tentative: 	343
Rejected: 	44730


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	56 / 100
Confirmed: 	1381
Tentative: 	343
Rejected: 	44730


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	57 / 100
Confirmed: 	1381
Tentative: 	293
Rejected: 	44780


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	58 / 100
Confirmed: 	1381
Tentative: 	293
Rejected: 	44780


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	59 / 100
Confirmed: 	1381
Tentative: 	274
Rejected: 	44799


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	60 / 100
Confirmed: 	1381
Tentative: 	274
Rejected: 	44799


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	61 / 100
Confirmed: 	1381
Tentative: 	274
Rejected: 	44799


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	62 / 100
Confirmed: 	1381
Tentative: 	252
Rejected: 	44821


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	63 / 100
Confirmed: 	1381
Tentative: 	252
Rejected: 	44821


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	64 / 100
Confirmed: 	1381
Tentative: 	252
Rejected: 	44821


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	65 / 100
Confirmed: 	1381
Tentative: 	234
Rejected: 	44839


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	66 / 100
Confirmed: 	1381
Tentative: 	234
Rejected: 	44839


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	67 / 100
Confirmed: 	1381
Tentative: 	216
Rejected: 	44857


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	68 / 100
Confirmed: 	1381
Tentative: 	216
Rejected: 	44857


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	69 / 100
Confirmed: 	1381
Tentative: 	216
Rejected: 	44857


  hits = np.where(cur_imp[0] > imp_sha_max)[0]


Iteration: 	70 / 100
Confirmed: 	1381
Tentative: 	207
Rejected: 	44866


In [None]:
Level2_features = list(gene_columns[level2_feature_selector.support_])

In [None]:
pd.DataFrame(Level2_features).to_csv('06_level2_features.txt', index=None, header=None)

# Multiclass training

* http://scikit-learn.org/stable/modules/multiclass.html