## Import

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
# DEVICE = torch.device('cpu')
DEVICE = torch.device('cuda')
DEVICE

## Evaluate models in subsets

TODO: move this to script???

In [None]:
%run -n train_report_generation.py
%run datasets/__init__.py
%run models/checkpoint/__init__.py
%run training/report_generation/flat.py
%run training/report_generation/hierarchical.py
%run models/report_generation/__init__.py

In [None]:
def eval_in_subset(run_name, compiled_model, debug=True, max_n_words=None, max_n_sentences=None,
                   device='cuda'):
    # Create datasets
    vocab = compiled_model.metadata['vocab']
    train_dataset = IUXRayDataset('train', vocab=vocab)
    val_dataset = IUXRayDataset('val', vocab=vocab)
    test_dataset = IUXRayDataset('test', vocab=vocab)
    
    # Prepare subsets
    subset_kwargs = {
        'max_n_words': max_n_words,
        'max_n_sentences': max_n_sentences,
    }
    
    train_subset = create_report_dataset_subset(train_dataset, **subset_kwargs)
    val_subset = create_report_dataset_subset(val_dataset, **subset_kwargs)
    test_subset = create_report_dataset_subset(test_dataset, **subset_kwargs)
    
    # Decide hierachical
    decoder_name = compiled_model.metadata['decoder_kwargs']['decoder_name']
    hierarchical = is_decoder_hierarchical(decoder_name)
    if hierarchical:
        create_dataloader = create_hierarchical_dataloader
    else:
        create_dataloader = create_flat_dataloader

    # Create dataloaders
    BS = 50
    train_dataloader = create_dataloader(train_subset, batch_size=BS)
    val_dataloader = create_dataloader(val_subset, batch_size=BS)
    test_dataloader = create_dataloader(test_subset, batch_size=BS)
    
    # Create a suffix
    if max_n_words:
        suffix = f'max-words-{max_n_words}'
    elif max_n_sentences:
        suffix = f'max-sentences-{max_n_sentences}'
        
    evaluate_and_save(run_name,
                      compiled_model.model,
                      train_dataloader,
                      val_dataloader,
                      test_dataloader,
                      hierarchical=hierarchical,
                      debug=debug,
                      device=device,
                      suffix=suffix,
                     )

In [None]:
eval_n_words = [
    20 , # --> 15%
    25 , # --> 26%
    27 , # --> 33%
    33 , # --> 50%
#     39 , # --> 66%
#     41 , # --> 70%
    44 , # --> 75%
#     47 , # --> 80%
#     58 , # --> 90%
    # None, # --> 100%
]

In [None]:
eval_n_sentences = [
#     1, # 1.2324835387472564
#     2, # 4.761100793516799
    3, # 25.730204288367382
    4, # 55.10720918453487
    5, # 76.66722944453824
    6, # 89.39726489954415
#     7, # 95.03629917271653
#     8, # 97.6194496032416
#     9, # 98.86881647813608
#     10, # 99.42596657099443
#     11, # 99.71298328549722
#     12, # 99.89869998311667
#     13, # 99.96623332770555
#     17, # 99.98311666385278
#     18, # 100
]

In [None]:
run_names = [
#     '0717_041434_lstm_lr0.0001_densenet-121',
    '0716_211601_lstm-att_lr0.0001_densenet-121', # faltan 33 y 34
#     '0717_015057_h-lstm_lr0.0001_densenet-121',
#     '0716_234501_h-lstm-att_lr0.0001_densenet-121',
]
debug = False

In [None]:
for run_name in run_names:
    compiled_model = load_compiled_model_report_generation(run_name,
                                                           debug=debug,
                                                           multiple_gpu=True,
                                                           device=DEVICE)
    for n_words in tqdm(eval_n_words):
        eval_in_subset(run_name,
                       compiled_model,
                       max_n_words=n_words,
                       max_n_sentences=None,
                       debug=debug,
                       device=DEVICE,
                      )
    for n_sentences in tqdm(eval_n_sentences):
        eval_in_subset(run_name,
                       compiled_model,
                       max_n_words=None,
                       max_n_sentences=n_sentences,
                       debug=debug,
                       device=DEVICE,
                      )

## Debug chexpert-labeler

In [None]:
%run ../utils/files.py
%run ../metrics/__init__.py
%run ../metrics/report_generation/chexpert.py
# %run -n ../eval_report_generation_chexpert_labeler.py

In [None]:
run_id = RunId('0428_130903', False, 'rg')
run_id.full_name

In [None]:
df = load_rg_outputs(run_id, free=False)
print(len(df))
df.head()

In [None]:
%%time

df = apply_labeler_to_df(df, caller_id='eval-notebook', dataset_name='mini-mimic')
df.head()

## Debug MIRQI

In [None]:
import pandas as pd
import re

In [None]:
%run -n ../eval_report_generation_mirqi.py

### Load MIRQI output

In [None]:
df = pd.read_csv('~/software/MIRQI/testing2.csv')
df.fillna('', inplace=True)
df.head()

In [None]:
attributes_gt = _attributes_to_list(df['attributes-gt'])
attributes_gen = _attributes_to_list(df['attributes-gen'])
len(attributes_gt), len(attributes_gen)

