Debug training processes

## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
from torch import nn

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run -n ../train_report_generation.py

In [None]:
DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
DEVICE

## Load previous model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/nlp.py

In [None]:
run_id = RunId('0504_053805', False, 'rg')
compiled_model = load_compiled_model_report_generation(run_id, device=DEVICE)

compiled_model.metadata.keys()

In [None]:
VOCAB = compiled_model.metadata['dataset_kwargs']['vocab']
REPORT_READER = ReportReader(VOCAB)
len(VOCAB)

## Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
BS = 2

dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'hierarchical': True,
    'max_samples': None,
    'batch_size': BS,
    'frontal_only': True,
    'image_size': (256, 256),
    'sort_samples': False,
    'shuffle': True,
    'masks': True,
    'vocab': VOCAB,
    # 'num_workers': 0,
}

train_dataloader = prepare_data_report_generation(dataset_type='train', **dataset_kwargs)
val_dataloader = prepare_data_report_generation(dataset_type='val', **dataset_kwargs)
len(train_dataloader.dataset)

### Debug hierarchical dataloader

In [None]:
from torch.nn.functional import interpolate

In [None]:
i = 0
for batch in train_dataloader:
    i += 1
    if i == 10:
        break

In [None]:
batch.masks.min(), batch.masks.max()

In [None]:
report_reader = ReportReader(train_dataloader.dataset.get_vocab())

In [None]:
item_idx = 0
report = batch.reports[item_idx]
mask = batch.masks[item_idx]
report.size(), mask.size()

In [None]:
mask = interpolate(mask.unsqueeze(0).float(), (8, 8), mode='nearest').squeeze(0).long()
mask.size()

In [None]:
plt.figure(figsize=(15, 5))
n_sentences = mask.size(0)
n_cols = n_sentences

for i_sentence in range(n_sentences):
    submask = mask[i_sentence]
    
    title = f'Sentence {i_sentence}'
    
    min_value = submask.min().item()
    if min_value == submask.max().item():
        unique_value = min_value
        title += f' (all={unique_value})'
    
    plt.subplot(1, n_cols, i_sentence + 1)
    plt.imshow(submask)
    plt.title(title)
    # plt.axis('off')
    
    sentence = report_reader.idx_to_text(report[i_sentence])
    print(f'{i_sentence}: {sentence}')
    
plt.show()

## Create model

If not loaded before

In [None]:
%run ../models/classification/__init__.py
%run ../models/report_generation/cnn_to_seq.py
%run ../models/checkpoint/__init__.py
%run ../losses/optimizers.py

### Load CNN

In [None]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

### or new CNN

In [None]:
cnn = create_cnn('mobilenet-v2', # resnet-50 # densenet-121
                 labels=[],
                 imagenet=True,
                 freeze=False,
                ).to(DEVICE)

### Create decoder

In [None]:
decoder_kwargs = {
    'decoder_name': 'h-lstm-att-v2',
    'vocab': VOCAB,
    'embedding_size': 100,
    'embedding_kwargs': { 'pretrained': 'radglove' },
    'hidden_size': 100,
    'features_size': cnn.features_size,
    'teacher_forcing': True,
    'dropout_recursive': 0,
    'dropout_out': 0,
    'double_bias': False,
}
decoder = create_decoder(**decoder_kwargs).to(DEVICE)

### CNN-2-seq

In [None]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

In [None]:
optimizer = create_optimizer(model, custom_lr={ 'word_embedding': 0.05 }, lr=0.0001)
optimizer

## Debug att-supervision loss

In [None]:
import torch.nn.functional as F

In [None]:
%run ../losses/out_of_target.py

In [None]:
i = 0
for batch in train_dataloader:
    i += 1
    if i == 200:
        break

In [None]:
batch.stops.size()

In [None]:
batch.stops

In [None]:
target = batch.masks
target.size()

In [None]:
target = F.interpolate(target.float(), (16, 16), mode='nearest') # .long()
target.size()

In [None]:
shape = target.size()[:2]
output = torch.rand(*shape, 16, 16)
# output = torch.ones(*target.size())
# output = output.view(*shape, -1)
# output = torch.softmax(output, dim=-1)
# output = output.view(*shape, 16, 16)
output.size()

In [None]:
loss = OutOfTargetSumLoss()
x = loss(output, target)
x.item()

In [None]:
loss = F.binary_cross_entropy(output, target.float(), reduction='none')
loss

In [None]:
l = loss[(target == 0) & (batch.stops.unsqueeze(-1).unsqueeze(-1) == 0)]
l

In [None]:
torch.tensor([]).sum()

In [None]:
for report in batch.reports:
    print(REPORT_READER.idx_to_text(report.view(-1)))

In [None]:
n_samples, n_sentences = shape
plt_index = 1
for i_sample in range(n_samples):
    for j_sentence in range(n_sentences):
        mask = target[i_sample, j_sentence]
        
        print(batch.stops[i_sample, j_sentence], mask.min(), mask.max())
        plt.subplot(n_samples, n_sentences, plt_index)
        plt.imshow(mask)
        plt_index += 1

In [None]:
# target = (torch.rand(1, 1, 256, 256) > 0.5).long()
target = masks
target.size()

In [None]:
target2 = interpolate(target.float(), size=(16, 16), mode='nearest')
target2.size()

In [None]:
batch_size, n_sentences = target.size()[:2]

n_rows = batch_size
n_cols = n_sentences * 2

plt.figure(figsize=(15, 8))

plot_index = 1
for idx1 in range(batch_size):
    for idx2 in range(n_sentences):
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(target[idx1][idx2])
        plt.title(f'Original - {idx1},{idx2}')
        plot_index += 1
        
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(output[idx1][idx2])
        plt.title(f'Downsampled - {idx1},{idx2}')
        plot_index += 1

## Debug h-reports

### Organ-by-sentence metric

In [None]:
%run ../training/report_generation/hierarchical.py

In [None]:
for batch in val_dataloader:
    images = batch.images.to(DEVICE)
    reports = batch.reports.to(DEVICE)

    with torch.no_grad():
        output = compiled_model.model(images, reports, free=True,
                                      max_words=100, max_sentences=100)
    gen_words, gen_stops, gen_scores, gen_topics = output
    
    gen_reports = _flatten_gen_reports(gen_words, gen_stops, threshold=0.5)
    gt_reports = _flatten_gt_reports(reports)
    break

In [None]:
%run ../metrics/report_generation/organ_by_sentence.py

In [None]:
m = OrganBySentence(VOCAB)
m.reset()

In [None]:
m.update((gen_reports, gt_reports))
m.compute()

In [None]:
def print_report(r):
    for s in sentence_iterator(r):
        print(REPORT_READER.idx_to_text(s))

In [None]:
for r1, r2 in zip(gen_reports, gt_reports):
    print_report(r1)
    print('-'*30)
    print_report(r2)
    print('='*50)