In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.model_selection import ShuffleSplit, GridSearchCV,train_test_split
from sklearn.model_selection import cross_validate

from sksurv.datasets import load_veterans_lung_cancer
from sksurv.column import encode_categorical
from sksurv.metrics import concordance_index_censored
from sksurv.svm import FastSurvivalSVM,FastKernelSurvivalSVM
from sksurv.kernels import clinical_kernel
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
def score_survival_model(model, X, y):
    prediction = model.predict(X)
    result = concordance_index_censored(y['Status'], y['Survival_in_days'], prediction)
    return result[0]

In [63]:
all_features = list(range(180)) 
clin_demo_comp = list(range(154,180)) 
clin_demo_cyto_gen_comp = list(range(1,180)) 
comp = list(range(163,180)) 
cyto_comp = list(range(85,154))+list(range(163,180)) 
cyto_gen_comp = list(range(1,154))+list(range(163,180)) 
eln_clin_demo_comp = [0]+list(range(154,180)) 
eln_cyto_comp = [0]+list(range(85,154))+list(range(163,180)) 
eln_cyto_gen_comp = list(range(154))+list(range(163,180)) 
eln_gen_comp = list(range(85))+list(range(163,180)) 
gen_comp = list(range(1,85))+list(range(163,180))
clin_comp = list(range(154,161))+list(range(163,180)) 
clin_cyto_comp = list(range(85,161))+list(range(163,180)) 
clin_gen_comp = list(range(1,85))+list(range(154,161))+list(range(163,180)) 
eln_clin_comp = [0]+list(range(154,161))+list(range(163,180))


#Without age
all_features_without_age = list(range(162))+list(range(163,180))
clin_demo_comp_without_age = list(range(154,162))+list(range(163,180))
clin_demo_cyto_gen_comp_without_age = list(range(1,162))+list(range(163,180))
eln_clin_demo_comp_without_age = [0]+list(range(154,162))+list(range(163,180))
            
            
eln_clin_gen = list(range(85))+list(range(154,161))  
eln_demo_gen = list(range(85))+[161,162] 
eln_clin_demo_cyto_gen =list(range(163)) 
eln_clin_demo_cyto = [0]+list(range(85,163))

eln_clin_demo_gen = list(range(85))+list(range(154,163))
eln_clin_demo = [0] + list(range(154,163))
eln_clin = [0] + list(range(154,161))
eln_cyto_gen = list(range(154))
clin_demo_cyto_gen = list(range(1,163))
clin_demo_cyto = list(range(85,163))
clin_demo_gen = list(range(1,85))+list(range(154,163)) 
clin_demo = list(range(154,163)) 
cyto_gen = list(range(1,154))
cyto = list(range(85,154))
gen = list(range(1,85))
clin_gen = list(range(1,85)) + list(range(154,161))  
clin_cyto = list(range(85,161))  
demo_gen = list(range(1,85)) + [161,162]
demo_cyto = list(range(85,154)) + [161,162]

###Without age:

eln_demo_gen_without_age = list(range(85)) + [161]
eln_clin_demo_cyto_gen_without_age = list(range(162))
eln_clin_demo_cyto_without_age = [0] + list(range(85,162))
eln_clin_demo_gen_without_age = list(range(85)) + list(range(154,162))
eln_clin_demo_without_age = [0] + list(range(154,162))
clin_demo_cyto_gen_without_age = list(range(1,162))
clin_demo_cyto_without_age = list(range(85,162))
clin_demo_gen_without_age = list(range(1,85)) + list(range(154,162)) 
clin_demo_without_age = list(range(154,162))
demo_gen_without_age = list(range(1,85)) + [161]
demo_cyto_without_age = list(range(85,154)) + [161]
gen_age = list(range(1,85)) + [162]
eln_comp=[0] + list(range(163,180))
eln_age=[0,162]
eln_gen=[0] + list(range(1,85))
eln_cyto=[0] + list(range(85,154))

