In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch
from Constants import *
from Preprocessing import *
from Models import *
import copy
from Utils import *
from SymptomPrediction import *

pd.set_option('display.max_rows', 200)



In [3]:
data = DTDataset(use_smote=False)
data.processed_df.T
data.get_input_state(1).shape
# data.processed_df#.shape, len(data.processed_df.index.unique())

(536, 61)

In [4]:
def load_mdasi_stuff():
    model = torch.load('../resources/symptomImputer.pt')
    mdasi = pd.read_excel('../data/mdasi_updated.xlsx').drop('Unnamed: 0',axis=1)
    return model, mdasi

sp, mdasi = load_mdasi_stuff()

In [7]:
mdasi

Unnamed: 0,ID,mdasidate_12_months_arm_6,mdasidate_18to24_months_arm_6,mdasidate_3to6_months_arm_6,mdasidate_60_months_arm_6,mdasidate_6_wks_after_primar_arm_6,mdasidate_baseline_arm_1,mdasidate_end_of_xrt_arm_3,mdasi_pain_12_months_arm_6,mdasi_pain_18to24_months_arm_6,...,os_flag,os_time (days),ICD 10 Coding,Stage Group Summary,Age at Diagnosis (Y),Patient Status Alive/Deceased,overall_survival(m),category_1,category_2,category_3
0,STIEFEL_1445,406.0,714.0,,1868.0,98.0,-36.0,45.0,6.0,3.0,...,0,1868,C10.9,Stage IVA,52.50,Alive,89.0,,,
1,STIEFEL_935,374.0,,,,95.0,-10.0,52.0,1.0,,...,0,374,C09.9,,67.72,Alive,25.0,,,
2,STIEFEL_1123,410.0,631.0,174.0,,93.0,-30.0,43.0,0.0,0.0,...,0,631,C10.9,Stage I,77.25,Alive,42.0,,,
3,STIEFEL_1223,,,,,,-98.0,,,,...,0,39,,,,,,,,
4,STIEFEL_1446,,,212.0,,86.0,-40.0,42.0,,,...,0,212,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1187,STIEFEL_1708,,,154.0,,,-11.0,,,,...,0,154,C09.8,Stage II,74.74,Alive,10.0,,,
1188,STIEFEL_1709,,,,,,-73.0,39.0,,,...,0,41,C09.0,,65.95,Alive,11.0,,,
1189,STIEFEL_1710,,,141.0,,86.0,-10.0,45.0,,,...,0,141,C10.9,Stage II,66.98,Alive,9.0,,,
1190,STIEFEL_1711,,,,,,-54.0,42.0,,,...,0,46,C01,Stage III,67.11,Alive,11.0,,,


In [11]:
def nan_mse_loss(ypred, y):
    #ignores loss in the autoencoder for missing values (-1 here)
    y = torch.flatten(y)
    ypred = torch.flatten(ypred)
    mask = torch.lt(y,-.1)
    out = (ypred[~mask] - y[~mask])**2
    loss = out.mean()
    return loss

def train_symptom_model(mdasi,lr=.0001,patience=1000,epochs=1000000,symptoms = None,save_file='../resources/symptomImputerTemp.pt'):

    mdasi_input=  process_mdasi_input(mdasi)

    train_ids = mdasi_input.reset_index().id.sample(frac=.66,replace=False).values
    test_ids = mdasi_input.drop(train_ids).reset_index().id.values
    
    xtrain = df_to_torch(mdasi_input.loc[train_ids])
    xtest = df_to_torch(mdasi_input.loc[test_ids])

    sdf,_,symptoms = get_symptom_df(mdasi)

    stopred= symptoms
    if symptoms is None:
        stopred = Const.prediction_symptoms

    ytrain = torch.from_numpy(sdf_symptom_array( sdf.loc[train_ids], stopred))
    ytest = torch.from_numpy(sdf_symptom_array( sdf.loc[test_ids], stopred))

    sp = SymptomPredictor(xtrain.shape[1],ytrain.shape[1],max_rating=10)
    sp.fit_normalizer(xtrain)
    optimizer = torch.optim.Adam(sp.parameters(),lr=lr)

    best_loss = 1000000000000000000000000
    steps_since_improvement=0
    patience=patience
    sp.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        ypred = sp(xtrain)

        loss = nan_mse_loss(ypred,ytrain)
        loss.backward()
        optimizer.step()
        ypred_test = sp(xtest)
        val_loss = nan_mse_loss(ypred_test,ytest)
        print(epoch,loss.item(),val_loss.item())
        steps_since_improvement+=1
        if val_loss.item() < best_loss:
            best_loss = val_loss.item()
            steps_since_improvement=0
            torch.save(sp.state_dict(),save_file)
        if steps_since_improvement > patience:
            break
    sp.load_state_dict(torch.load(save_file))
    sp.eval()
    print('best loss',best_loss,epoch-patience)
    return sp

