In [1]:
from train_utils import generate_model
import json
import torch
from data_util import MimicFullDataset, my_collate_fn
from torch.utils.data import DataLoader

from main import eval_func

In [2]:
# Путь до дирекории с чекпоинтами
path_dir = "/home/jovyan/isviridov/ue/ICD-MSMN/output/mimic4_word-0.2_lstm-1-512-0.1_MultiLabelMultiHeadLAATV2-512-xav-0.2-8_max-est1-8-42_ce-0.0-1.0_bsz16-AdamW-20-4000-warm0.0-wd0.01-0.0005-rdrop5.0"
device = "cuda:0"

In [3]:
# Загружается лучшая модель по валидации
model = torch.load(f"{path_dir}/best_epoch.pth").to(device)
batch_size = 8

In [4]:
# Загрузка валидационного датасета (для поиска трешхолда, тк он не сохраняется)
word_embedding_path = "./embedding/processed_full.model"

dev_dataset = MimicFullDataset("dev", word_embedding_path, 4000)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=my_collate_fn, shuffle=False, num_workers=1)
test_dataset = MimicFullDataset("test", word_embedding_path, 4000)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=my_collate_fn, shuffle=False, num_workers=1)

In [5]:
# Валидация (в основном для расчёта трешхолда)
dev_metric, (dev_yhat, dev_y, dev_yhat_raw), threshold = eval_func(model, dev_dataloader, device, tqdm_bar=True)

100%|██████████| 476/476 [00:24<00:00, 19.06it/s]


In [6]:
dev_metric

{'acc_macro': 0.5530060210235864,
 'prec_macro': 0.6965041509893598,
 'rec_macro': 0.6970153235909379,
 'f1_macro': 0.696759643535654,
 'acc_micro': 0.5854394860084353,
 'prec_micro': 0.740853210559824,
 'rec_micro': 0.7362016937603392,
 'f1_micro': 0.738520127920315,
 'rec_at_5': 0.6943198256700227,
 'prec_at_5': 0.6509329829172142,
 'f1_at_5': 0.6719267521122437,
 'rec_at_8': 0.8324006724039577,
 'prec_at_8': 0.5229960578186597,
 'f1_at_8': 0.6423835331539697,
 'rec_at_15': 0.9391307312923607,
 'prec_at_15': 0.33184406482698214,
 'f1_at_15': 0.49040305162232384,
 'auc_macro': 0.9348519020183065,
 'auc_micro': 0.9559267933019663}

In [7]:
# Тест: dev_yhat_raw - логиты, dev_yhat - логиты после трешхолда, dev_y - ground truth
test_metric, (dev_yhat, dev_y, dev_yhat_raw), _ = eval_func(model, test_dataloader, device, tqdm_bar=True, threshold=threshold)

100%|██████████| 921/921 [00:40<00:00, 22.72it/s]


In [8]:
test_metric

{'acc_macro': 0.5557136680822725,
 'prec_macro': 0.6986344743225809,
 'rec_macro': 0.6993245556906181,
 'f1_macro': 0.6989793446824755,
 'acc_micro': 0.5884286653517411,
 'prec_micro': 0.7423920369122278,
 'rec_micro': 0.7394020756957692,
 'f1_micro': 0.7408940397350974,
 'rec_at_5': 0.7098881717757207,
 'prec_at_5': 0.6513843648208468,
 'f1_at_5': 0.6793791007083276,
 'rec_at_8': 0.839703137158081,
 'prec_at_8': 0.5148615635179153,
 'f1_at_8': 0.6383318122380617,
 'rec_at_15': 0.9425761879503459,
 'prec_at_15': 0.3242942453854506,
 'f1_at_15': 0.4825624239801493,
 'auc_macro': 0.9371672179444519,
 'auc_micro': 0.9577894110746192}

In [14]:
dev_yhat_raw[:2], dev_yhat[:2]

(array([[ -2.9810805 ,  -6.0504203 ,  -1.4114794 ,  -3.8998897 ,
          -2.5264597 ,  -5.7022276 ,  -2.4338849 ,  -2.4488735 ,
          -0.7510817 ,  -1.4569505 ,  -2.8358486 ,  -2.1841345 ,
          -2.1629922 ,  -1.3296151 ,  -3.4232244 ,  -3.2867386 ,
          -2.1514978 ,   0.57269865,  -3.9672565 ,  -5.0608745 ,
          -6.9945965 ,  -3.8867693 ,  -4.2377896 ,  -5.3218694 ,
          -2.4139996 ,  -2.8933094 ,  -3.8860736 ,  -1.9988656 ,
          -2.6469536 ,  -1.9725678 ,  -2.6347969 ,  -3.8692875 ,
          -2.7529347 ,  -3.7117763 ,  -3.2451048 ,  -4.0340204 ,
         -11.435334  ,  -4.2765613 ,  -4.187875  ,  -2.872563  ,
          -4.4055853 ,   1.9272723 ,  -4.3311586 ,  -4.244507  ,
          -5.044257  ,  -2.7505763 ,  -2.8848522 ,  -1.2859597 ,
          -5.5899663 ,  -4.55542   ],
        [ -6.3940372 ,  -8.750653  ,  -3.218421  ,  -5.0218964 ,
          -3.7120266 ,  -7.5002027 ,  -3.980398  ,  -2.0434327 ,
           0.902923  ,  -5.5081425 ,  -6.0383754 ,  