dict_features_type_final_comp = dict(zip(("all_features","clin_demo_comp","clin_demo_cyto_gen_comp","comp","cyto_comp","cyto_gen_comp","eln_clin_demo_comp","eln_cyto_comp","eln_cyto_gen_comp",
                                        "eln_gen_comp","gen_comp","clin_comp","clin_cyto_comp","clin_gen_comp","eln_clin_comp","all_features_without_age","clin_demo_comp_without_age",
                                          "clin_demo_cyto_gen_comp_without_age","eln_clin_demo_comp_without_age","eln_clin_gen","eln_demo_gen","eln_clin_demo_cyto_gen","eln_clin_demo_cyto",
                                         "eln_clin_demo_gen","eln_clin_demo","eln_clin","eln_cyto_gen","clin_demo_cyto_gen","clin_demo_cyto","clin_demo_gen","clin_demo","cyto_gen","cyto","gen",
                                          "clin_gen","clin_cyto","demo_gen","demo_cyto","eln_demo_gen_without_age","eln_clin_demo_cyto_gen_without_age","eln_clin_demo_cyto_without_age",
                                          "eln_clin_demo_gen_without_age","eln_clin_demo_without_age","clin_demo_cyto_gen_without_age","clin_demo_cyto_without_age","clin_demo_gen_without_age",
                                          "clin_demo_without_age","demo_gen_without_age","demo_cyto_without_age","gen_age","eln_comp","eln_age","eln_gen","eln_cyto"),
                                         (all_features,clin_demo_comp,clin_demo_cyto_gen_comp,comp,cyto_comp,cyto_gen_comp,eln_clin_demo_comp,eln_cyto_comp,eln_cyto_gen_comp,
                                         eln_gen_comp,gen_comp,clin_comp,clin_cyto_comp,clin_gen_comp,eln_clin_comp,all_features_without_age,clin_demo_comp_without_age,
                                          clin_demo_cyto_gen_comp_without_age,eln_clin_demo_comp_without_age,eln_clin_gen,eln_demo_gen,eln_clin_demo_cyto_gen,eln_clin_demo_cyto,
                                         eln_clin_demo_gen,eln_clin_demo,eln_clin,eln_cyto_gen,clin_demo_cyto_gen,clin_demo_cyto,clin_demo_gen,clin_demo,cyto_gen,cyto,gen,
                                          clin_gen,clin_cyto,demo_gen,demo_cyto,eln_demo_gen_without_age,eln_clin_demo_cyto_gen_without_age,eln_clin_demo_cyto_without_age,
                                          eln_clin_demo_gen_without_age,eln_clin_demo_without_age,clin_demo_cyto_gen_without_age,clin_demo_cyto_without_age,clin_demo_gen_without_age,
                                          clin_demo_without_age,demo_gen_without_age,demo_cyto_without_age,gen_age,eln_comp,eln_age,eln_gen,eln_cyto)))
dicts= dict(zip(("all_features","clin_demo_comp","clin_demo_cyto_gen_comp","comp","cyto_comp","cyto_gen_comp","eln_clin_demo_comp","eln_cyto_comp","eln_cyto_gen_comp",
                                        "eln_gen_comp","gen_comp","clin_comp","clin_cyto_comp","clin_gen_comp","eln_clin_comp","all_features_without_age","clin_demo_comp_without_age"
                                          ),
                                         (all_features,clin_demo_comp,clin_demo_cyto_gen_comp,comp,cyto_comp,cyto_gen_comp,eln_clin_demo_comp,eln_cyto_comp,eln_cyto_gen_comp,
                                         eln_gen_comp,gen_comp,clin_comp,clin_cyto_comp,clin_gen_comp,eln_clin_comp,all_features_without_age,clin_demo_comp_without_age
                                          )))