sp = train_symptom_model(mdasi)
sp

0 8.880846559215314 8.80382490949617
1 8.86329206454978 8.799484444329673
2 8.839032613750003 8.786960527114
3 8.814365990341567 8.784115963554452
4 8.818710866092076 8.764999696401325
5 8.801119630710977 8.768903891792625
6 8.779019870445351 8.742751070960308
7 8.741162005074354 8.740702418869974
8 8.746627676648961 8.73208593603798
9 8.733825815757251 8.719594384768374
10 8.714151742048562 8.705037888986228
11 8.68046658338315 8.694096420684227
12 8.672988816429557 8.679369698202068
13 8.640732952139324 8.676568007561919
14 8.64250673049328 8.675213046813228
15 8.603232141813825 8.637774137977527
16 8.591030239797336 8.647544129776124
17 8.596009967956203 8.647135264848131
18 8.585201933544308 8.632169064596654
19 8.567278625665896 8.632807148133207
20 8.558737630324721 8.624079615617484
21 8.555694395861483 8.626445780521955
22 8.5184956065038 8.604049789420564
23 8.53213804496052 8.596067508618871
24 8.506128602727255 8.60986465731528
25 8.47894929452407 8.608408429572343
26 8.4891

224 7.165025671630236 7.971207038644015
225 7.127388806133201 7.979372422596367
226 7.127726977119538 7.953447860235786
227 7.139747639388836 7.975456192940144
228 7.133932403907365 7.9837743848770515
229 7.077776559917548 7.982040371462535
230 7.106445351123383 7.961525252395544
231 7.064968525418617 7.931596054162196
232 7.075136738888653 7.978003520071235
233 7.051178108345696 7.994902989175238
234 7.112699746118657 7.954837939453794
235 7.048366965119098 7.971523225815744
236 7.060284232071706 7.951846949773249
237 7.060002359075383 7.937777454756375
238 7.046689953266229 7.9734873382562705
239 7.051483388682941 7.961313252286921
240 7.031776835911727 7.949130125181928
241 7.023304726669551 7.950919775768332
242 7.076433459638803 7.961049520125786
243 7.01225952466099 7.958339736150319
244 7.030521336098477 7.929680075223304
245 7.019974068399139 7.922310388847531
246 7.016101710459802 7.945167707874945
247 7.0097373913335606 7.9337464141246965
248 6.993917321894337 7.8992967662892

429 5.852181416366583 7.691248855081872
430 5.860743493194952 7.730590360901968
431 5.812915262139611 7.761397076768242
432 5.831672709060357 7.753016682330119
433 5.814771455048039 7.6889035191001796
434 5.81521280579164 7.687721395930187
435 5.7921971923421465 7.754199313141395
436 5.834975077607032 7.7134249042238725
437 5.803273965427777 7.715175292086178
438 5.799697907528922 7.814677277941495
439 5.800818425641484 7.725488807423776
440 5.7526962625255145 7.717438110973254
441 5.778638143572385 7.757594306532205
442 5.787014877752027 7.690822527234549
443 5.76496257472196 7.727990575195364
444 5.755869685210238 7.718285199010489
445 5.794254031626878 7.735383991995844
446 5.765402615812 7.686730270447302
447 5.757303431927384 7.729957610130547
448 5.707600318411946 7.696545638083589
449 5.714595173377362 7.74367140186035
450 5.71875244379536 7.732350087633984
451 5.756011997876249 7.730287828350975
452 5.694001630760377 7.736276044392673
453 5.685847479214097 7.744636946368561
454

