In [1]:
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import make_scorer, accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import argparse
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_selection import mutual_info_classif
from sklearn.model_selection import cross_validate
import os
import sys
from pathlib import Path
from datetime import datetime, timedelta
# from xgboost import XGBClassifier

from tqdm import tqdm


In [2]:
fdir = Path("/home/ar3/Documents/PYTHON/RNASeqAnalysis_backup")

raw_dataset_fname = "merged.TPM.txt"

n_features = 1000
need_a_log = True
non_zero_median = True

metadata_fname = "metadata_brain1.tsv"

output_dir = Path(fdir/"results/test_on_brain_age")
if (not output_dir.is_dir()):
    output_dir.mkdir()
    
n_threads = 6
cv_threshold = 0.7

exclude_chr = None
gtf = None

dataset_fname = Path(fdir/'merged_df_dataset.csv')


In [6]:
metadata_df = pd.read_csv(fdir/metadata_fname, sep="\t")
metadata_df['Descriptor'] = metadata_df['Descriptor'].astype(str)
metadata_df

Unnamed: 0,Run,Descriptor,gender
0,SRR19147434,40,female
1,SRR19147435,30,male
2,SRR19147436,30,male
3,SRR19147437,30,male
4,SRR19147438,40,male
...,...,...,...
210,SRR19147644,50,male
211,SRR19147645,20,female
212,SRR19147646,50,male
213,SRR19147647,20,female


In [7]:
df_raw = pd.read_csv(fdir/raw_dataset_fname, sep="\t")
df_raw = df_raw.T
df_raw 

Unnamed: 0,ENST00000000233.10,ENST00000000412.8,ENST00000000442.11,ENST00000001008.6,ENST00000001146.7,ENST00000002125.9,ENST00000002165.11,ENST00000002501.11,ENST00000002596.6,ENST00000002829.8,...,MSTRG.9996.9,MSTRG.9997.1,MSTRG.9997.10,MSTRG.9997.11,MSTRG.9997.2,MSTRG.9997.3,MSTRG.9997.4,MSTRG.9997.5,MSTRG.9997.6,MSTRG.9997.9
SRR19147434,32.607071,6.516344,0.000000,0.000000,1.925776,2.297619,0.924489,0.028084,0.734709,0.841659,...,9.845663,3.107049,0.0,1.003726,22.034277,0.073788,0.060702,0.0,1.682322,0.0
SRR19147435,31.636797,5.059402,0.000000,24.925554,1.240503,3.179465,3.767600,3.318774,0.699344,2.028128,...,21.757978,10.480646,0.0,1.479971,3.982772,0.055014,0.071976,0.0,1.218230,0.0
SRR19147436,20.091410,25.899878,0.000000,31.769285,1.113374,1.936185,2.293938,0.000000,0.677500,0.945981,...,14.312212,33.029228,0.0,0.963181,6.348445,0.118533,0.026810,0.0,1.360862,0.0
SRR19147437,43.278019,4.217986,0.000000,40.835411,1.202843,3.431436,3.102470,0.000000,0.847226,1.183343,...,10.454411,20.480751,0.0,3.004943,3.611126,0.115716,5.284467,0.0,2.483386,0.0
SRR19147438,35.062801,9.067139,0.000000,27.680275,1.084882,6.995416,2.060208,0.000000,0.899927,1.236574,...,12.893114,20.962542,0.0,3.714251,2.127948,0.180802,0.200934,0.0,1.111884,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR19147644,27.600822,7.181911,4.369801,16.823191,1.110326,1.101212,1.929206,3.881400,0.391131,1.276104,...,9.676552,8.877145,0.0,1.016816,6.069803,0.053570,0.850967,0.0,1.943223,0.0
SRR19147645,14.374123,16.623938,5.323976,15.089913,1.211220,1.694511,1.864122,1.507361,0.451201,1.113731,...,7.817225,4.729078,0.0,0.000000,8.838813,0.009877,0.022259,0.0,1.065074,0.0
SRR19147646,25.151016,22.907345,6.102655,22.421217,1.192796,2.952276,0.000000,4.445130,0.823767,1.800524,...,12.571911,4.146348,0.0,1.518792,7.946578,0.065910,1.171079,0.0,1.820326,0.0
SRR19147647,18.360344,15.833601,0.000000,7.729653,2.054185,3.576194,0.000000,4.050052,0.522655,2.959796,...,26.704006,4.111988,0.0,0.000000,13.994611,0.020801,0.078161,0.0,1.657206,0.0


