Debug training processes

## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
from torch import nn

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run -n ../train_report_generation.py

In [None]:
DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
DEVICE

## Load data

In [None]:
%run ../datasets/__init__.py
%run ../utils/nlp.py

In [None]:
BS = 20

dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'hierarchical': True,
    'max_samples': None,
    'batch_size': BS,
    'frontal_only': True,
    'image_size': (256, 256),
    'sort_samples': False,
    'masks': True,
    # 'num_workers': 0,
}

train_dataloader = prepare_data_report_generation(dataset_type='train',
                                                  **dataset_kwargs)
VOCAB = train_dataloader.dataset.get_vocab()
val_dataloader = prepare_data_report_generation(dataset_type='val',
                                                vocab=VOCAB,
                                                **dataset_kwargs)
len(train_dataloader.dataset)

In [None]:
VOCAB_SIZE = len(VOCAB)
VOCAB_SIZE

In [None]:
REPORT_READER = ReportReader(VOCAB)
REPORT_READER

### Debug hierarchical dataloader

In [None]:
for batch in train_dataloader:
    break

In [None]:
batch.masks.min(), batch.masks.max()

In [None]:
report_reader = ReportReader(train_dataloader.dataset.get_vocab())

In [None]:
item_idx = 1
report = batch.reports[item_idx]
mask = batch.masks[item_idx]
report.size(), mask.size()

In [None]:
plt.figure(figsize=(15, 5))
n_sentences = mask.size(0)
n_cols = n_sentences

for i_sentence in range(n_sentences):
    submask = mask[i_sentence]
    
    title = f'Sentence {i_sentence}'
    
    min_value = submask.min().item()
    if min_value == submask.max().item():
        unique_value = min_value
        title += f' ({unique_value})'
    
    plt.subplot(1, n_cols, i_sentence + 1)
    plt.imshow(submask)
    plt.title(title)
    plt.axis('off')
    
    sentence = report_reader.idx_to_text(report[i_sentence])
    print(f'{i_sentence}: {sentence}')
    
plt.show()

## Create model

In [None]:
%run ../models/classification/__init__.py
%run ../models/report_generation/cnn_to_seq.py
%run ../models/checkpoint/__init__.py

### Create CNN

#### Load CNN

In [None]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### New CNN

In [None]:
cnn = create_cnn('mobilenet-v2', # resnet-50 # densenet-121
                 labels=[],
                 imagenet=True,
                 freeze=False,
                ).to(DEVICE)

### Create Decoder

In [None]:
decoder_kwargs = {
    'decoder_name': 'lstm-v2',
    'vocab': VOCAB,
    'embedding_size': 100,
    'embedding_kwargs': { 'pretrained': 'radglove' },
    'hidden_size': 100,
    'features_size': cnn.features_size,
    'teacher_forcing': True,
    'dropout_recursive': 0,
    'dropout_out': 0,
    'double_bias': False,
}
decoder = create_decoder(**decoder_kwargs).to(DEVICE)

### cnn2seq

In [None]:
%run ../losses/optimizers.py

In [None]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

In [None]:
optimizer = create_optimizer(model, custom_lr={ 'word_embedding': 0.05 }, lr=0.0001)
optimizer

### Load CNN2Seq model

In [None]:
%run models/checkpoint/__init__.py

In [None]:
# run_name = '0717_184851_lstm_lr0.0001_densenet-121_size256'
debug = False

run_id = RunId(run_name, debug, 'rg')

compiled_model = load_compiled_model_report_generation(run_id, device=DEVICE)

compiled_model.metadata

In [None]:
compiled_model.metadata['vocab']

## Debug att-supervision loss

In [None]:
%run ../losses/out_of_target.py

In [None]:
masks1 = train_dataloader.dataset[30].masks
masks2 = train_dataloader.dataset[100].masks
masks3 = train_dataloader.dataset[200].masks
masks4 = train_dataloader.dataset[1].masks
target = torch.stack([torch.stack([masks1, masks2]), torch.stack([masks3, masks4])])
target.size()

In [None]:
shape = target.size()[:2]
output = torch.rand(*shape, 16, 16)
# output = torch.ones(*target.size())
output = output.view(*shape, -1)
output = torch.softmax(output, dim=-1)
output = output.view(*shape, 16, 16)
output.size()

In [None]:
loss = OutOfTargetSumLoss()
x = loss(output, target)
x.item()

In [None]:
# target = (torch.rand(1, 1, 256, 256) > 0.5).long()
target = masks
target.size()

In [None]:
target2 = interpolate(target.float(), size=(16, 16), mode='nearest')
target2.size()

In [None]:
batch_size, n_sentences = target.size()[:2]

n_rows = batch_size
n_cols = n_sentences * 2

plt.figure(figsize=(15, 8))

plot_index = 1
for idx1 in range(batch_size):
    for idx2 in range(n_sentences):
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(target[idx1][idx2])
        plt.title(f'Original - {idx1},{idx2}')
        plot_index += 1
        
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(output[idx1][idx2])
        plt.title(f'Downsampled - {idx1},{idx2}')
        plot_index += 1

## Debug flatten h-reports

In [None]:
%run ../training/report_generation/hierarchical.py

In [None]:
for batch in tqdm(val_dataloader):
    images = batch.images.to(DEVICE)
    reports = batch.reports.to(DEVICE)

    with torch.no_grad():
        output = compiled_model.model(images, reports, free=True,
                                      max_words=100, max_sentences=100)
    gen_words, gen_stops, gen_scores, gen_topics = output
    
    break

In [None]:
for report in _flatten_gen_reports(gen_words, gen_stops):
    break

In [None]:
gen_words = gen_words.argmax(dim=-1)

valid_sentences = torch.where(
    gen_stops < 0.5,
    torch.ones_like(gen_stops),
    torch.zeros_like(gen_stops),
)

gen_words.size(), valid_sentences.size()

In [None]:
end_of_sentence_tensor = torch.tensor([4], device=gen_words.device)
end_of_sentence_tensor = end_of_sentence_tensor.expand(gen_words.size(0),
                                                       gen_words.size(1),
                                                       -1)
end_of_sentence_tensor.size()

In [None]:
gen_words = torch.cat((gen_words, end_of_sentence_tensor), dim=-1)
gen_words.size()

In [None]:
for report, valid in zip(gen_words, valid_sentences):
    report = report[valid.nonzero(as_tuple=True)]
    break
report.size()

In [None]:
dot_positions = (report == END_OF_SENTENCE_IDX).type(torch.uint8)
print(dot_positions.size())
dot_positions

In [None]:
dot_positions = torch.cumsum(dot_positions, dim=1)
dot_positions = torch.cumsum(dot_positions, dim=1)
dot_positions

In [None]:
report[dot_positions <= 1]

In [None]:
REPORT_READER.idx_to_text(report[dot_positions <= 1])

In [None]:
REPORT_READER.idx_to_text(report)