635 4.758221891091522 7.486862384365681
636 4.730761094925398 7.5005604285633884
637 4.748023629130283 7.46428563439566
638 4.717307909889811 7.48328134334635
639 4.69069759223059 7.49286108391135
640 4.709639121029587 7.5143347350341285
641 4.730453731981596 7.525758815585663
642 4.726274315264201 7.496169732609822
643 4.739577309508265 7.451687029262673
644 4.696898959590144 7.5360624817455255
645 4.682870095028927 7.451950163845363
646 4.718700868793459 7.47283543764622
647 4.693554204497099 7.457464155945159
648 4.665416398717682 7.426751671551498
649 4.677913881535094 7.438242560910746
650 4.685713264025322 7.5352501895161215
651 4.669380319765855 7.427358172373925
652 4.675341230828627 7.504855360570425
653 4.657248869330743 7.44671523838085
654 4.705197064113562 7.485038900734593
655 4.600696560155642 7.4308956402895525
656 4.664970891698426 7.536354093733182
657 4.6204005168802915 7.522303139071148
658 4.621164885376633 7.5115993554802705
659 4.665915235113991 7.490876352451695

841 3.9441031485413878 7.268666417607344
842 3.904385014431062 7.302236829475914
843 3.956448393456544 7.3496486587820975
844 3.889948820456471 7.325937720603265
845 3.9174980251537512 7.288643706409386
846 3.9410366939515655 7.216847688432096
847 3.913698852395535 7.287924111683667
848 3.8815053458484443 7.1738116847061235
849 3.8954760031863738 7.159608487328235
850 3.9048135180967725 7.184632685564008
851 3.8624333528294614 7.190225950510776
852 3.8725076716842453 7.157437389672389
853 3.8916553678544195 7.196405402320949
854 3.935592674289214 7.1786037757861845
855 3.904651479781133 7.131591177280138
856 3.8625128603041063 7.1949164874666085
857 3.852004124528483 7.15871167768179
858 3.8903437908478335 7.210111469679551
859 3.867051273705047 7.2039063272068695
860 3.8947429257467596 7.246572048951749
861 3.8443140624146803 7.262778259407655
862 3.8804231623425047 7.20655430326366
863 3.822939758988427 7.240756718815186
864 3.8707951275227614 7.325428339505753
865 3.858419151173066 

1042 3.4892642551870203 7.089618978071387
1043 3.4781724274528436 7.130405122412983
1044 3.478056125386408 7.156677820102598
1045 3.485348544523904 7.197487967791132
1046 3.4154994376973233 7.126685298760985
1047 3.473769941267166 7.188474808052476
1048 3.447523226333479 7.0865711895698125
1049 3.4313631638931827 7.1993392987337
1050 3.479340955517394 7.081748839958366
1051 3.4664377245306786 7.05522557721547
1052 3.52406259405229 7.157795021182732
1053 3.4506064866355572 7.069798472931524
1054 3.461595038162679 7.201855558321517
1055 3.4167484007812443 7.15038706442061
1056 3.4789081173333285 7.138582214663998
1057 3.5076230506746966 7.125966607766559
1058 3.3962340045418635 7.145635846626063
1059 3.440192493964115 7.1505036315333355
1060 3.417311887365005 7.100633023431803
1061 3.4181699345410204 7.188712575636832
1062 3.4205701971990226 7.185644225961345
1063 3.4904076057705735 7.013057362322675
1064 3.422076594879382 7.134120707928739
1065 3.44493035989648 7.128106233133142
1066 3.