In [8]:
df = df_raw.copy(deep=True)
df.reset_index(inplace=True)
df = df.rename(columns={'index': 'Run'})
df.set_index('Run', inplace=True)
df

Unnamed: 0_level_0,ENST00000000233.10,ENST00000000412.8,ENST00000000442.11,ENST00000001008.6,ENST00000001146.7,ENST00000002125.9,ENST00000002165.11,ENST00000002501.11,ENST00000002596.6,ENST00000002829.8,...,MSTRG.9996.9,MSTRG.9997.1,MSTRG.9997.10,MSTRG.9997.11,MSTRG.9997.2,MSTRG.9997.3,MSTRG.9997.4,MSTRG.9997.5,MSTRG.9997.6,MSTRG.9997.9
Run,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR19147434,32.607071,6.516344,0.000000,0.000000,1.925776,2.297619,0.924489,0.028084,0.734709,0.841659,...,9.845663,3.107049,0.0,1.003726,22.034277,0.073788,0.060702,0.0,1.682322,0.0
SRR19147435,31.636797,5.059402,0.000000,24.925554,1.240503,3.179465,3.767600,3.318774,0.699344,2.028128,...,21.757978,10.480646,0.0,1.479971,3.982772,0.055014,0.071976,0.0,1.218230,0.0
SRR19147436,20.091410,25.899878,0.000000,31.769285,1.113374,1.936185,2.293938,0.000000,0.677500,0.945981,...,14.312212,33.029228,0.0,0.963181,6.348445,0.118533,0.026810,0.0,1.360862,0.0
SRR19147437,43.278019,4.217986,0.000000,40.835411,1.202843,3.431436,3.102470,0.000000,0.847226,1.183343,...,10.454411,20.480751,0.0,3.004943,3.611126,0.115716,5.284467,0.0,2.483386,0.0
SRR19147438,35.062801,9.067139,0.000000,27.680275,1.084882,6.995416,2.060208,0.000000,0.899927,1.236574,...,12.893114,20.962542,0.0,3.714251,2.127948,0.180802,0.200934,0.0,1.111884,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR19147644,27.600822,7.181911,4.369801,16.823191,1.110326,1.101212,1.929206,3.881400,0.391131,1.276104,...,9.676552,8.877145,0.0,1.016816,6.069803,0.053570,0.850967,0.0,1.943223,0.0
SRR19147645,14.374123,16.623938,5.323976,15.089913,1.211220,1.694511,1.864122,1.507361,0.451201,1.113731,...,7.817225,4.729078,0.0,0.000000,8.838813,0.009877,0.022259,0.0,1.065074,0.0
SRR19147646,25.151016,22.907345,6.102655,22.421217,1.192796,2.952276,0.000000,4.445130,0.823767,1.800524,...,12.571911,4.146348,0.0,1.518792,7.946578,0.065910,1.171079,0.0,1.820326,0.0
SRR19147647,18.360344,15.833601,0.000000,7.729653,2.054185,3.576194,0.000000,4.050052,0.522655,2.959796,...,26.704006,4.111988,0.0,0.000000,13.994611,0.020801,0.078161,0.0,1.657206,0.0


