# Import

In [None]:
%env CUDA_VISIBLE_DEVICES=2

In [None]:
import torch

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
DEVICE = torch.device('cuda')
DEVICE

# Run model

## Load model

In [None]:
%run ../models/report_generation/__init__.py
%run ../models/checkpoint/__init__.py
%run ../utils/files.py
%run ../utils/nlp.py

In [None]:
# run_name = '0716_211601_lstm-att_lr0.0001_densenet-121'
# run_name = '0115_175006_h-lstm-att-v2_lr0.001_satt_densenet-121-v2_noes'
# run_name = '0115_064249_h-lstm-att-v2_lr0.001_densenet-121-v2_noes_front'
# run_name = '0401_222625_mimic-cxr_lstm-v2_lr0.0001_mobilenet-v2_size256'

# run_name = '0513_123117' # lstm
# run_name = '0513_145846' # lstm-att
# run_name = '0513_174148' # h-lstm
# run_name = '0513_200618' # h-lstm-att

run_name = '0607_002702'

run_id = RunId(run_name, False, 'rg')

In [None]:
compiled_model = load_compiled_model_report_generation(run_id, device=DEVICE)
_ = compiled_model.model.eval()
compiled_model.metadata['decoder_kwargs'].keys()

In [None]:
VOCAB = compiled_model.metadata['decoder_kwargs']['vocab']
len(VOCAB)

In [None]:
REPORT_READER = ReportReader(VOCAB)

In [None]:
decoder_name = compiled_model.metadata['decoder_kwargs']['decoder_name']
HIERARCHICAL = is_decoder_hierarchical(decoder_name)
HIERARCHICAL

## Load data

In [None]:
%run ../datasets/__init__.py
%run ../utils/nlp.py

In [None]:
model_dataset_kwargs = compiled_model.metadata['dataset_kwargs']
dataset_kwargs = {
    **model_dataset_kwargs,
    'sort_samples': True,
    'shuffle': False,
    'batch_size': 2,
}

train_dataloader = prepare_data_report_generation(dataset_type='train', **dataset_kwargs)
val_dataloader = prepare_data_report_generation(dataset_type='val', **dataset_kwargs)
test_dataloader = prepare_data_report_generation(dataset_type='test', **dataset_kwargs)
len(train_dataloader.dataset), len(val_dataloader.dataset), len(test_dataloader.dataset)

## Eval

In [None]:
%run ../training/report_generation/hierarchical.py
%run ../training/report_generation/flat.py
%run ../utils/common.py

In [None]:
def eval_sample(batch, device=DEVICE, free=False, **kwargs):
    # Prepare inputs
#     images = item.image.unsqueeze(0).to(device)
#     report = item.report
#     if HIERARCHICAL:
#         reports = split_sentences_and_pad(report)
#     else:
#         reports = torch.tensor(report)

#     reports = reports.unsqueeze(0).to(device)
    
    images = batch.images.to(device)
    reports = batch.reports.to(device)
    
    # Pass thru model
    if not HIERARCHICAL:
        del kwargs['max_sentences']
    tup = compiled_model.model(images, reports, free=free, **kwargs)
    
    # Parse outputs
    if HIERARCHICAL:
        generated = _flatten_gen_reports(tup[0], tup[1])
    else:
        generated = _clean_gen_reports(tup[0])

    return reports, generated, tup

In [None]:
def to_text(reports):
    return [REPORT_READER.idx_to_text(r) for r in reports]

def print_result(reports, generated):
    def print_list(l):
        for x in l:
            print(f'\t{x}')

    reports = to_text(reports)
    generated = to_text(generated)
    print('GROUND TRUTH:')
    print_list(reports)
    print('-'*20)
    print('GENERATED:')
    print_list(generated)
    return reports, generated

### Check stops array

In [None]:
iter_dataloader = iter(val_dataloader)

In [None]:
loss_fn = nn.BCELoss()

In [None]:
# for _ in range(0):
batch = next(iter_dataloader)

In [None]:
gt, gen, out = eval_sample(batch, free=False, max_sentences=100, max_words=40)
gt_str, gen_str = print_result(gt, gen)

In [None]:
stops = out[1].detach().cpu()
print(stops.size())
print(batch.stops.size())
stops > 0.5

In [None]:
l = loss_fn(stops, batch.stops)
l

### Plot attentions

In [None]:
from captum.attr import visualization
from skimage.color import gray2rgb
from skimage.transform import resize

In [None]:
scores = out[2].detach().squeeze(0).cpu()
scores.size()

In [None]:
image_color = image.detach().permute(1, 2, 0).cpu().numpy()
image_color = arr_to_range(image_color)

image_color.shape

In [None]:
sentence_idx = 1
heatmap = scores[sentence_idx].numpy()
heatmap = gray2rgb(heatmap)
heatmap = resize(heatmap, image_color.shape)
heatmap.shape

In [None]:
REPORT_READER.idx_to_text(out_words[sentence_idx])

In [None]:
visualization.visualize_image_attr_multiple(heatmap,
                                            image_color,
                                            methods=['original_image',
                                                     'blended_heat_map'],
                                            signs=['all', 'positive'],
                                            cmap='jet',
                                            show_colorbar=True,
                                           )

# Check result-generated reports

In [None]:
%run ../metrics/report_generation/writer.py
%run ../utils/files.py

In [None]:
def load_generated_reports(run_id, free=True):
    fpath = _get_outputs_fpath(run_id, free=free)
    
    df = pd.read_csv(fpath)
    
    return df

In [None]:
run_ids = [
    RunId(name, False, 'rg')
    for name in (
        '0513_123117', # lstm
        '0513_145846', # lstm-att
        '0513_174148', # h-lstm
        '0513_200618', # h-lstm-att
    )
]

In [None]:
TOTAL_DF = None

In [None]:
for run_id in run_ids:
    df = load_generated_reports(run_id)
    del df['epoch']
    
    gen_col_name = f'gen-{run_id.short_clean_name}'
    df.rename(columns={'generated': gen_col_name}, inplace=True)
    
    if TOTAL_DF is None:
        df = df[['image_fname', 'filename', 'dataset_type', 'ground_truth', gen_col_name]]
        TOTAL_DF = df
    else:
        df = df[['image_fname', gen_col_name]]
        TOTAL_DF = TOTAL_DF.merge(df, on='image_fname', how='outer')
        
TOTAL_DF = TOTAL_DF.sort_values(by='ground_truth', key=lambda x: x.str.len())
TOTAL_DF.reset_index(drop=True, inplace=True)
len(TOTAL_DF)

In [None]:
TOTAL_DF.head()

In [None]:
def print_sample(idx):
    def _print_report(report, name):
        print(name)
        print(report)
        print('-'*30)

    row = TOTAL_DF.iloc[idx]
    
    print(f"{row['filename']} {row['image_fname']} ({row['dataset_type']})")
    
    gt = row['ground_truth']
    _print_report(gt, 'GT')

    gen_cols = [c for c in TOTAL_DF.columns if c.startswith('gen-')]    
    for col in gen_cols:
        gen = row[col]
        _print_report(gen, col)

In [None]:
print_sample(-101)