1240 3.1837973381085547 6.949357620489268
1241 3.167355434221425 7.078912458194214
1242 3.207565746392926 7.1108767541512705
1243 3.193942286436466 6.965881452295398
1244 3.2147016182272994 7.1016903607062245
1245 3.172438911561419 7.152318441401497
1246 3.184246632318985 7.007766497898253
1247 3.190353788892991 7.042865364370103
1248 3.167118167799665 7.077878172193608
1249 3.180067026000012 7.1203103809090615
1250 3.180310025229727 7.032371062081705
1251 3.1805472182675163 6.975231731513344
1252 3.1971096663908827 7.056078636234149
1253 3.140714707057241 7.119169022482916
1254 3.1723131039987686 7.019089135404232
1255 3.136583814953281 7.0705419722915455
1256 3.1810822657689166 7.006642799576983
1257 3.1789220989926945 7.115776686132056
1258 3.1441214697501025 6.935251175993226
1259 3.169549054779986 7.033098838263414
1260 3.1708417392732877 6.996082066961847
1261 3.1874385995622836 7.045118686792668
1262 3.1528055901600136 7.011609599243953
1263 3.167580714599969 6.933095854540028
1

1454 2.978744302908825 7.194636626256748
1455 3.0167526030253913 6.978440235002777
1456 3.0392216089909314 7.023444217039544
1457 3.012624808703082 6.927826359476179
1458 2.97708030796708 7.0511840932975405
1459 2.9540332991045806 6.97766172307462
1460 3.028735447817113 7.058704403246135
1461 2.9741465357249846 7.056250705281832
1462 2.9685000978532634 6.938082881627602
1463 2.987227863966992 6.940452206373042
1464 2.9836172535325236 6.847097187958776
1465 3.025377810554995 7.0224409668173475
1466 2.9824217744007337 7.133049294222946
1467 3.013005282756377 7.056587604322099
1468 3.0294435756445592 7.04249136810888
1469 2.9948620313085668 7.058147351459461
1470 2.9669616224051896 7.029999251313115
1471 3.0086132696640377 6.966114013354255
1472 2.967183828685921 7.0053998719841974
1473 2.9839252074780496 6.992236191229003
1474 3.0043161184025666 6.977175686815407
1475 2.9805476187157462 7.064811028732951
1476 2.9891252005598754 7.108007321387666
1477 2.995527849421785 7.046993382459565
1

1661 2.8598629559041453 6.972553791480207
1662 2.8400248115192874 7.0910586383720196
1663 2.871182006401931 7.072385266784903
1664 2.8877617003190608 7.044963744704578
1665 2.8894782323060455 7.1068709311207625
1666 2.8896884295790204 7.162041262871679
1667 2.9072463434608 7.131484663599375
1668 2.8895659477467412 7.004338629111565
1669 2.843462272837858 7.0997807916399935
1670 2.815087842011259 7.037987030437424
1671 2.87051614018733 7.019626579206541
1672 2.8254450838504432 7.027055454166792
1673 2.8372825930049146 7.072298113248816
1674 2.8667121412282452 6.968418359671663
1675 2.896900393706603 6.9979236558472
1676 2.8675774716976625 7.047847495955738
1677 2.8217570408591457 6.967746345365867
1678 2.821601617805951 7.10070058469951
1679 2.864844040584469 7.116978910577219
1680 2.8686320599505164 6.985716752518732
1681 2.883715104900773 7.1011624997422045
1682 2.843516703310812 7.045437607308656
1683 2.853180962022825 7.047844963331299
1684 2.8683917038116356 7.0127217846240075
1685

1870 2.7584997166417238 7.061883717253726
1871 2.754264259564022 7.054942436761074
1872 2.717938437029658 7.111180603514591
1873 2.7458095821275545 7.022174759717635
1874 2.75874779610605 6.993596043997647
1875 2.7396463444737362 7.008997336917495
1876 2.75255453460039 7.096545117142757
1877 2.729388385520554 7.085380483148279
1878 2.7411706879439874 7.125327595778785
1879 2.758337604835568 7.119488347036
1880 2.733167034529981 7.045461409137789
1881 2.7259149127260995 7.104714657905552
1882 2.764413093428859 7.114406123522731
1883 2.795646651165473 7.14571825453536
1884 2.7174148037082264 7.1673487670338965
1885 2.7493686247582763 7.200123007204222
1886 2.782011909272806 7.106341334978976
1887 2.7457243995880156 7.084252339179218
1888 2.759098124225018 7.09379801255688
1889 2.7255981373850067 7.002517649288622
1890 2.7277019644461764 7.068393519982332
1891 2.73471377635747 7.127113785697394
1892 2.76186215720529 6.994064309064033
1893 2.742307483391388 7.015124005629245
1894 2.7270312

