# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=2

In [None]:
import torch
from torch import nn

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run ../utils/logging.py
config_logging()

In [None]:
%run -n ../train_report_generation.py

In [None]:
DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
DEVICE

# Load previous model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/nlp.py

In [None]:
# run_name, debug = '1113_183215', False
# run_name, debug = '0607_002702', Fase
run_name, debug = '1119_182557', True
run_id = RunId(run_name, debug, task='rg')
run_id

In [None]:
compiled_model = load_compiled_model_report_generation(run_id, device=DEVICE)

compiled_model.metadata.keys()

In [None]:
meta = compiled_model.metadata
VOCAB = meta['dataset_kwargs'].get('vocab') or \
    meta['model_kwargs']['decoder_kwargs'].get('vocab')
assert VOCAB is not None
REPORT_READER = ReportReader(VOCAB)
len(VOCAB)

In [None]:
_ = compiled_model.model.to(DEVICE)

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
BS = 5

dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'hierarchical': False,
    # 'max_samples': None,
    'frontal_only': True,
    'image_size': (256, 256),
    **compiled_model.metadata['dataset_kwargs'],
    'max_samples': 100,
    'batch_size': 20,
    # 'sort_samples': True,
#     'shuffle': True,
#     'num_workers': 1,
#     'batch_size': BS,
}

train_dataloader = prepare_data_report_generation(dataset_type='train', **dataset_kwargs)
val_dataloader = prepare_data_report_generation(dataset_type='val', **dataset_kwargs)
len(train_dataloader.dataset)

## Debug hierarchical dataloader

In [None]:
from torch.nn.functional import interpolate

In [None]:
i = 0
for batch in train_dataloader:
    i += 1
    if i == 10:
        break

In [None]:
batch.masks.min(), batch.masks.max()

In [None]:
report_reader = ReportReader(train_dataloader.dataset.get_vocab())

In [None]:
item_idx = 0
report = batch.reports[item_idx]
mask = batch.masks[item_idx]
report.size(), mask.size()

In [None]:
mask = interpolate(mask.unsqueeze(0).float(), (8, 8), mode='nearest').squeeze(0).long()
mask.size()

In [None]:
plt.figure(figsize=(15, 5))
n_sentences = mask.size(0)
n_cols = n_sentences

for i_sentence in range(n_sentences):
    submask = mask[i_sentence]
    
    title = f'Sentence {i_sentence}'
    
    min_value = submask.min().item()
    if min_value == submask.max().item():
        unique_value = min_value
        title += f' (all={unique_value})'
    
    plt.subplot(1, n_cols, i_sentence + 1)
    plt.imshow(submask)
    plt.title(title)
    # plt.axis('off')
    
    sentence = report_reader.idx_to_text(report[i_sentence])
    print(f'{i_sentence}: {sentence}')
    
plt.show()

# Create model

If not loaded before

In [None]:
%run ../models/classification/__init__.py
%run ../models/report_generation/cnn_to_seq.py
%run ../models/checkpoint/__init__.py
%run ../losses/optimizers.py

## Load CNN

In [None]:
run_name = '0611_155356'
run_id = RunId(run_name, task='cls', debug=False)
run_id

In [None]:
compiled_cnn = load_compiled_model(run_id, device=DEVICE)
cnn = compiled_cnn.model

### or new CNN

In [None]:
cnn = create_cnn('mobilenet-v2', # resnet-50 # densenet-121
                 labels=[],
                 imagenet=True,
                 freeze=False,
                ).to(DEVICE)

## Create decoder

In [None]:
decoder_kwargs = {
    'decoder_name': 'h-lstm-att-v2',
    'vocab': VOCAB,
    'embedding_size': 100,
    'embedding_kwargs': { 'pretrained': 'radglove' },
    'hidden_size': 100,
    'features_size': cnn.features_size,
    'teacher_forcing': True,
    'dropout_recursive': 0,
    'dropout_out': 0,
    'double_bias': False,
}
decoder = create_decoder(**decoder_kwargs).to(DEVICE)

## CNN-2-seq

In [None]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

In [None]:
optimizer = create_optimizer(model, custom_lr={ 'word_embedding': 0.05 }, lr=0.0001)
optimizer

# Debug rolling-average

In [None]:
trainer, validator = train_model(
    run_id, compiled_model, train_dataloader, val_dataloader, n_epochs=11,
    medical_correctness=False,
    print_metrics=['bleu1', 'ciderD', 'rougeL'],
    checkpoint_metric=['bleu1', 'bleu2'],
    tb_kwargs={'scalars': False},
    lambda_att=0,
    device=DEVICE,
)

# Debug att-supervision loss

In [None]:
import torch.nn.functional as F

In [None]:
%run ../losses/out_of_target.py

In [None]:
i = 0
for batch in train_dataloader:
    i += 1
    if i == 200:
        break

In [None]:
batch.stops.size()

In [None]:
batch.stops

In [None]:
target = batch.masks
target.size()

In [None]:
target = F.interpolate(target.float(), (16, 16), mode='nearest') # .long()
target.size()

In [None]:
shape = target.size()[:2]
output = torch.rand(*shape, 16, 16)
# output = torch.ones(*target.size())
# output = output.view(*shape, -1)
# output = torch.softmax(output, dim=-1)
# output = output.view(*shape, 16, 16)
output.size()

In [None]:
loss = OutOfTargetSumLoss()
x = loss(output, target)
x.item()

In [None]:
loss = F.binary_cross_entropy(output, target.float(), reduction='none')
loss