In [9]:
def filter_by_non_zero_median(df):
    print(df.shape)

    if (df.median() == 0).any():
        cols_to_drop = df.columns[df.median() == 0]
        print(len(cols_to_drop),
              " features will be removed, due to a zero median value")
        df = df.drop(columns=cols_to_drop)
        print("Current dataset size: ", df.shape)
        return df

    print("Zero median columns aren't found")
    print('Dataset shape: ', df.shape)
    return df

df = filter_by_non_zero_median(df).astype('float32')
df

(215, 380224)
234141  features will be removed, due to a zero median value
Current dataset size:  (215, 146083)


Unnamed: 0_level_0,ENST00000000233.10,ENST00000000412.8,ENST00000000442.11,ENST00000001008.6,ENST00000001146.7,ENST00000002125.9,ENST00000002165.11,ENST00000002501.11,ENST00000002596.6,ENST00000002829.8,...,MSTRG.9995.2,MSTRG.9996.5,MSTRG.9996.8,MSTRG.9996.9,MSTRG.9997.1,MSTRG.9997.11,MSTRG.9997.2,MSTRG.9997.3,MSTRG.9997.4,MSTRG.9997.6
Run,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR19147434,32.607071,6.516344,0.000000,0.000000,1.925776,2.297619,0.924489,0.028084,0.734709,0.841659,...,28.696077,2.778211,0.000000,9.845663,3.107049,1.003726,22.034277,0.073788,0.060702,1.682322
SRR19147435,31.636797,5.059402,0.000000,24.925554,1.240503,3.179465,3.767600,3.318774,0.699344,2.028128,...,3.271569,2.012823,0.273076,21.757978,10.480646,1.479971,3.982772,0.055014,0.071976,1.218230
SRR19147436,20.091410,25.899878,0.000000,31.769285,1.113374,1.936185,2.293938,0.000000,0.677500,0.945981,...,27.078547,4.232101,1.969063,14.312212,33.029228,0.963181,6.348445,0.118533,0.026810,1.360862
SRR19147437,43.278019,4.217986,0.000000,40.835411,1.202843,3.431436,3.102470,0.000000,0.847226,1.183343,...,0.000000,3.075821,1.090033,10.454411,20.480751,3.004943,3.611126,0.115716,5.284467,2.483386
SRR19147438,35.062801,9.067139,0.000000,27.680275,1.084882,6.995416,2.060208,0.000000,0.899927,1.236574,...,20.197245,0.000000,0.000000,12.893114,20.962542,3.714251,2.127948,0.180802,0.200934,1.111884
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR19147644,27.600822,7.181911,4.369801,16.823191,1.110326,1.101212,1.929206,3.881400,0.391131,1.276104,...,23.135386,5.800527,0.166172,9.676552,8.877145,1.016816,6.069803,0.053570,0.850967,1.943223
SRR19147645,14.374123,16.623938,5.323976,15.089913,1.211220,1.694511,1.864122,1.507361,0.451201,1.113731,...,22.134399,2.125863,0.079778,7.817225,4.729078,0.000000,8.838813,0.009877,0.022259,1.065074
SRR19147646,25.151016,22.907345,6.102655,22.421217,1.192796,2.952276,0.000000,4.445130,0.823767,1.800524,...,3.985238,1.097599,0.377483,12.571911,4.146348,1.518792,7.946578,0.065910,1.171079,1.820326
SRR19147647,18.360344,15.833601,0.000000,7.729653,2.054185,3.576194,0.000000,4.050052,0.522655,2.959796,...,14.703023,0.000000,0.000000,26.704006,4.111988,0.000000,13.994611,0.020801,0.078161,1.657206


In [10]:
# df = log_a_table(df)  # zeros are replaced by 0.000001
numerical_cols = df.iloc[:1].select_dtypes(include=[np.number]).columns
df = df.replace(0, 1e-6)

df = np.log(df)
df

