In [1]:
from train_utils import generate_model
import json
import torch
from data_util import MimicFullDataset, my_collate_fn
from torch.utils.data import DataLoader

from main import eval_func

In [2]:
# Путь до дирекории с чекпоинтами
model_dir = "/mnt/s3-data/datasets/mimic4/output/model_with_mimic3_embeds/mimic4_word-0.2_lstm-1-512-0.1_MultiLabelMultiHeadLAATV2-512-xav-0.2-8_max-est1-8-42_ce-0.0-1.0_bsz16-AdamW-20-4000-warm0.0-wd0.01-0.0005-rdrop5.0"
embeds_dir = "/mnt/s3-data/datasets/mimic4/mimic_iv/mimic3_embeds/word2vec_sg0_100.model"
device = "cuda:1"

In [3]:
# Загружается лучшая модель по валидации
model = torch.load(f"{model_dir}/best_epoch.pth").to(device)
model.c_word_mask = model.c_word_mask.to(device)
model.c_word_sent = model.c_word_sent.to(device)
model.c_input_word = model.c_input_word.to(device)

batch_size = 8

In [4]:
# Загрузка валидационного датасета (для поиска трешхолда, тк он не сохраняется)
dev_dataset = MimicFullDataset("dev", embeds_dir, 4000)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=my_collate_fn, shuffle=False, num_workers=1)
test_dataset = MimicFullDataset("test", embeds_dir, 4000)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=my_collate_fn, shuffle=False, num_workers=1)

In [5]:
# Валидация (в основном для расчёта трешхолда)
dev_metric, (dev_yhat, dev_y, dev_yhat_raw), threshold = eval_func(model, dev_dataloader, device, tqdm_bar=True)

100%|██████████| 476/476 [00:21<00:00, 21.77it/s]


In [6]:
dev_metric

{'acc_macro': 0.5559106979658223,
 'prec_macro': 0.705887979377018,
 'rec_macro': 0.6936896043688301,
 'f1_macro': 0.699735632850313,
 'acc_micro': 0.5853981096114179,
 'prec_micro': 0.7445703260278003,
 'rec_micro': 0.7325026769200783,
 'f1_micro': 0.7384872052798119,
 'rec_at_5': 0.6918052519826501,
 'prec_at_5': 0.6501971090670171,
 'f1_at_5': 0.6703561602152069,
 'rec_at_8': 0.8310636250156619,
 'prec_at_8': 0.5227332457293036,
 'f1_at_8': 0.6417869556352915,
 'rec_at_15': 0.9392640433737675,
 'prec_at_15': 0.3320192728865528,
 'f1_at_15': 0.49061253418602035,
 'auc_macro': 0.934798444913623,
 'auc_micro': 0.9563674994378357}

In [7]:
# Тест: dev_yhat_raw - логиты, dev_yhat - логиты после трешхолда, dev_y - ground truth
test_metric, (dev_yhat, dev_y, dev_yhat_raw), _ = eval_func(model, test_dataloader, device, tqdm_bar=True, threshold=threshold)

100%|██████████| 921/921 [00:41<00:00, 22.26it/s]


In [8]:
test_metric

{'acc_macro': 0.556983215307773,
 'prec_macro': 0.7091817872654288,
 'rec_macro': 0.6964002490602869,
 'f1_macro': 0.7027329042589135,
 'acc_micro': 0.5892717306186362,
 'prec_micro': 0.7449136426394333,
 'rec_micro': 0.7382403056745992,
 'f1_micro': 0.7415619610741543,
 'rec_at_5': 0.7082975476777916,
 'prec_at_5': 0.6500271444082518,
 'f1_at_5': 0.6779124829149429,
 'rec_at_8': 0.838864421616997,
 'prec_at_8': 0.5144374321389794,
 'f1_at_8': 0.6377634934463228,
 'rec_at_15': 0.9437690388784601,
 'prec_at_15': 0.3246018820123055,
 'f1_at_15': 0.48305933407830975,
 'auc_macro': 0.9364695491775712,
 'auc_micro': 0.9575190884932989}

In [9]:
dev_yhat_raw[:2], dev_yhat[:2]

(array([[ -2.3478022 ,  -4.3366065 ,  -0.76810884,  -2.5877743 ,
          -0.53989863,  -3.7219074 ,  -0.75302756,  -1.8592943 ,
           0.03826404,  -1.1498349 ,  -1.9747726 ,  -1.5877777 ,
          -0.7024405 ,  -0.65485   ,  -2.5973842 ,  -2.409341  ,
          -1.9991333 ,   0.3217976 ,  -3.0705009 ,  -3.1525962 ,
          -4.7005205 ,  -1.7631277 ,  -2.820897  ,  -1.1594746 ,
           0.07368386,  -1.8499093 ,  -1.2649512 ,  -1.0763559 ,
          -3.570842  ,  -0.6253984 ,  -1.0826867 ,  -3.2176063 ,
          -3.1571984 ,  -3.0010257 ,  -1.8408563 ,  -2.6844306 ,
         -10.79202   ,  -4.434217  ,  -3.5336676 ,  -2.381419  ,
          -4.1773396 ,   2.1684    ,  -2.3287804 ,  -1.9040662 ,
          -2.5591264 ,  -1.6012051 ,  -0.81751096,  -2.5956612 ,
          -3.3163705 ,  -2.35377   ],
        [ -7.8886933 , -10.15856   ,  -3.5262523 ,  -4.1764245 ,
          -4.5758038 ,  -7.4148903 ,  -3.719347  ,  -3.100707  ,
           1.6389074 ,  -5.475654  ,  -5.169864  ,  