In [None]:
l = loss[(target == 0) & (batch.stops.unsqueeze(-1).unsqueeze(-1) == 0)]
l

In [None]:
torch.tensor([]).sum()

In [None]:
for report in batch.reports:
    print(REPORT_READER.idx_to_text(report.view(-1)))

In [None]:
n_samples, n_sentences = shape
plt_index = 1
for i_sample in range(n_samples):
    for j_sentence in range(n_sentences):
        mask = target[i_sample, j_sentence]
        
        print(batch.stops[i_sample, j_sentence], mask.min(), mask.max())
        plt.subplot(n_samples, n_sentences, plt_index)
        plt.imshow(mask)
        plt_index += 1

In [None]:
# target = (torch.rand(1, 1, 256, 256) > 0.5).long()
target = masks
target.size()

In [None]:
target2 = interpolate(target.float(), size=(16, 16), mode='nearest')
target2.size()

In [None]:
batch_size, n_sentences = target.size()[:2]

n_rows = batch_size
n_cols = n_sentences * 2

plt.figure(figsize=(15, 8))

plot_index = 1
for idx1 in range(batch_size):
    for idx2 in range(n_sentences):
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(target[idx1][idx2])
        plt.title(f'Original - {idx1},{idx2}')
        plot_index += 1
        
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(output[idx1][idx2])
        plt.title(f'Downsampled - {idx1},{idx2}')
        plot_index += 1

# Debug h-reports

## Organ-by-sentence metric

In [None]:
%run ../training/report_generation/hierarchical.py

In [None]:
for batch in val_dataloader:
    images = batch.images.to(DEVICE)
    reports = batch.reports.to(DEVICE)

    with torch.no_grad():
        output = compiled_model.model(images, reports, free=True,
                                      max_words=100, max_sentences=100)
    gen_words, gen_stops, gen_scores, gen_topics = output
    
    gen_reports = _flatten_gen_reports(gen_words, gen_stops, threshold=0.5)
    gt_reports = _flatten_gt_reports(reports)
    break

In [None]:
%run ../metrics/report_generation/organ_by_sentence.py

In [None]:
m = OrganBySentence(VOCAB)
m.reset()

In [None]:
m.update((gen_reports, gt_reports))
m.compute()

In [None]:
def print_report(r):
    for s in sentence_iterator(r):
        print(REPORT_READER.idx_to_text(s))

In [None]:
for r1, r2 in zip(gen_reports, gt_reports):
    print_report(r1)
    print('-'*30)
    print_report(r2)
    print('='*50)

# Debug show [attend] and tell

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py
%run ../utils/nlp.py
%run ../training/report_generation/flat.py

In [None]:
reports_ids_to_text = lambda rr: [reader.idx_to_text(r) for r in rr]

In [None]:
run_name = '0611_155356'
run_id = RunId(run_name, task='cls', debug=False)
run_id

In [None]:
compiled_cnn = load_compiled_model(run_id, device=DEVICE)
# compiled_cnn.model

In [None]:
vocab = train_dataloader.dataset.get_vocab()
reader = ReportReader(vocab)
len(vocab)

In [None]:
loader = iter(train_dataloader)

In [None]:
batch = next(loader)

In [None]:
batch.images.size(), batch.reports.size()

In [None]:
images = batch.images.cuda()
reports = batch.reports.cuda()

In [None]:
run_name = '1113_183215'
run_id = RunId(run_name, task='rg', debug=False)
run_id

In [None]:
compiled_model = load_compiled_model(run_id, device=DEVICE, mode='bleu4')
decoder = compiled_model.model.decoder

In [None]:
# image_features = compiled_model.model.cnn.features(images)
image_features = compiled_cnn.model.features(images)
image_features.size()

In [None]:
%run ../models/report_generation/decoder_show_attend_tell.py

In [None]:
decoder = ShowAttendTellDecoder(vocab, 100, 512, 1024).cuda()
# decoder
# _ = decoder.eval()

In [None]:
words_out, scores_out = decoder(image_features, reports=reports, free=True, max_words=10)
words_out.size(), scores_out.size()

In [None]:
words_out, scores_out = decoder.caption(image_features[:1], beam_size=5, max_words=10)
len(words_out)

In [None]:
reports_ids_to_text(_clean_gen_reports(words_out))

In [None]:
%run ../models/report_generation/decoder_show_tell.py

In [None]:
decoder = ShowTellDecoder(vocab, 100, 512, 1024).cuda()

In [None]:
words_out, = decoder(image_features, reports=reports, free=True, max_words=10)
words_out.size()

In [None]:
reports_ids_to_text(_clean_gen_reports(words_out))

In [None]:
words_out, = decoder.caption(image_features[:1], beam_size=5, max_words=10, debug=True)
words_out.size()

In [None]:
w, scores = decoder.caption(image_features[:1], beam_size=5, max_words=10, debug=True)
w = torch.stack(w)
w.size()

In [None]:
reports_ids_to_text(_clean_gen_reports(w))

In [None]:
%run ../models/report_generation/cnn_to_seq.py

In [None]:
model = CNN2Seq(compiled_cnn.model, decoder)

In [None]:
%run ../training/report_generation/flat.py

In [None]:
step_fn = get_step_fn_flat(model, training=False, free=True, beam_size=20, max_words=10)

In [None]:
out = step_fn(None, batch)

In [None]:
reports_ids_to_text(out['flat_clean_reports_gen'])

In [None]:
reports_ids_to_text(out['flat_clean_reports_gt'])