Unnamed: 0_level_0,ENST00000000233.10,ENST00000000412.8,ENST00000000442.11,ENST00000001008.6,ENST00000001146.7,ENST00000002125.9,ENST00000002165.11,ENST00000002501.11,ENST00000002596.6,ENST00000002829.8,...,MSTRG.9995.2,MSTRG.9996.5,MSTRG.9996.8,MSTRG.9996.9,MSTRG.9997.1,MSTRG.9997.11,MSTRG.9997.2,MSTRG.9997.3,MSTRG.9997.4,MSTRG.9997.6
Run,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR19147434,3.484529,1.874313,-13.815511,-13.815511,0.655329,0.831873,-0.078514,-3.572555,-0.308281,-0.172380,...,3.356760,1.021807,-13.815511,2.287031,1.133673,0.003719,3.092599,-2.606559,-2.801779,0.520175
SRR19147435,3.454321,1.621248,-13.815511,3.215894,0.215517,1.156713,1.326438,1.199595,-0.357613,0.707113,...,1.185270,0.699538,-1.298005,3.079981,2.349530,0.392023,1.381978,-2.900168,-2.631423,0.197399
SRR19147436,3.000292,3.254238,-13.815511,3.458500,0.107395,0.660720,0.830270,-13.815511,-0.389346,-0.055533,...,3.298742,1.442699,0.677558,2.661113,3.497393,-0.037514,1.848210,-2.132564,-3.618980,0.308118
SRR19147437,3.767645,1.439358,-13.815511,3.709550,0.184688,1.232979,1.132199,-13.815511,-0.165788,0.168344,...,-13.815511,1.123572,0.086208,2.347024,3.019485,1.100259,1.284020,-2.156616,1.664772,0.909623
SRR19147438,3.557141,2.204657,-13.815511,3.320720,0.081471,1.945255,0.722807,-13.815511,-0.105442,0.212345,...,3.005546,-13.815511,-13.815511,2.556693,3.042737,1.312177,0.755158,-1.710353,-1.604779,0.106056
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR19147644,3.317846,1.971566,1.474717,2.822758,0.104654,0.096411,0.657109,1.356196,-0.938713,0.243812,...,3.141363,1.757949,-1.794732,2.269706,2.183480,0.016676,1.803326,-2.926766,-0.161382,0.664348
SRR19147645,2.665430,2.810844,1.672220,2.714027,0.191628,0.527394,0.622790,0.410360,-0.795842,0.107716,...,3.097133,0.754178,-2.528507,2.056330,1.553730,-13.815511,2.179153,-4.617546,-3.805009,0.063044
SRR19147646,3.224898,3.131458,1.808724,3.110008,0.176300,1.082576,-13.815511,1.491809,-0.193868,0.588078,...,1.382597,0.093125,-0.974230,2.531465,1.422228,0.417915,2.072741,-2.719465,0.157926,0.599016
SRR19147647,2.910193,2.762134,-13.815511,2.045064,0.719879,1.274299,-13.815511,1.398730,-0.648834,1.085120,...,2.688053,-13.815511,-13.815511,3.284814,1.413907,-13.815511,2.638672,-3.872754,-2.548984,0.505133


In [11]:
def filter_by_cv(df, threshold):
    cv = df.std() / df.mean()
    # print(cv)
    low_cv_cols = cv[cv < threshold].index

    if len(low_cv_cols) > 0:
        print(f"{len(low_cv_cols)} features have coefficient of variation below {threshold} and will be removed.")
        df = df.drop(columns=low_cv_cols)
    else:
        print("No features found with coefficient of variation below the threshold.")
    print(f"Current amount of features is {len(df.columns)}")
    return df

cv_threshold = 0.7
df = filter_by_cv(df, cv_threshold)
df

120857 features have coefficient of variation below 0.7 and will be removed.
Current amount of features is 25226


