## Import

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
DEVICE = torch.device('cuda')
DEVICE

## Load model

In [None]:
%run ../models/report_generation/__init__.py
%run ../models/checkpoint/__init__.py

In [None]:
# run_name = '0716_211601_lstm-att_lr0.0001_densenet-121'
# run_name = '0115_175006_h-lstm-att-v2_lr0.001_satt_densenet-121-v2_noes'
run_name = '0115_064249_h-lstm-att-v2_lr0.001_densenet-121-v2_noes_front'
debug = False

In [None]:
compiled_model = load_compiled_model_report_generation(run_name,
                                                       debug=debug,
                                                       device=DEVICE,
                                                      )

_ = compiled_model.model.eval()
compiled_model.metadata['decoder_kwargs']

In [None]:
VOCAB = compiled_model.metadata['vocab']
len(VOCAB)

## Load data

In [None]:
%run ../datasets/iu_xray.py
%run ../utils/nlp.py

In [None]:
dataset_kwargs = {
    'max_samples': None,
    'frontal_only': True,
    'image_size': (512, 512),
    'vocab': VOCAB,
}

train_dataset = IUXRayDataset(dataset_type='train', **dataset_kwargs)
val_dataset = IUXRayDataset(dataset_type='val', **dataset_kwargs)
test_dataset = IUXRayDataset(dataset_type='test', **dataset_kwargs)
len(train_dataset), len(val_dataset), len(test_dataset)

## Eval

In [None]:
from captum.attr import visualization
from skimage.color import rgb2gray, gray2rgb
from skimage.transform import resize

In [None]:
%run ../training/report_generation/hierarchical.py
%run ../utils/common.py
%run ../utils/nlp.py

In [None]:
report_reader = ReportReader(compiled_model.metadata['vocab'])

In [None]:
def eval_sample(compiled_model, image, report,
                show=True, device=DEVICE, free=False, **kwargs):
    is_hierarchical = compiled_model.metadata['decoder_kwargs']['decoder_name'].startswith('h-')
    
    # Prepare inputs
    images = image.unsqueeze(0).to(device)
    if is_hierarchical:
        reports = split_sentences_and_pad(report)
    else:
        reports = torch.tensor(report)

    reports = reports.unsqueeze(0).to(device)
    
    # Pass thru model
    if not is_hierarchical:
        del kwargs['max_sentences']
    tup = compiled_model.model(images, reports, free=free, **kwargs)
    
    # Parse outputs
    if is_hierarchical:
        generated = _flatten_gen_reports(tup[0], tup[1])
    else:
        generated = tup[0]
        _, generated = generated.max(dim=-1)

    generated = generated.squeeze(0).cpu()
    
    # Print result
    original_report = report_reader.idx_to_text(report)
    generated_report = report_reader.idx_to_text(generated)
    if show:
        print('GROUND TRUTH:\n', original_report)
        print('-'*20)
        print('GENERATED:\n', generated_report)
        
    return original_report, generated_report, tup

In [None]:
idx = 20

item = train_dataset[idx]
image = item.image
report = item.report

gt, gen, other = eval_sample(compiled_model, image, report,
                             free=True, max_sentences=100, max_words=40)

In [None]:
stops = other[1].detach().cpu()
# stops = stops[0]
print(stops.size())
stops

In [None]:
out_words = other[0].detach().squeeze(0)
out_words = out_words.argmax(dim=-1)
out_words.size()

In [None]:
scores = other[2].detach().squeeze(0).cpu()
scores.size()

In [None]:
image_color = image.detach().permute(1, 2, 0).cpu().numpy()
image_color = arr_to_range(image_color)

image_color.shape

In [None]:
sentence_idx = 1
heatmap = scores[sentence_idx].numpy()
heatmap = gray2rgb(heatmap)
heatmap = resize(heatmap, image_color.shape)
heatmap.shape

In [None]:
report_reader.idx_to_text(out_words[sentence_idx])

In [None]:
visualization.visualize_image_attr_multiple(heatmap,
                                            image_color,
                                            methods=['original_image',
                                                     'blended_heat_map'],
                                            signs=['all', 'positive'],
                                            cmap='jet',
                                            show_colorbar=True,
                                           )

## Debug report-reader

In [None]:
%run ../utils/nlp.py

In [None]:
report_reader = ReportReader(compiled_model.metadata['vocab'], ignore_pad=True)

In [None]:
idx = 20

item = train_dataset[idx]
images = item.image
images = item.image.unsqueeze(0).to(DEVICE)
reports = split_sentences_and_pad(item.report)

out = compiled_model.model(images, reports, free=False)
out = out[0].argmax(dim=-1).squeeze(0).detach().cpu()
out.size()

In [None]:
out

In [None]:
report_reader.idx_to_text(out)

In [None]:
out.tolist()