2081 2.6861828919678787 7.08540374681895
2082 2.678949536488981 7.180055417665243
2083 2.6433788945640497 7.1105730188365195
2084 2.6440951751539576 7.118520574750493
2085 2.680980033287578 7.190351314035094
2086 2.695280019710294 7.186153989869858
2087 2.6998918812523116 7.172425507903183
2088 2.7120598395690125 7.012699662692563
2089 2.62555423879626 7.142844614084299
2090 2.628275074280302 7.08480773922488
2091 2.6176255633105923 7.1180227400063085
2092 2.6583669825693406 6.987272558738486
2093 2.6475955624222482 7.045942668532074
2094 2.6655108760180215 7.189789611110483
2095 2.679373985877112 7.169782518298988
2096 2.6815420761780326 7.182612246399194
2097 2.665307026033315 7.118333470332897
2098 2.628848625479842 7.090311580760831
2099 2.670215383555614 7.100395252609596
2100 2.6454640182176092 7.132644204503992
2101 2.665288238198815 7.084546374316457
2102 2.698932565250088 7.145218464639322
2103 2.636366253947547 7.08331217072154
2104 2.6333426002522002 7.096302222947639
2105 2

2292 2.5785513860060245 7.082408768495155
2293 2.570184059720892 7.092958816020495
2294 2.617813201349796 7.081028139360288
2295 2.6422366704724642 7.135288775095032
2296 2.57783700016247 7.045620188244975
2297 2.5690576765039377 7.095768560902759
2298 2.634783361461891 7.200955330816318
2299 2.6121243085248134 7.172481349469547
2300 2.585018961436486 7.132305860873658
2301 2.6055269143878466 7.067217056480248
2302 2.6178542252866475 7.15895538958624
2303 2.6125698610833368 6.996224930609632
2304 2.615202873986953 7.044314329827816
2305 2.6168628771295612 7.242920651303299
2306 2.6095745564341657 7.124167824810016
2307 2.5984881959093085 7.195319792109763
2308 2.6241152589528123 7.126783644915884
2309 2.600139075100695 7.120079189222273
2310 2.54440994280331 7.104267770952674
2311 2.613659233193915 7.043133191269906
2312 2.5642761273213432 7.208606875354625
2313 2.5977601527671426 6.964937679719777
2314 2.5955765873189005 7.19530392283824
2315 2.55267804001069 7.0738118416468785
2316 2