Unnamed: 0_level_0,ENST00000001008.6,ENST00000001146.7,ENST00000005257.7,ENST00000005259.9,ENST00000005286.8,ENST00000005386.8,ENST00000009105.5,ENST00000009589.8,ENST00000011691.6,ENST00000012443.9,...,MSTRG.9921.7,MSTRG.9952.7,MSTRG.9967.8,MSTRG.9981.2,MSTRG.9990.1,MSTRG.9995.1,MSTRG.9995.2,MSTRG.9997.1,MSTRG.9997.2,MSTRG.9997.6
Run,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SRR19147434,-13.815511,0.655329,3.042406,1.065409,1.266377,2.597472,-0.111890,0.084826,0.790239,2.619894,...,0.583503,2.154768,-0.805602,-0.049208,0.042149,3.662991,3.356760,1.133673,3.092599,0.520175
SRR19147435,3.215894,0.215517,-13.815511,1.691719,2.378844,2.602901,1.652716,1.091419,1.396960,2.440174,...,0.224819,0.709243,-0.252296,4.652913,0.286916,2.194798,1.185270,2.349530,1.381978,0.197399
SRR19147436,3.458500,0.107395,3.312534,0.367056,-0.514103,3.017072,-1.296143,0.277070,1.087991,3.071359,...,0.563291,1.632670,0.571868,4.247298,-0.091506,4.561057,3.298742,3.497393,1.848210,0.308118
SRR19147437,3.709550,0.184688,3.303985,1.951370,-13.815511,2.870057,3.214494,-0.107841,1.168506,2.693609,...,0.634680,1.960360,2.120228,-0.173830,1.319177,4.006963,-13.815511,3.019485,1.284020,0.909623
SRR19147438,3.320720,0.081471,3.141080,-1.423519,-13.815511,2.440230,2.055452,-0.535474,0.755853,2.689106,...,0.353791,2.705550,2.028284,4.693324,1.544473,3.910051,3.005546,3.042737,0.755158,0.106056
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR19147644,2.822758,0.104654,2.676821,2.029124,1.918129,1.979167,-0.778282,2.913702,1.095942,2.196946,...,-0.959365,2.662783,1.019860,3.892203,0.835619,3.322631,3.141363,2.183480,1.803326,0.664348
SRR19147645,2.714027,0.191628,2.489765,1.247616,1.583653,1.723282,-2.952903,1.354266,1.375931,1.927883,...,-0.899230,-0.245448,0.757098,-2.049223,1.503681,-1.110847,3.097133,1.553730,2.179153,0.063044
SRR19147646,3.110008,0.176300,-13.815511,1.184247,2.329522,2.030454,-1.001764,-0.221836,1.381422,1.739521,...,-0.059158,-1.656178,-0.614569,3.863475,0.383602,-13.815511,1.382597,1.422228,2.072741,0.599016
SRR19147647,2.045064,0.719879,-0.886084,1.807603,2.550754,2.436864,1.895173,1.789700,1.264843,2.173031,...,0.526209,1.951159,-0.247303,1.539385,1.060566,1.477525,2.688053,1.413907,2.638672,0.505133


In [12]:
dataset = pd.merge(df, metadata_df, on='Run', how='inner')
dataset.to_csv(output_dir/'merged_df_dataset.csv')
dataset

