## Imports

In [None]:
!echo $CUDA_VISIBLE_DEVICES

In [None]:
import torch
from torch import nn

In [None]:
%run -n ../train_report_generation.py

In [None]:
DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
DEVICE

## Load stuff

### Load data

In [None]:
# %run ../datasets/iu_xray.py
%run ../datasets/__init__.py
%run ../training/report_generation/flat.py
%run ../training/report_generation/hierarchical.py

In [None]:
# hierarchical = is_decoder_hierarchical(decoder_name)
hierarchical = True
if hierarchical:
    create_dataloader = create_hierarchical_dataloader
else:
    create_dataloader = create_flat_dataloader

In [None]:
BS = 150

dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'max_samples': None,
    'batch_size': BS,
    'frontal_only': True,
    'image_size': (256, 256),
    'sort_samples': False,
    'masks': True,
    # 'num_workers': 0,
}

train_dataloader = prepare_data_report_generation(create_dataloader, dataset_type='train',
                                                  **dataset_kwargs)
VOCAB = train_dataloader.dataset.get_vocab()
val_dataloader = prepare_data_report_generation(create_dataloader, dataset_type='val',
                                                vocab=VOCAB,
                                                **dataset_kwargs)
len(train_dataloader.dataset)

In [None]:
VOCAB_SIZE = len(VOCAB)
VOCAB_SIZE

#### Debug hierarchical dataloader

In [None]:
#
### %%debug

for batch in train_dataloader:
    break

In [None]:
batch.masks.min(), batch.masks.max()

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run ../utils/nlp.py

In [None]:
report_reader = ReportReader(train_dataloader.dataset.get_vocab())

In [None]:
item_idx = 1
report = batch.reports[item_idx]
mask = batch.masks[item_idx]
report.size(), mask.size()

In [None]:
plt.figure(figsize=(15, 5))
n_sentences = mask.size(0)
n_cols = n_sentences

for i_sentence in range(n_sentences):
    submask = mask[i_sentence]
    
    title = f'Sentence {i_sentence}'
    
    min_value = submask.min().item()
    if min_value == submask.max().item():
        unique_value = min_value
        title += f' ({unique_value})'
    
    plt.subplot(1, n_cols, i_sentence + 1)
    plt.imshow(submask)
    plt.title(title)
    plt.axis('off')
    
    sentence = report_reader.idx_to_text(report[i_sentence])
    print(f'{i_sentence}: {sentence}')
    
plt.show()

### Create CNN2Seq model

In [None]:
%run ../models/classification/__init__.py
%run ../models/report_generation/cnn_to_seq.py
%run ../models/checkpoint/__init__.py

#### Load CNN

In [None]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### ..or create CNN

In [None]:
cnn = create_cnn('mobilenet-v2', # resnet-50 # densenet-121
                 labels=[],
                 imagenet=True,
                 freeze=False,
                ).to(DEVICE)

#### Create Flat LSTM decoder

In [None]:
%run ../models/report_generation/decoder_lstm.py

In [None]:
decoder = LSTMDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                      teacher_forcing=True).to(DEVICE)

#### ...or with attention

In [None]:
%run ./models/report_generation/decoder_lstm_att.py

In [None]:
decoder_att = LSTMAttDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                             teacher_forcing=True).to(DEVICE)

#### ...or hierarchical decoder

In [None]:
%run ../models/report_generation/decoder_h_lstm_att_v2.py

In [None]:
decoder_h = HierarchicalLSTMAttDecoderV2(VOCAB_SIZE, 100, 100, cnn.features_size,
                                         teacher_forcing=True, attention=True).to(DEVICE)

#### Full model

In [None]:
model = CNN2Seq(cnn, decoder_h).to(DEVICE)

In [None]:
# model = nn.DataParallel(model)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(model, optimizer, None, {})

### ...or Load CNN2Seq model

In [None]:
%run models/checkpoint/__init__.py

In [None]:
# run_name = '0714_181427_lstm_lr0.0001_densenet-121'
# run_name = '0716_234501_h-lstm-att_lr0.0001_densenet-121'
# run_name = '0717_015057_h-lstm_lr0.0001_densenet-121'

# run_name = '0717_041434_lstm_lr0.0001_densenet-121'
# run_name = '0716_211601_lstm-att_lr0.0001_densenet-121'
# debug = False

# run_name = '0717_183321_lstm_lr0.0001_densenet-121_size256'
# run_name = '0717_184851_lstm_lr0.0001_densenet-121_size256'
debug = False

run_name = 'supervise-att-2'
debug = True

compiled_model = load_compiled_model_report_generation(run_name, debug=debug, device=DEVICE)

compiled_model.metadata

In [None]:
compiled_model.metadata['vocab']

## Train

### Run training

In [None]:
%run -n ../train_report_generation.py

In [None]:
%%time

trainer, validator = train_model(
    'supervise-att-3',
    compiled_model,
    train_dataloader,
    val_dataloader,
    n_epochs=1,
    supervise_attention=True,
    early_stopping=False,
    hierarchical=True,
    dryrun=False,
    save_model=True,
    debug=True,
    device=DEVICE,
)