SymptomPredictor(
  (layers): ModuleList(
    (0): Linear(in_features=27, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=100, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (final_layer): Linear(in_features=100, out_features=112, bias=True)
)

In [12]:
torch.save(sp,'../resources/symptomImputer.pt')

In [None]:
def get_test_patient(d,pid=7,clear_transitions=True):
    tp = d.processed_df.loc[pid].to_dict()
    if clear_transitions:
        for v in Const.primary_disease_states + Const.nodal_disease_states:
            tp[v] = 0
            tp[v + ' 2'] = 0
    return tp 

def get_knn_predictions(fdict,
                        model,
                        mdasi,
                        k=8,
                        ttype=torch.FloatTensor,
                        dates=[0,7,12,27],
                        symptom_subset = None,
                       ):
    

    mdasi_df = process_mdasi_input(mdasi)
    sdf,output_dates,output_symptoms = get_symptom_df(mdasi)
    
    xalt = df_to_torch(mdasi_df)
    
    order = mdasi_df.columns
    xin = torch.tensor([fdict[k] for k in order]).type(ttype).view(1,-1) 

    embeddings = model.get_embedding(xin).cpu().detach().numpy()
    base_embeddings = model.get_embedding(xalt).cpu().detach().numpy()
    mdasi_ids= mdasi_df.index
    dists = cdist(embeddings,base_embeddings)[0]
    
    
    order = np.argsort(dists)[:k]
    dists = dists[order]
    ids = mdasi_ids[order]
    symptoms = sdf.loc[ids]
    res = {'ids': ids.tolist(),'dists': dists.tolist()}
    sentries = {}
    
    if symptom_subset is None:
        symptom_subset = Const.prediction_symptoms
    for sym in symptom_subset:
        if sym == 'core':
            continue
        cols = [c for c in symptoms.columns if sym+'_' in c]
        values = symptoms[cols].values
        entry = {'ratings': values.tolist()}
        means = []
        for cidx in range(values.shape[1]):
            subvals = values[:,cidx]
            subvals = [v for v in subvals if v >= 0]
            means.append(np.mean(subvals))
        entry['means'] = means
        sentries[sym] = entry
    return {'ids': ids.tolist(),'dists':dists.tolist(),'symptoms':sentries}

test_patient = get_test_patient(data,7)
get_knn_predictions(test_patient,sp,mdasi,symptom_subset=['choke','drymouth'])

In [None]:
def get_predictions(data,model,input_cols=Const.mdasi_input_cols,output_symptoms=Const.prediction_symptoms,output_dates=[0, 7, 13, 27]):
    xin = df_to_torch(data.processed_df[input_cols])
    ypred = model(xin).cpu().detach().numpy()
    if output_symptoms is None or output_dates is None:
        return ypred
    results = {}
    i = 0
    width = len(output_dates)
    for symptom in output_symptoms:
        values = ypred[:,i:i+width]
        i += width
        s = values.tolist()
        results[symptom] = s
    return pd.DataFrame(results,index=data.processed_df.index)
get_predictions(data,sp)

In [None]:
def overlap(v1,v2):
    return sorted([vv for vv in v1 if vv in v2])

def get_id_mapped_r01(file = '../data/key_map.xlsx',reverse=False):
    df = pd.read_excel('../data/key_map.xlsx').drop('Unnamed: 0',axis=1)
    df['mdasi_id'] = df['STIEFEL'].apply(lambda x: int(x.replace("STIEFEL_",'')))
    df = df[['mdasi_id','ID']]
    if reverse:
        return df.set_index('mdasi_id').to_dict()['ID']
    idmap = df.set_index('ID').to_dict()['mdasi_id']
    r01 = pd.read_csv('../data/distance_csv.csv').drop('Unnamed: 0',axis=1)
    r01['old_id'] = r01.id
    r01['id'] = r01['id'].apply(lambda x: idmap.get(x,x))
    return r01

camprt = pd.read_csv('../data/camprtdists.csv').rename({'ID':'id'},axis=1)
mdasi = pd.read_csv('../data/MDASI_0909201_surgery_updated.csv')
r01 = get_id_mapped_r01()
camprt_overlap = set(camprt.id.values.astype(int)).intersection(data.processed_df.index.values.astype(int))
r01_overlap = set(r01.id.values.astype(int)).intersection(data.processed_df.index.values.astype(int))
mdasi_overlap = set(mdasi.id.values.astype(int)).intersection(data.processed_df.index.values.astype(int))
len(camprt_overlap), len(r01_overlap), len(camprt_overlap.union(r01_overlap)),len(mdasi_overlap), len(mdasi_overlap.intersection(r01_overlap.union(camprt_overlap)))

In [None]:
def cluster_loss(ytrue,ypred,weights=None):
    if weights is None:
        weights = [1 for i in range(ytrue.shape[1])]
    loss = 0
    nloss = torch.nn.BCELoss()
    for i in range(len(weights)):
        iloss = nloss(ypred[:,i],ytrue[:,i])*weights[i]
        loss += iloss
    return loss

epochs = 100
best_val_loss = 10000
best_loss_metrics = {}
patience = 10
save_file = '../data/models/clusterModel.tar'
optimizer = torch.optim.Adam(imputer.parameters(),lr=.01)
lossF = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
    ypreds = imputer(xtrain)
    losses = 0
    for i,ypred in enumerate(ypreds):
        losses+= lossF(ypred,ytrain[:,i].long())
    losses.backward()
    optimizer.step()
    yval = imputer(xtest)
    val_loss = 0
    for i, yp2 in enumerate(yval):
        val_loss += loffF(yp2, ytest[:,i].long())
    print
    if loss < best_val_loss

In [None]:
mdasi_input