## Todos

model isn't exactly reproducible. having different results when i load it

As of 24 September my best uncertainty model according to recall@30 is registered in wandb in the project "final_model_p1_pre_uncertainty_tunning" and it is called "golden-oath-84"

The config is as follow:
- type: gru
- hidden_size = 100
- lr: 0.01
- layers = 1
- batch_size = 64
- dropout: 0.1
- epochs: 15


In this notebook i will create a pickle of this model to be used for later experiments

In [1]:
model_name = 'golden-oath-84'

In [2]:
import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model

from mourga_variational.variational_rnn import VariationalRNN

import torch
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.model_selection import ParameterGrid, ParameterSampler

import numpy as np
import pandas as pd

import wandb

# Reproducibility

In [3]:
# Reproducibility
np.random.seed(546)
torch.manual_seed(546)
torch.cuda.manual_seed(546)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x127a56730>

# Create dataset

In [4]:
grouping = 'ccs'
batch_size=64

In [5]:
dataset = DiagnosesDataset('data/model_data.json',grouping)
test_size = 0.15
eval_size=0.15
eval_size_corrected = eval_size/(1-test_size)

whole_train_dataset,test_dataset = split_dataset(dataset,test_size)
train_dataset, val_dataset = split_dataset(whole_train_dataset,eval_size_corrected)

len(whole_train_dataset)
len(train_dataset)
len(val_dataset)
len(test_dataset)

whole_train_dataset = DataLoader(whole_train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

6375

5250

1125

1124

# Define model

## Hyperparameters

In [7]:
input_size = next(iter(train_dataloader))['target_sequences']['sequence'].shape[2]
hidden_size = 100
num_layers = 1
n_labels = input_size
rnn_type = 'GRU'

lr = 0.01
dropout = 0.1

In [8]:
n_labels = input_size
epochs = 15
criterion = torch.nn.BCEWithLogitsLoss()

# Train

In [9]:
model = VariationalRNN(input_size=input_size,
                          hidden_size=hidden_size,
                          num_layers=num_layers,
                          n_labels=n_labels,
                          rnn_type=rnn_type,
                          dropouti=dropout,
                          dropoutw=dropout,
                          dropouto=dropout)
    
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(1,epochs+1):
    loss = train_one_epoch(model,train_dataloader,epoch,criterion,optimizer)

# Eval

confirm it has good results

In [10]:
train_results = eval_model(model,train_dataloader,dataset, criterion, epoch, 'train_last')
val_results = eval_model(model,val_dataloader,dataset, criterion, epoch, 'validation')
res = {'train_loss':train_results['loss'],
       'train_recall@30':train_results['last adm']['recall30']['mean'],
       'val_loss':val_results['loss'],
       'recall@10':val_results['last adm']['recall10']['mean'],
       'recall@20':val_results['last adm']['recall20']['mean'],
       'recall@30':val_results['last adm']['recall30']['mean']
      }
res

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe',
  ret = ret.dtype.type(ret / rcount)


{'train_loss': 0.03398568955489567,
 'train_recall@30': 0.7559982857687191,
 'val_loss': 0.035415327936410904,
 'recall@10': 0.43259607308928094,
 'recall@20': 0.603509778626304,
 'recall@30': 0.7074953808560517}

## Save model weights and parameters

In [17]:
os.mkdir('test')

In [18]:
models_base_path = 'models'
model_path = os.path.join('models',model_name)

if not os.path.isdir(models_base_path):
    os.mkdir(models_base_path)
if not os.path.isdir(model_path):
    os.mkdir(model_path)

# save weights
weights_save_path = os.path.join('models',
                         model_name,
                         "_".join([model_name,'weights'])
                        )

torch.save(model.state_dict(), 
           weights_save_path
          )

## Save model hyperparameters

In [19]:
params = dict(input_size = input_size,
              hidden_size=hidden_size,
              num_layers=num_layers,
              n_labels=n_labels,
              rnn_type=rnn_type,
              dropouti=dropout,
              dropouto=dropout,
              dropoutw=dropout
             )

hypp_save_path = os.path.join(model_path, "_".join([model_name,'hypp.json']))
with open(hypp_save_path, "w") as f:
    json.dump(params, f)

# Test it out

### Read weights and hyperparameters

In [20]:
#hyperparameters
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)
    
# weights
weights = torch.load(weights_save_path)

### Create model and load weights

In [21]:
new_model = VariationalRNN(**params_loaded)
new_model.load_state_dict(torch.load(weights_save_path))

<All keys matched successfully>

### Evaluate it

In [22]:
new_val_results = eval_model(new_model,val_dataloader,dataset, criterion, epoch, 'validation')
new_val_results

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe',
  ret = ret.dtype.type(ret / rcount)


{'name': 'validation',
 'epoch': 15,
 'loss': 0.035255526608890954,
 '1 adm': {'recall10': {'mean': nan, 'std': nan, 'n': 0},
  'recall20': {'mean': nan, 'std': nan, 'n': 0},
  'recall30': {'mean': nan, 'std': nan, 'n': 0}},
 '2 adm': {'recall10': {'mean': nan, 'std': nan, 'n': 0},
  'recall20': {'mean': nan, 'std': nan, 'n': 0},
  'recall30': {'mean': nan, 'std': nan, 'n': 0}},
 '3 adm': {'recall10': {'mean': nan, 'std': nan, 'n': 0},
  'recall20': {'mean': nan, 'std': nan, 'n': 0},
  'recall30': {'mean': nan, 'std': nan, 'n': 0}},
 '>3 adm': {'recall10': {'mean': nan, 'std': nan, 'n': 0},
  'recall20': {'mean': nan, 'std': nan, 'n': 0},
  'recall30': {'mean': nan, 'std': nan, 'n': 0}},
 'last adm': {'recall10': {'mean': 0.43573282622446574,
   'std': 0.19247979720688113,
   'n': 1125},
  'recall20': {'mean': 0.6066858725790196,
   'std': 0.18259942654418368,
   'n': 1125},
  'recall30': {'mean': 0.7106074628941526,
   'std': 0.16737796949498387,
   'n': 1125}}}