In [151]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [152]:
import os
from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE, IcareDataset, ICareCOLLATE
from rnn_utils import RNN, train_one_epoch, train_one_epochV2, eval_model, compute_loss, get_prediction_thresholds, outs2df,compute_metrics

import torch
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.model_selection import ParameterGrid, ParameterSampler

import numpy as np
import pandas as pd
import json

from tqdm.notebook import tqdm

from config import Settings; settings = Settings()

from ICDMappings import ICDMappings
icdmap = ICDMappings()

import wandb

idx = pd.IndexSlice

# Reproducibility

In [153]:
# Reproducibility
seed = settings.random_seed

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x7eff33ab5d30>

# Create dataset

In [154]:
grouping = 'ccs'
batch_size=64

In [155]:
ccs_universe = list(icdmap.icd9_3toccs.data.keys())
dataset_folder = '/home/debian/Simao/master-thesis/data/model_ready_dataset/icare2021_diag_A301'
dataset = IcareDataset(os.path.join(dataset_folder,'dataset.json'),
                       ccs_universe,
                       grouping
                      )

train_dataset = IcareDataset(os.path.join(dataset_folder,'train_subset.json'),grouping)
val_dataset = IcareDataset(os.path.join(dataset_folder,'val_subset.json'),grouping)
test_dataset = IcareDataset(os.path.join(dataset_folder,'test_subset.json'),grouping)


len(train_dataset)
len(val_dataset)
len(test_dataset)


train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=ICareCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=ICareCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=ICareCOLLATE(dataset))

183967

39422

39422

In [164]:
from sys import getsizeof

array = np.random.randint(1,10,size=(1000,1000))
getsizeof(array)

8000128

In [166]:
getsizeof(dataset.grouping_data)

232

In [169]:
getsizeof(dataset.data)

10485856

In [158]:
a = 3
getsizeof(a)

28

In [160]:
len(dataset)

262811

# Train

In [6]:
n_labels = input_size = next(iter(train_dataloader))['target_sequences']['sequence'].shape[2]

criterion = torch.nn.BCEWithLogitsLoss(reduction='none')

In [7]:
hyperparameters = {
    'hidden_size':[100,150],
    'num_layers':[1,2],
    'lr':[0.01,0.02,0.03],
    'model':['rnn','gru','lstm']
    
}
meta_parameters = {
    'epochs':15
}

params = ParameterGrid(hyperparameters)
print(f'params:',len(params))

#random_params = ParameterSampler(params.param_grid,n_iter=len(params)-1,random_state=231)
#next(iter(random_params))

params: 36


# Test

In [8]:
from torch import nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence
import torch.nn.functional as F

In [9]:
import pandas
import numpy

In [10]:
param_set = next(iter(params))
config = param_set
model = RNN(input_size=input_size,
              hidden_size=config['hidden_size'],
              num_layers=config['num_layers'],
              n_labels=n_labels,
              model=config['model'])

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
criterion = torch.nn.BCEWithLogitsLoss(reduction='none')

#batch = next(iter(val_dataloader))

#outs = model(batch['train_sequences']['sequence']).detach().numpy()
#loss = train_one_epoch(model,val_dataloader,1,criterion,optimizer)

#outs,golden = outs2df_dev(model,val_dataloader,dataset,return_golden=True)

#train_metrics = eval_model(model,train_dataloader,dataset,metrics=['roc','f1'])[1].filter(regex='_adm')
#    val_metrics = eval_model(model,val_dataloader,dataset,metrics=['roc','f1'])[1].filter(regex='_adm')

In [11]:
param_set = {
          'hidden_size':100,
          'num_layers':1,
          'lr':0.01,
          'model':'rnn'
         }
for idx,_ in tqdm(enumerate(params)):
    config = {**param_set, 
              **meta_parameters}
    
    wandb.init(
        project="icare", 
        config=config
    )
    
    model = RNN(input_size=input_size,
              hidden_size=config['hidden_size'],
              num_layers=config['num_layers'],
              n_labels=n_labels,
              model=config['model'])
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    
    #loss = eval_model(model,val_dataloader,dataset,None,None,only_loss=True)
    #wandb.log({'epoch':0,'loss':loss})
    
    #for epoch in tqdm(range(1,config['epochs']+1)):
        
    #    loss = train_one_epochV2(model,val_dataloader,epoch,criterion,optimizer);
    #    wandb.log({'epoch':epoch,'loss':loss})
        
        
    
    outs_train,golden_train = outs2df(model,val_dataloader,dataset,return_golden=True)
    #thresholds = get_prediction_thresholds(outs_train,(golden_train > 0).astype(int))
    
    train_metrics = compute_metrics(outs_train,model_predictions=None,golden=golden_train,metrics=['roc','recall@30','precision@30'])[1].filter(regex='_adm')
    #val_metrics = eval_model(model,test_dataloader,dataset,None,metrics=['roc','recall@30','precision@30'])[1].filter(regex='_adm')
    train_metrics.index = ['train_' + n for n in train_metrics.index]
    #val_metrics.index = ['val_' + n for n in val_metrics.index]
        

    log = dict()

    log.update(train_metrics.to_dict())
    #log.update(val_metrics.to_dict())
    log.update({'loss':loss})

    wandb.log(log)
    
    model_name = str(params)

    hypp_save_path = os.path.join(model_folder, model_name+'_hyper_parameters.json')

    with open(hypp_save_path, "w") as f:
        json.dump(params, f)

    print('Hyperparameters saved!')
    
    weights_save_path = os.path.join(model_folder,model_name+"_weights")

    torch.save(model.state_dict(), 
               weights_save_path
              )
    print('Model saved!')
    break