Unnamed: 0,Run,ENST00000001008.6,ENST00000001146.7,ENST00000005257.7,ENST00000005259.9,ENST00000005286.8,ENST00000005386.8,ENST00000009105.5,ENST00000009589.8,ENST00000011691.6,...,MSTRG.9967.8,MSTRG.9981.2,MSTRG.9990.1,MSTRG.9995.1,MSTRG.9995.2,MSTRG.9997.1,MSTRG.9997.2,MSTRG.9997.6,Descriptor,gender
0,SRR19147434,-13.815511,0.655329,3.042406,1.065409,1.266377,2.597472,-0.111890,0.084826,0.790239,...,-0.805602,-0.049208,0.042149,3.662991,3.356760,1.133673,3.092599,0.520175,40,female
1,SRR19147435,3.215894,0.215517,-13.815511,1.691719,2.378844,2.602901,1.652716,1.091419,1.396960,...,-0.252296,4.652913,0.286916,2.194798,1.185270,2.349530,1.381978,0.197399,30,male
2,SRR19147436,3.458500,0.107395,3.312534,0.367056,-0.514103,3.017072,-1.296143,0.277070,1.087991,...,0.571868,4.247298,-0.091506,4.561057,3.298742,3.497393,1.848210,0.308118,30,male
3,SRR19147437,3.709550,0.184688,3.303985,1.951370,-13.815511,2.870057,3.214494,-0.107841,1.168506,...,2.120228,-0.173830,1.319177,4.006963,-13.815511,3.019485,1.284020,0.909623,30,male
4,SRR19147438,3.320720,0.081471,3.141080,-1.423519,-13.815511,2.440230,2.055452,-0.535474,0.755853,...,2.028284,4.693324,1.544473,3.910051,3.005546,3.042737,0.755158,0.106056,40,male
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210,SRR19147644,2.822758,0.104654,2.676821,2.029124,1.918129,1.979167,-0.778282,2.913702,1.095942,...,1.019860,3.892203,0.835619,3.322631,3.141363,2.183480,1.803326,0.664348,50,male
211,SRR19147645,2.714027,0.191628,2.489765,1.247616,1.583653,1.723282,-2.952903,1.354266,1.375931,...,0.757098,-2.049223,1.503681,-1.110847,3.097133,1.553730,2.179153,0.063044,20,female
212,SRR19147646,3.110008,0.176300,-13.815511,1.184247,2.329522,2.030454,-1.001764,-0.221836,1.381422,...,-0.614569,3.863475,0.383602,-13.815511,1.382597,1.422228,2.072741,0.599016,50,male
213,SRR19147647,2.045064,0.719879,-0.886084,1.807603,2.550754,2.436864,1.895173,1.789700,1.264843,...,-0.247303,1.539385,1.060566,1.477525,2.688053,1.413907,2.638672,0.505133,20,female


In [None]:
del df

In [13]:
dataset = pd.read_csv(output_dir/'merged_df_dataset.csv', index_col=0)
dataset

Unnamed: 0,Run,ENST00000001008.6,ENST00000001146.7,ENST00000005257.7,ENST00000005259.9,ENST00000005286.8,ENST00000005386.8,ENST00000009105.5,ENST00000009589.8,ENST00000011691.6,...,MSTRG.9967.8,MSTRG.9981.2,MSTRG.9990.1,MSTRG.9995.1,MSTRG.9995.2,MSTRG.9997.1,MSTRG.9997.2,MSTRG.9997.6,Descriptor,gender
0,SRR19147434,-13.815511,0.655329,3.042406,1.065409,1.266377,2.597472,-0.111890,0.084826,0.790239,...,-0.805602,-0.049208,0.042149,3.662991,3.356760,1.133673,3.092599,0.520175,40,female
1,SRR19147435,3.215894,0.215517,-13.815511,1.691719,2.378844,2.602901,1.652716,1.091419,1.396960,...,-0.252296,4.652913,0.286916,2.194798,1.185270,2.349530,1.381978,0.197399,30,male
2,SRR19147436,3.458500,0.107395,3.312534,0.367056,-0.514103,3.017072,-1.296143,0.277070,1.087991,...,0.571868,4.247298,-0.091506,4.561057,3.298742,3.497393,1.848210,0.308118,30,male
3,SRR19147437,3.709550,0.184688,3.303985,1.951370,-13.815511,2.870057,3.214494,-0.107841,1.168506,...,2.120228,-0.173830,1.319177,4.006963,-13.815511,3.019485,1.284020,0.909623,30,male
4,SRR19147438,3.320720,0.081471,3.141080,-1.423519,-13.815511,2.440230,2.055452,-0.535474,0.755853,...,2.028284,4.693324,1.544473,3.910051,3.005546,3.042737,0.755158,0.106056,40,male
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210,SRR19147644,2.822758,0.104654,2.676821,2.029124,1.918129,1.979167,-0.778282,2.913702,1.095942,...,1.019860,3.892203,0.835619,3.322631,3.141363,2.183480,1.803326,0.664348,50,male
211,SRR19147645,2.714027,0.191628,2.489765,1.247616,1.583653,1.723282,-2.952903,1.354266,1.375931,...,0.757098,-2.049223,1.503681,-1.110847,3.097133,1.553730,2.179153,0.063044,20,female
212,SRR19147646,3.110008,0.176300,-13.815511,1.184247,2.329522,2.030454,-1.001764,-0.221836,1.381422,...,-0.614569,3.863475,0.383602,-13.815511,1.382597,1.422228,2.072741,0.599016,50,male
213,SRR19147647,2.045064,0.719879,-0.886084,1.807603,2.550754,2.436864,1.895173,1.789700,1.264843,...,-0.247303,1.539385,1.060566,1.477525,2.688053,1.413907,2.638672,0.505133,20,female