In [None]:
validator.state.metrics

In [None]:
trainer.state.metrics

In [None]:
t = trainer._event_handlers[Events.ITERATION_COMPLETED][12]
method, (engine,), d = t

In [None]:
method.__self__.words_seen

### Run evaluation (post-train)

In [None]:
run_name = 'debugging'

In [None]:
dataloaders = [train_dataloader, val_dataloader]

In [None]:
evaluate_and_save(run_name,
                  model,
                  dataloaders,
                  hierarchical=False,
                  free='both',
                  debug=True,
                  device=DEVICE,
                  )

## Test samples

In [None]:
_ = compiled_model.model.eval()

In [None]:
%run ../utils/nlp.py
%run ../utils/__init__.py

In [None]:
import matplotlib.pyplot as plt
from skimage.color import gray2rgb
from skimage.transform import resize

In [None]:
report_reader = ReportReader(VOCAB)

In [None]:
len(train_dataloader.dataset)

In [None]:
idx = 10

In [None]:
item = val_dataloader.dataset[idx]
image = item.image
report = item.report
image.size(), len(report)

In [None]:
reports = torch.tensor(report).unsqueeze(0).to(DEVICE)
reports.size()

In [None]:
report_reader.idx_to_text(report)

In [None]:
images = image.unsqueeze(0).to(DEVICE)
tup = compiled_model.model(images, reports=reports, free=False, max_words=100)
generated = tup[0]
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
print(generated.size())

In [None]:
report_reader.idx_to_text(generated)

In [None]:
generated.size()

In [None]:
scores = tup[2]
scores.size()

### Report 1
idx = 5001 from train_dataset
predicted with lstm

In [None]:
report_reader.idx_to_text(generated)

In [None]:
report_reader.idx_to_text(report)

#### Plot x-ray

In [None]:
plt.figure(figsize=(8, 8))

plt.imshow(arr_to_range(image.permute(1, 2, 0)))
plt.axis('off')

#### Plot with attention

In [None]:
att = tup[1][0].detach().cpu().numpy()
att.shape

In [None]:
plt.imshow(att[-10])

In [None]:
att_idx = 13

In [None]:
len(att)

In [None]:
# Transpose image to plot with imshow
norm_image_CHW = arr_to_range(image.detach().cpu().numpy())
norm_image_HWC = norm_image_CHW.transpose(1, 2, 0)

# Resize activation
height, width = norm_image_HWC.shape[:2]
act = resize(att[att_idx], (height, width))
act = arr_to_range(act, 0, 1)

# Apply pretty colormap
cm = plt.get_cmap('jet')
act = cm(act)

# Add both images
image_plus_act = (norm_image_HWC + act[:, :, :3]) # / 2
image_plus_act = arr_to_range(image_plus_act)

plt.imshow(image_plus_act)

In [None]:
plt.figure(figsize=(8, 8))

plt.imshow(image_plus_act)

### Search reports with a certain pattern

In [None]:
from tqdm.notebook import tqdm
import re

In [None]:
# target = re.compile(r'\A[a-zA-Z]+ size is normal')
target = re.compile('both lungs are clear and expanded')
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if target.search(report):
        found.append(report)

len(found)

In [None]:
found_diff = list(set(found))
len(found_diff)

In [None]:
found_diff[5]

## Debug att-supervision loss

In [None]:
%run ../losses/out_of_target.py

In [None]:
masks1 = train_dataloader.dataset[30].masks
masks2 = train_dataloader.dataset[100].masks
masks3 = train_dataloader.dataset[200].masks
masks4 = train_dataloader.dataset[1].masks
target = torch.stack([torch.stack([masks1, masks2]), torch.stack([masks3, masks4])])
target.size()

In [None]:
shape = target.size()[:2]
output = torch.rand(*shape, 16, 16)
# output = torch.ones(*target.size())
output = output.view(*shape, -1)
output = torch.softmax(output, dim=-1)
output = output.view(*shape, 16, 16)
output.size()

In [None]:
loss = OutOfTargetSumLoss()
x = loss(output, target)
x.item()

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
# target = (torch.rand(1, 1, 256, 256) > 0.5).long()
target = masks
target.size()

In [None]:
target2 = interpolate(target.float(), size=(16, 16), mode='nearest')
target2.size()

In [None]:
batch_size, n_sentences = target.size()[:2]

n_rows = batch_size
n_cols = n_sentences * 2

plt.figure(figsize=(15, 8))

plot_index = 1
for idx1 in range(batch_size):
    for idx2 in range(n_sentences):
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(target[idx1][idx2])
        plt.title(f'Original - {idx1},{idx2}')
        plot_index += 1
        
        plt.subplot(n_rows, n_cols, plot_index)
        plt.imshow(output[idx1][idx2])
        plt.title(f'Downsampled - {idx1},{idx2}')
        plot_index += 1

In [None]:
masks.min(), masks.max()

In [None]:
target2.min(), target2.max()

In [None]:
from collections import Counter

In [None]:
Counter(x.item() for x in target2.long().view(-1))

In [None]:
target2 = target2.long()