In [64]:
df_final = pd.read_table("prognosis_comp_final.tsv")
ci=[]
df=pd.DataFrame(columns=["feature","ref_CI","permuted_CI","algo","model"])
for key,item in dicts.items():
    for j in range(25):    
        estimator = FastSurvivalSVM(max_iter=1000, tol=1e-6, random_state=j)
        param_grid = {'alpha': 10. ** np.array([-6,-5.5,-5,-4.5,-2.5,-1,0]),'optimizer':["avltree"]}
        cv = ShuffleSplit(n_splits=5,random_state=j)
        gcv = GridSearchCV(estimator, param_grid, scoring=score_survival_model,
                           n_jobs=50, iid=False, refit=True,
                           cv=cv)
        x = df_final.iloc[:,item]
        features = x.columns
        y = np.array(list(zip(df_final.os_status, df_final.os)),dtype=[('Status', '?'), ('Survival_in_days', '<f8')])    
        X_train, X_test, y_train, y_test = train_test_split(pd.DataFrame(x), y, test_size=0.2, random_state=j)
        gcv = gcv.fit(X_train,y_train)
        ref_ci = concordance_index_censored(y_test['Status'], y_test['Survival_in_days'], gcv.predict(X_test))[0]
        print(ref_ci)
        for i in range(4):
            for feature in features:
                X_test_permuted = X_test
                X_test_permuted[feature]=np.random.RandomState(seed=i).permutation(X_test_permuted[feature])      
                permuted_ci = concordance_index_censored(y_test['Status'], y_test['Survival_in_days'], gcv.predict(X_test_permuted))[0]
                df = df.append({'feature': feature, 'ref_CI': ref_ci, 'permuted_CI': permuted_ci, 'algo':'SVM_optimized', 'model': key}, ignore_index=True)

0.7000666518551433
0.7002752264217365
0.739146051974013
0.7234100482070501
0.7227515686730188
0.7140024455533891
0.7441734060250688
0.7278137853670696
0.7137123437235695
0.7233658188733064
0.7424926620004516
0.7194971537001897
0.7109888907985813
0.7369951584711287
0.7286878114371589
0.7028952242255394
0.7023152395316654
0.7198844023440636
0.7194383347677792
0.7062489959839358
0.7133946944440014
0.727200292154975
0.717731165056171
0.705044305977393
0.7272472327158152
0.7024629447424382
0.6998911895541972
0.7306190654672664
0.7204952033184387
0.7175392583247568
0.7156383224825672
0.7435804522919309
0.7269880344010968
0.697765612265297
0.7145867997781475


  self.best_estimator_.fit(X, y, **fit_params)


0.7299132342031416
0.7169196710942441
0.7082858080342495
0.7328651641943418
0.7291465633156687
0.7084039067323066
0.6904824841011926
0.7140402986272778
0.7112619935609321
0.7015421686746988
0.7218810317599579
0.7171076657481491
0.7110692120355094
0.6921282524971036
0.7153202791204077
0.7041292411210207
0.6961948347041316
0.7355072463768116
0.7194221560242797
0.7231665615351416
0.7132423411216497
0.7458863834763559
0.7233266857783871
0.7116517246151717
0.7262974407733143
0.7410895719769055
0.7157653383934219
0.715410696707864
0.73772216120899
0.7266471565293048
0.7036075538600352
0.698645685942189
0.7153728827165449
0.7234547830799146
0.7062650602409638
0.7137456332051875
0.7239467481159324
0.7196794720716474
0.7030247048877477
0.725119187053564
0.6222902846986385
0.6280842960924249
0.6501983383308346
0.6227037588686558
0.6128946582118787
0.637190587924254
0.6391794179170853
0.6134083260625701
0.6263436313029571
0.6136122335789557
0.6570896364867916
0.6526644528779254
0.6236933797909407

  self.best_estimator_.fit(X, y, **fit_params)


0.6331055071097494


  self.best_estimator_.fit(X, y, **fit_params)


0.5929329617684816
0.5990695986406902
0.6152124924619926
0.6290283867251256
0.6541807221389305
0.6159371546629511
0.6096328143155938
0.6385703427079547
0.6382817518488627
0.6157843076156051
0.629003982838426
0.6183979082481579
0.6368254685030481
0.6566492726122707
0.6222949641411853
0.6374035174557998
0.6166258008384086
0.629153277507796
0.606922395161165
0.640700008027615
0.614134072869848
0.6347871485943775
0.6181945795913159
0.6309219481424919
0.6245895199937151
0.5929094780348811
0.6136689816724129
0.6443092011299076
0.6450139213364483
0.6788090329835083
0.6569371706785823
0.6572324956010757
0.6507981096533263
0.6622799071039152
0.642566683285554
0.6511479493764321
0.6547183265985262
0.6817566041995936
0.6758776091081594
0.6622552772612927
0.6778604464106174
0.6657517994146959
0.6314723062067655
0.6287220313155971
0.6599020630970539
0.631682126805011
0.6611726907630522
0.6470034615323262
0.6669765280037183
0.668355723151858
0.629645865297304
0.639312756726439
0.6955597168883105
0.7

  self.best_estimator_.fit(X, y, **fit_params)


