In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import entropy

In [12]:
# KL Divergence for each feature
def kl_divergence(feature_name, dataset):
    neg_feature_data = dataset[dataset['label'] == 0][feature_name] # 负样本数据
    pos_feature_data = dataset[dataset['label'] == 1][feature_name] # 正样本数据

    min_value = min(neg_feature_data.min(), pos_feature_data.min())
    max_value = max(neg_feature_data.max(), pos_feature_data.max())
    # min_value = pos_feature_data.min() if neg_feature_data.min() > pos_feature_data.min() else neg_feature_data.min()
    # max_value = neg_feature_data.max() if neg_feature_data.max() > pos_feature_data.max() else pos_feature_data.max()
    # c = 0
    # if(c==0):
    #     plt.hist(neg_feature_data, bins = 10, range=(min_value, max_value), density = True, alpha = 0.5, label = 'neg', color = 'red')
    #     plt.hist(pos_feature_data, bins = 10, range=(min_value, max_value), density = True, alpha = 0.5, label = 'pos', color = 'blue')
    #     plt.legend()
    #     plt.show()
    #     c = 1
    neg_hist, neg_bin = np.histogram(neg_feature_data, bins = 10, range=(min_value, max_value), density = True)
    pos_hist, pos_bin = np.histogram(pos_feature_data, bins = 10, range=(min_value, max_value), density = True)
    pos_hist[pos_hist == 0] = 1e-7
    neg_hist[neg_hist == 0] = 1e-7
    return entropy(neg_hist, pos_hist, base = 2)


In [13]:
def get_all_kl_divergence(df):
    features =[]
    feature_kl = []
    size = len(df.columns)
    i = 0
    # print(df[df.columns[19944]])
    for feature in df.columns: 
        if feature == 'label' or feature == 'file_id':
            continue
        features.append(feature)
        feature_kl.append(kl_divergence(feature, df))
        i = i + 1
        print('process:{} / {}'.format(i, feature))
    kl_feature_dataset = pd.DataFrame()
    kl_feature_dataset['feature_name'] = features
    kl_feature_dataset['kl_divergence'] = feature_kl
    final = kl_feature_dataset.sort_values('kl_divergence', ascending=False)
    return final

In [16]:
# Read the preprocessed data and calculate the KL divergence for each feature
df = pd.read_csv('TCGA_Labeled_Selected_GOA.csv', index_col = 0)
dd = get_all_kl_divergence(df)
choosed_features = dd[dd['kl_divergence'] > 3].reset_index(drop = True)
choosed_features


process:1 / OR4F29
process:2 / OR4F16
process:3 / SAMD11
process:4 / NOC2L
process:5 / KLHL17
process:6 / PERM1
process:7 / HES4
process:8 / ISG15
process:9 / RNF223
process:10 / C1orf159
process:11 / TTLL10
process:12 / TNFRSF18
process:13 / PUSL1
process:14 / INTS11
process:15 / TAS1R3
process:16 / DVL1
process:17 / MRPL20
process:18 / ANKRD65
process:19 / VWA1
process:20 / ATAD3B
process:21 / ATAD3A
process:22 / TMEM240
process:23 / FNDC10
process:24 / MIB2
process:25 / MMP23B
process:26 / CDK11B
process:27 / SLC35E2B
process:28 / CDK11A
process:29 / CALML6
process:30 / GABRD
process:31 / PRKCZ
process:32 / FAAP20
process:33 / SKI
process:34 / RER1
process:35 / PLCH2
process:36 / PANK4
process:37 / HES5
process:38 / PRXL2B
process:39 / TTC34
process:40 / ARHGEF16
process:41 / MEGF6
process:42 / TPRG1L
process:43 / WRAP73
process:44 / DFFB
process:45 / AJAP1
process:46 / NPHP4
process:47 / RPL22
process:48 / RNF207
process:49 / ICMT
process:50 / HES3
process:51 / GPR153
process:52 / 

Unnamed: 0,feature_name,kl_divergence
0,GYPE,5.867323
1,ADRB2,5.825714
2,CD5L,5.746667
3,ANGPT4,5.723198
4,GPM6A,5.692274
...,...,...
79,EPAS1,3.036145
80,ADRB1,3.013379
81,ARHGEF15,3.011971
82,ACTN2,3.004149


In [17]:
# Filtering the choosed 134 features from the original dataset
index_columns = np.append(choosed_features['feature_name'].values, ['label'])
df_choosed = df[index_columns]
df_choosed.to_csv('TCGA_Labeled_GOA_KL.csv')

df_choosed.columns


Index(['GYPE', 'ADRB2', 'CD5L', 'ANGPT4', 'GPM6A', 'NCKAP5', 'CD300LG', 'SGCG',
       'SH3GL3', 'CA4', 'CAVIN2', 'RTKN2', 'HBM', 'ACADL', 'TEK', 'AGER',
       'WNT3A', 'MYZAP', 'ITLN2', 'FHL5', 'ST8SIA6', 'LGI3', 'CCM2L', 'BTNL9',
       'ADH1B', 'FHL1', 'PTPN21', 'EMP2', 'CNTN6', 'CLDN18', 'HSPA12B',
       'LIMS2', 'SPOCK2', 'EMCN', 'S1PR1', 'STXBP6', 'FABP4', 'SOX7', 'VEGFD',
       'RXFP1', 'SFTPC', 'ACVRL1', 'SLC39A8', 'LDB2', 'OTC', 'CALCRL', 'KANK3',
       'RAMP3', 'SIRPB1', 'JPH4', 'SLC6A4', 'RS1', 'MGAT3', 'TAL1', 'ALKAL2',
       'LRRC36', 'GRIA1', 'SCUBE1', 'ANGPTL7', 'GRK5', 'VEPH1', 'PRKG2',
       'ANOS1', 'PRX', 'FAM110D', 'EFCC1', 'TENT5B', 'DPEP2', 'LAMP3',
       'TMEM100', 'SEMA3G', 'CLEC1A', 'CCDC141', 'SSMEM1', 'ABCA8', 'DKK2',
       'SLC4A1', 'CAV2', 'GIMAP8', 'EPAS1', 'ADRB1', 'ARHGEF15', 'ACTN2',
       'ARHGAP31', 'label'],
      dtype='object')