In [None]:
def split_the_table(dataset, 
                    test_size=0.2, 
                    random_state=42):
    
    dataset.reset_index(inplace=True)
    dataset = dataset.drop(columns=['index'])
    index_column = dataset['Run']
    dataset.drop('Run', axis=1, inplace=True)
    
    X = dataset.drop(['Descriptor', 'gender'], axis=1)
    y = dataset['Descriptor']
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, shuffle=False)

    return X_train, X_test, y_train, y_test, index_column

X_train_init, X_test, y_train_init, y_test, index_column = split_the_table(dataset=dataset)


In [None]:
# Gradient_boosting_model(X_train, X_test, y_train,  y_test, output_dir, n_threads=n_threads)

from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
y_train_init_encoded = label_encoder.fit_transform(y_train_init)
# print(y_train)

X_train, X_val, y_train, y_val = train_test_split(X_train_init, y_train_init_encoded)

n_threads = 6

CBC = CatBoostClassifier(loss_function='MultiClass',
                         od_pval=0.05,
                         thread_count=n_threads,
                         task_type="CPU",
                         iterations=300,
                         learning_rate=0.03
                         #  devices='0'
                         )

CBC.fit(X_train, y_train, eval_set=(X_val, y_val), 
        verbose=False, 
        use_best_model=True, 
        plot=True, 
        early_stopping_rounds=20)


In [None]:
saved_model_filename = "catboost_model.cbm"
CBC.save_model(fname=output_dir/saved_model_filename, format='cbm', export_parameters=None, pool=None)

In [None]:
# CBC = CatBoostClassifier(loss_function='MultiClass',
#                          od_pval=0.05,
#                          thread_count=n_threads,
#                          task_type="CPU",
#                          iterations=300,
#                          learning_rate=0.03
#                          #  devices='0'
#                          )

# CBC.load_model(output_dir/saved_model_filename, format='cbm',)

In [None]:
pred = CBC.predict(X_test)
pred

In [None]:
report_path = output_dir/"gradient_boosting_classification_report.txt"

report = classification_report(label_encoder.fit_transform(y_test), 
                               pred)

# with open(report_path, "w") as file:
#     file.write(report)
    
report 

In [None]:
sorted_feature_importance = CBC.feature_importances_.argsort()[::-1]   
sorted_feature_importance

In [None]:
feature_importance_df = pd.DataFrame({
    'Feature': X_train.columns,
    'Importance': CBC.feature_importances_
    # 'Importance': XGB.feature_importances_
})
feature_importance_df = feature_importance_df.sort_values(
    by='Importance', ascending=False)

feature_importance_df.to_csv(output_dir/"catboost_feature_importance.csv", index=False, sep=",")

sns.barplot(x=CBC.feature_importances_[sorted_feature_importance[:20]],
            y=X_train.columns[sorted_feature_importance[:20]], orient='h')