0.7368641744347322
0.7174573055028463
0.7104888986109592
0.7363609645934198
0.7316776081626196
0.7022778718756431
0.6923483588077061
0.7234326081721121
0.7138918109081636
0.7061847389558233
0.7156757963917114
0.7246605358387835
0.7131432162777909
0.699862228762877
0.717728764908692
0.61832291236868
0.6381252600249624


  self.best_estimator_.fit(X, y, **fit_params)


0.6676895927036481
0.6353961466391198
0.6257262375087148
0.6446346541524836
0.6399864938316341
0.6446544310108439
0.6373848590628797
0.6302273987798114
0.6470583491920137
0.6560879190385832
0.6330916704426494
0.64967748921097
0.6469508819109389


  self.best_estimator_.fit(X, y, **fit_params)


0.636371551137353
0.6239251784242688
0.6648229910893474
0.6093924325013548
0.6375100401606426
0.6297197275439073
0.6371966402177882
0.6379055699583628
0.6191408084666687


  self.best_estimator_.fit(X, y, **fit_params)


0.6202428281561887
0.6400799822261719
0.6502624251928185
0.6789417791104447
0.6624385400150546
0.6617891172271837
0.6494100928649328
0.6562927214929257
0.6513850803938677
0.6497870180996171
0.6484034545598606
0.6823613843821565
0.6729601518026566
0.6609427977687848
0.6744187845132948
0.6680851063829787
0.6346698747882799
0.6361699811857634
0.6669743919081641
0.6194255841382168
0.6614618473895583
0.6474102314600647
0.6626025032369444
0.6572158064262708
0.6350706077590256
0.6415232847786998
0.6355571777700194
0.6479021986110667
0.6824009870064968
0.6602684219798524
0.6588841671923243
0.6470388975180937
0.653887964686311
0.6537532718434501
0.6460886681326793
0.6505189763093258
0.6811437602812631
0.6722090449082859
0.6566928641739973
0.6700722362294854
0.6660839990508581
0.6339575451537841
0.6357734828106293
0.6631131090952878
0.6266296898409359
0.6608192771084337
0.6453125747738838
0.6611500282195146
0.6572000942729201
0.6341156025926042
0.638042527920289
0.6408893261814835
0.640157455115

  self.best_estimator_.fit(X, y, **fit_params)


0.6716329818628467
0.6855682082268185
0.6743499096551182
0.6505307323793719
0.6479651594384599
0.6483479861618053
0.6652798668672193
0.6991191904047976
0.6730408879065969
0.6788619235749145
0.6721636537889554
0.6825062177787294
0.6977439860401347
0.6675944578572637
0.6784248474764282
0.6947392187852789
0.6925838077166351
0.6745050858580335
0.6968243901684481
0.688839674127976
0.6586991277919364
0.6581095579431842
0.6893152444408767
0.6723566351088585
0.6721606425702811
0.677008725613744
0.6916603034427807
0.6853484170005499
0.6561511726210978
0.6664907042346458
0.6487129844161615
0.6554229206003777
0.6883433283358321
0.6627428370089207
0.654559941569005
0.6698337684655805
0.677400227298931
0.683550417549545
0.6597210475326393
0.6628951747088186
0.6817243492565236
0.6846299810246679
0.6577240980609678
0.6784327677148911
0.6837617654037808
0.6678960948506483
0.6424984062320215
0.6726499157100425
0.661709859424309
0.6645301204819277
0.6723667629089632
0.6868131868131868
0.6631314321627779