0it [00:00, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msnovaisg[0m. Use [1m`wandb login --relogin`[0m to force relogin


0it [00:00, ?it/s]

computing roc


  0%|          | 0/270248 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
outs_train.apply(lambda x: 'hi',axis=1)

pat_id                            adm_index
00070121385D5F499BB0D98F48554EF5  1            hi
                                  2            hi
                                  3            hi
0008F5D602267E1E2F679BA745F38A41  1            hi
                                  2            hi
                                               ..
FFFEF32F9705DCBB22EE8E7CC09E9379  2            hi
FFFFCAF4C44303A609ABA747E6E00CEE  1            hi
                                  2            hi
FFFFCD298CB71236CEB3470483A4C6A1  1            hi
                                  2            hi
Length: 270248, dtype: object

In [19]:
from multiprocessing import Pool
from sklearn.metrics import roc_auc_score

In [None]:
outs_train

In [9]:
# setup

df = pd.DataFrame(np.zeros(shape=(int(1e8),10)))
#df_later = df.copy()
#df_later.loc[int(1e7)] = 1
#df_later = df_later.astype(bool)
#
#df_sooner = df.copy()
#df_sooner.loc[int(1e2)] = 1
#df_sooner = df_sooner.astype(bool)

In [10]:
%%timeit -r 10 -n 1

df_sooner = df.copy()
df_sooner.loc[int(1e2)] = 1
df_sooner = df_sooner.astype(bool)
df_sooner.iloc[:,0].any()

5.47 s ± 89.3 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [11]:
%%timeit -r 10 -n 1

df_later = df.copy()
df_later.loc[int(1e7)] = 1
df_later = df_later.astype(bool)
df_later.iloc[:,0].any()

5.44 s ± 42.9 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [43]:
array.mask(array.index == int(1e2)).fillna(1)#.loc[int(1e2)]

0            0.0
1            0.0
2            0.0
3            0.0
4            0.0
            ... 
999999995    0.0
999999996    0.0
999999997    0.0
999999998    0.0
999999999    0.0
Length: 1000000000, dtype: float64

----

In [30]:
def normal_roc_optimized(logits,golden):
    idx = pd.IndexSlice
    
    logits_ = logits.copy()
    golden_ = golden.copy()
    
    logits_.columns = pd.MultiIndex.from_product([['logits'],logits_.columns])
    golden_.columns = pd.MultiIndex.from_product([['golden'],golden_.columns])
    
    full = logits_.join(golden_,how='inner')
    assert (full.shape[0] == logits.shape[0]) and (full.shape[0] == golden.shape[0]),'oops'
    
    return full.apply(lambda row: roc_auc_score(row.loc[:,idx['golden',:]],row.loc[:,idx['logits']]) if any(
    #return logits.apply(lambda row: roc_auc_score(golden.loc[row.name],row) if any(golden.loc[row.name] == 1) else np.nan,axis=1).rename('roc_adm') 

In [31]:
normal_roc(outs_train,golden_train)

270248
270248


Unnamed: 0_level_0,Unnamed: 1_level_0,logits,logits,logits,logits,logits,logits,logits,logits,logits,logits,...,golden,golden,golden,golden,golden,golden,golden,golden,golden,golden
Unnamed: 0_level_1,Unnamed: 1_level_1,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_273,diag_274,diag_275,diag_276,diag_277,diag_278,diag_279,diag_280,diag_281,diag_282
pat_id,adm_index,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
00070121385D5F499BB0D98F48554EF5,1,0.521170,0.487607,0.482274,0.475589,0.552587,0.507825,0.506065,0.470086,0.530118,0.447996,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
00070121385D5F499BB0D98F48554EF5,2,0.509584,0.494575,0.494839,0.497005,0.532272,0.519580,0.491184,0.466202,0.519124,0.462451,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
00070121385D5F499BB0D98F48554EF5,3,0.513210,0.491401,0.484080,0.488004,0.558111,0.507040,0.501613,0.471400,0.531054,0.443273,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0008F5D602267E1E2F679BA745F38A41,1,0.517556,0.494975,0.507667,0.483336,0.539981,0.499710,0.504512,0.480275,0.520952,0.457920,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0008F5D602267E1E2F679BA745F38A41,2,0.515604,0.504581,0.508042,0.504277,0.531513,0.508115,0.496231,0.483177,0.515269,0.446479,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
FFFEF32F9705DCBB22EE8E7CC09E9379,2,0.514425,0.498155,0.501106,0.511381,0.542690,0.518909,0.527698,0.480515,0.523945,0.445996,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
FFFFCAF4C44303A609ABA747E6E00CEE,1,0.515025,0.496486,0.507078,0.495090,0.541965,0.511540,0.508252,0.473659,0.524212,0.460159,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
FFFFCAF4C44303A609ABA747E6E00CEE,2,0.511931,0.491205,0.494301,0.516626,0.528295,0.506213,0.500183,0.484189,0.529203,0.454736,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
FFFFCD298CB71236CEB3470483A4C6A1,1,0.509109,0.493436,0.502984,0.504799,0.543310,0.513838,0.511638,0.478368,0.537129,0.461304,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
for e in tqdm(dataset):
    target = e['target']
    if any([(pd.Series(t).value_counts() >1).any() if t else False for t in target]):
        print('found one')
        break

  0%|          | 0/262811 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
target

[[],
 [258.0, 136.0, 258.0],
 [136.0, 258.0, 47.0, 44.0, 47.0],
 [258.0, 47.0, 44.0, 47.0],
 [47.0, 44.0, 47.0],
 [47.0, 258.0],
 [258.0],
 [],
 []]

In [None]:
from tqdm.auto import tqdm

tqdm.pandas()

In [14]:
df = pd.DataFrame(np.random.randint(0, int(1e8), (10000, 1000)))

In [15]:
df.groupby(0).progress_apply(lambda x: x**2)

  0%|          | 0/10000 [00:00<?, ?it/s]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,990,991,992,993,994,995,996,997,998,999
0,8861896394421316,2516512275852201,372733605816100,4172308873801744,5427823796622756,762071635866025,292111749490729,9997616142086400,6321334108485124,1271447821372689,...,251821352454400,2339752576838121,3704089009201,3943910838718096,6966029135233081,1289798731957824,1984177646603881,300951681465681,7736057397619600,5319091648002609
1,9155879298676881,6709341313903761,65115732802624,890036469243441,9545029142782225,747917698471056,1806302955388129,1476837230529600,3410535705643264,4293490110523441,...,4681045028294569,2865903717266496,1946334394598400,3823240248510736,2004459901728025,7188269002995600,2243874077724304,6803153467275361,4944864005354529,19988598081321
2,3498268242745600,669675749072836,8038721756248561,1850330143135296,1072158795994896,28563637494016,1021796322871401,5273630994852900,2872036653400996,6139705831875625,...,509823885318400,916343854253584,3164205151607824,5673962587004100,712247208961600,16249839769881,6765721667556049,5029166424891456,30716934105796,9946115385025369
3,5311420841104,26637211265625,1822132330619364,5192115845970244,5280184349425129,363343527111184,1167735767278201,6401713074618436,1904908371305536,1169945083891600,...,8993614065000889,4729322665345041,4174562196165904,1025072216529001,40104241849681,1644630079213521,4089657497184900,2461892336851600,319438981596736,188127202106809
4,7723203564934081,666917970810025,548547784683409,6585170425777969,147144356836,1168733602068516,7887804921289225,2172900609241924,2184455880614569,9670344739843921,...,106726420399104,1490704223710321,2282394441945744,3051335410343236,1148400881668096,5580157529004304,496791675302976,853373183524,5421407382862849,7177021597028416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,5766018068358601,99489433364721,1637230981469521,4248013433289796,56538564139369,45277368491716,259256820914116,4111923029021284,9518272623873225,3034494608545009,...,1205378755659225,166931852284521,1725772410201889,7003729359936,166618090411969,7454134896878224,6223394012257081,6113684857283136,513185321116836,1289090035210000
9996,196378490547121,384595713876544,3465764145889081,7138363746590244,2037127152357025,69772274880400,62955909394576,8010592788667225,12669869394576,1180895698690569,...,90766015494400,775019887394041,5609292930315369,1449463478112400,5197174711993489,17178894536049,234544201706896,39165355117729,2048348658446596,4330366199248324
9997,2961342014949796,678613180542025,4530780144544900,6008298034556944,1743870706663684,5581324261644121,329371536771216,1148995490810896,7708482306949369,2208631159397776,...,730305169657281,320143452857809,7351274721180625,3698201265614244,9472348087524,1782668910928896,6118404224812441,6129201498473089,3095992769589316,52856942197284
9998,3977470202227344,107320835618929,830945279844,3086961604026624,4711842366265321,1437973421184400,234800795350521,1743325367909904,1327628297249956,587601694807009,...,953331328132096,669732526814481,9854895040286976,1790783162981316,1704794320614025,360751712706225,988775384150401,1905336031524025,197825434591849,713561075454096