In [None]:
df['attributes-gt']

In [None]:
%run -n ../eval_report_generation_mirqi.py

In [None]:
scores = MIRQI_v2(attributes_gt, attributes_gen)

In [None]:
scores['MIRQI-v2-attr-p']

In [None]:
idx = 2
attributes_gt[idx], attributes_gen[idx]

## MIRQI Examples

In [None]:
%run -n ../eval_report_generation_mirqi.py
# %run -n ~/software/MIRQI/evaluate.py

### MIRQI original def

In [None]:
def MIRQI(gt_list, cand_list, pos_weight=0.8, attribute_weight=0.3, verbose=False):
    """Compute the score of matching keyword and associated attributes between gt list and candidate list.
       It returns two scores:   MIRQI-r (recall: hits in gt)
                                MIRQI-p (precision: correct ratio of all candidates)
    """

    MIRQI_r = []
    MIRQI_p = []
    MIRQI_f = []

    for gt_report_entry, cand_report_entry in zip(gt_list, cand_list):
        attribute_cand_all = []

        pos_count_in_gt = 0
        pos_count_in_cand = 0
        tp = 0.0
        fp = 0.0
        tn = 0.0
        fn = 0.0

        for gt_entity in gt_report_entry:
            if gt_entity[2] == 'NEGATIVE':
                continue
            pos_count_in_gt = pos_count_in_gt + 1
        neg_count_in_gt = len(gt_report_entry) - pos_count_in_gt

        for entity_index, cand_entity in enumerate(cand_report_entry):
            if cand_entity[2] == 'NEGATIVE':
                for entity_index, gt_entity in enumerate(gt_report_entry):
                    if  gt_entity[1] == cand_entity[1]:
                        if gt_entity[2] == 'NEGATIVE':
                            tn = tn + 1     # true negative hits
                            break
                        else:
                            fn = fn + 1     # false negative hits
                            break
            else:
                pos_count_in_cand = pos_count_in_cand + 1
                for entity_index, gt_entity in enumerate(gt_report_entry):
                    if gt_entity[1] == cand_entity[1]:
                        if gt_entity[2] == 'NEGATIVE':
                            fp = fp + 1     # false positive hits
                            break
                        else:
                            tp = tp + 1.0 - attribute_weight    # true positive hits (key words part)
                            # count attribute hits
                            if gt_entity[3] == '':
                                break
                            attributes_all_gt = gt_entity[3].split('/')
                            attribute_hit_count = 0
                            for attribute in attributes_all_gt:
                                if attribute in cand_entity[3]:
                                    attribute_hit_count = attribute_hit_count + 1
                            # true positive hits (attributes part)
                            temp = attribute_hit_count/len(attributes_all_gt)*attribute_weight
                            tp = tp + temp
                            break
        neg_count_in_cand = len(cand_report_entry) - pos_count_in_cand
        #
        # calculate score for positive/uncertain mentions
        if pos_count_in_gt == 0 and pos_count_in_cand == 0:
            score_r = 1.0
            score_p = 1.0
        elif pos_count_in_gt == 0 and pos_count_in_cand != 0:
            score_r = 0.0
            score_p = 0.0
        elif pos_count_in_gt != 0 and pos_count_in_cand == 0:
            score_r = 0.0
            score_p = 0.0
        else:
            score_r = tp / (tp + fn + 0.000001)
            score_p = tp / (tp + fp + 0.000001)

        # calculate score for negative mentions
        # if neg_count_in_cand != 0 and neg_count_in_gt != 0:
        if tn != 0:
            score_r = score_r * pos_weight + tn / (tn + fp + 0.000001) * (1.0 - pos_weight)
            score_p = score_p * pos_weight + tn / (tn + fn + 0.000001) * (1.0 - pos_weight)

        MIRQI_r.append(score_r)
        MIRQI_p.append(score_p)
        rec_prec = (score_r + score_p)
        MIRQI_f.append(2*(score_r * score_p) / rec_prec if rec_prec != 0.0 else 0.0)

    scores = {
        'MIRQI-r': MIRQI_r,
        'MIRQI-p': MIRQI_p,
        'MIRQI-f': MIRQI_f,
    }

    return scores

### Robust matching

#### Repeated nodes with different attributes

In [None]:
report_gt = "right effusion with mild atelectasis. left effusion is also present."
entities_gt = [
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'right/present'],
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'left/pleural'],
]
report_gen = report_gt
entities_gen = [
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'left/pleural'],
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'right/present'],
]
{
    **MIRQI([entities_gt], [entities_gen]),
    **MIRQI_v2([entities_gt], [entities_gen]),
}

#### GT nodes matched twice

In [None]:
report_gt = "right pleural effusion."
entities_gt = [
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'right'],
]
report_gen = "right pleural effusion. left pleural effusion"
entities_gen = [
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'right'],
    ['effusion', 'Pleural Effusion', 'POSITIVE', 'left'],
]
{
    **MIRQI([entities_gt], [entities_gen]),
    **MIRQI_v2([entities_gt], [entities_gen]),
}