In [None]:
dataset

In [None]:
def knn_feature_selection(df: pd.DataFrame, n_features):
    
    target = df["Descriptor"]
    label_encoder = LabelEncoder()
    target = label_encoder.fit_transform(target)
    
    feature_matrix = df.drop(["Descriptor", 'gender'], axis=1)
    mi_df = pd.DataFrame(columns=['Transcript', 'Mutual_Information'])

    feature_matrix.drop(columns=['index'], inplace=True)
    feature_matrix.set_index('Run', inplace=True)

    for column in tqdm(feature_matrix.columns):
        mi_score = mutual_info_classif(
            feature_matrix[column].values.reshape(-1, 1), target, random_state=42)[0]
        
        mi_df = pd.concat([mi_df, pd.DataFrame(
            {'Transcript': [column], 'Mutual_Information': [mi_score]})], ignore_index=True)

    mi_df_sorted = mi_df.sort_values(by='Mutual_Information', ascending=False)
    rows_to_drop = mi_df_sorted.loc[mi_df_sorted['Mutual_Information'] < 0.01].index
    mi_df_filtered = mi_df_sorted.drop(rows_to_drop)
    

    mi_df_filtered = mi_df_filtered.head(n_features)

    return mi_df_filtered


mi_df_filtered = knn_feature_selection(dataset, 
                                       n_features)

mi_table_path = output_dir/"mutual_information.csv"
mi_df_filtered.to_csv(mi_table_path, index=False)

mi_df_filtered


In [None]:
def cross_val_n_best_transcripts_model(n, mi_df_filtered, X, y):
    result_df = pd.DataFrame(columns=['number_of_best_transcripts', 
                                      'accuracy',
                                      'precision_macro', 
                                      'recall_macro', 
                                      'f1_macro', 
                                      'n_neighbors'])

    for i in tqdm(range(1, n + 1)):

        selected_transcripts = mi_df_filtered.head(i)['Transcript'].tolist()

        X_selected = X[selected_transcripts]

        best_accuracy = 0
        best_precision_macro = 0
        best_recall_macro = 0
        best_f1_macro = 0
        best_neighbors = 0

        for n_neighbors in range(1, 22, 2):
            knn_model = KNeighborsClassifier(n_neighbors=n_neighbors, algorithm="brute")
            
            scoring = {'accuracy': 'accuracy',
                       'precision_macro': 'precision_macro',
                       'recall_macro': 'recall_macro',
                       'f1_macro': 'f1_macro'}
    
            cv_results = cross_validate(
                knn_model, X_selected, y, cv=5, scoring=scoring)

            accuracy = cv_results['test_accuracy'].mean()
            precision_macro = cv_results['test_precision_macro'].mean()
            recall_macro = cv_results['test_recall_macro'].mean()
            f1_macro = cv_results['test_f1_macro'].mean()

            if f1_macro > best_f1_macro:
                best_accuracy = accuracy
                best_precision_macro = precision_macro
                best_recall_macro = recall_macro
                best_f1_macro = f1_macro
                best_neighbors = n_neighbors

        result_df = pd.concat([result_df, pd.DataFrame({'number_of_best_transcripts': [i],
                                                        'accuracy': [best_accuracy],
                                                        'precision_macro': [best_precision_macro],
                                                        'recall_macro': [best_recall_macro],
                                                        'f1_macro': [best_f1_macro],
                                                        'n_neighbors': [best_neighbors]})], ignore_index=True)
        print(str(i) + " best features knn model has been validated")

    return result_df

X_ = dataset.drop(columns=["Descriptor", 'gender'])
y_ = dataset["Descriptor"]

result_df = cross_val_n_best_transcripts_model(len(mi_df_filtered), mi_df_filtered, 
                                               X_, y_)

result_path = output_dir/"knn_model_validation_result.csv"
result_df.to_csv(result_path, index=False, sep="\t")

result_df

In [None]:
result_df