## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
from torch import nn

In [None]:
%run ../utils/common.py

In [None]:
DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
DEVICE

## CNN

In [None]:
%run ../utils/conv.py

In [None]:
%run classification/transfusion.py
%run classification/resnet.py
%run classification/densenet.py
%run classification/mobilenet.py
%run classification/vgg.py
%run classification/load_imagenet.py
%run classification/tiny_res_scan.py

In [None]:
LABELS = [f'disease{idx}' for idx in range(14)]

In [None]:
def short_name(module):
    if isinstance(module, (nn.BatchNorm2d, nn.MaxPool2d, nn.ReLU, nn.Conv2d)):
        return module
    return module.__class__

In [None]:
def model_details(model):
    return list((k, short_name(v)) for k, v in model._modules.items())

### Num trainable params

In [None]:
def print_trainable_params(cnn):
    total = num_trainable_parameters(cnn)
    if isinstance(cnn, (ImageNetModel, TransfusionCBRCNN, TinyResScanCNN)):
        feats = cnn.features
    else:
        feats = cnn.base_cnn.features

    feats = num_trainable_parameters(feats)
    fc = total - feats

    print(f'Total: {total:,}')
    print(f'Feats: {feats:,} ({feats / total * 100:.1f}%)')
    print(f'FC: {fc:,} ({fc / total * 100:.1f}%)')

In [None]:
cnn = ImageNetModel(model_name='densenet-121', labels=LABELS)
print_trainable_params(cnn)

In [None]:
cnn = ImageNetModel(model_name='resnet-50', labels=labels)
print_trainable_params(cnn)

In [None]:
cnn = ImageNetModel(model_name='mobilenet', labels=labels)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(LABELS, name='small', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(labels, name='tiny', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(labels, name='wide', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(labels, name='tall', n_channels=3)
print_trainable_params(cnn)

### Debug input/output

In [None]:
bs = 4
height = width = 512

images = torch.rand(bs, 3, height, width)
images.size()

In [None]:
features = cnn(images, features=True)
features.size()

In [None]:
output, = cnn(images, features=False)
output.size()

### Debug imagenet models

In [None]:
%run ./classification/load_imagenet.py

In [None]:
DEVICE = 'cuda'

In [None]:
def test_pass(model, bs=4, height=512, width=512, features=False):
    images = torch.rand(bs, 3, height, width).to(DEVICE)
    
    return model(images, features=features)

In [None]:
model = ImageNetModel(list(range(3)), model_name='resnet-50').to(DEVICE)
# model

In [None]:
out = test_pass(model, height=1024, width=1024)
y, emb = out
y.size(), emb.size()

### Tiny densenet

In [None]:
from torchvision.models import densenet as dn

In [None]:
%run ./classification/tiny_densenet.py

In [None]:
cnn = dn.DenseNet(12, (6, 6, 6, 12), 64, num_classes=14)

In [None]:
f'{num_trainable_parameters(cnn.features):,}'

In [None]:
model_details(cnn.features)

In [None]:
cnn = SmallDenseNetCNN(list(range(14)))

In [None]:
f'{num_trainable_parameters(cnn):,}'

In [None]:
model_details(cnn.features.denseblock4.denselayer12.conv2)

In [None]:
f'{num_trainable_parameters(cnn):,}'

In [None]:
cnn = CustomDenseNetCNN(labels=list(range(14)), growth_rate=12,
                        block_config=(6, 6, 6, 12),
                        num_init_features=16,
                        bn_size=4,
                        drop_rate=0,
                       )
f'{num_trainable_parameters(cnn):,}'

In [None]:
x = torch.rand(7, 3, 512, 512)
x.size()

In [None]:
y = cnn.features(x)
y.size()

In [None]:
out = cnn(x)
out = out[0]
out.size()

### Tiny Resnet

In [None]:
# %run ./segmentation/scan.py
%run ./classification/tiny_res_scan.py
# %run ../utils/conv.py

In [None]:
labels = [f'd{i}' for i in range(14)]

In [None]:
cnn = TinyResScanCNN(labels)

In [None]:
f'{num_trainable_parameters(cnn):,}'

In [None]:
x = torch.randn(7, 1, 500, 500)
x.size()

In [None]:
out = cnn(x)
out = out[0]
out.size()

In [None]:
print_trainable_params(cnn)

### Cls-seg models

In [None]:
%run ./cls_seg/imagenet.py
%run ./cls_seg/scan.py

In [None]:
cl_labels = list(range(14))
seg_labels = list(range(4))
model = ImageNetClsSegModel(cl_labels, seg_labels, model_name='resnet-50')
# model = ScanClsSeg(cl_labels, seg_labels)
# model

In [None]:
x = torch.randn(7, 3, 200, 200)
x.size()

In [None]:
cl, seg = model(x)
cl.size(), seg.size()

#### Calculate necessary padding

In [None]:
kernel = 32
stride = 16
dilation = 1
out_padding = 0
def f(in_size, out_size):
    padding = ((in_size - 1) * stride + dilation*(kernel - 1) + out_padding + 1 - out_size) / 2
    return padding
f(12, 200), f(16, 256), f(32, 512), f(64, 1024)

In [None]:
kernel = 4
stride = 2
dilation = 1
out_padding = 0
def f(in_size, out_size):
    padding = ((in_size - 1) * stride + dilation*(kernel - 1) + out_padding + 1 - out_size) / 2
    return padding
f(6, 12), f(8, 16), f(16, 32), f(7, 14)

## SCAN

In [None]:
%run ./segmentation/scan.py

In [None]:
res = _ResBlock(50, 7)

In [None]:
images = torch.rand(7, 50, 400, 400)
res(images).size()

In [None]:
pres = _ParallelResBlocks(2, 16, 3)

In [None]:
images = torch.rand(7, 16, 100, 100)
pres(images).size()

In [None]:
model = ScanFCN()
# model
total = num_trainable_parameters(model)
print(f'{total:,}')

In [None]:
bs = 7
height = width = 1024

images = torch.rand(bs, 1, height, width)
out = model(images)
out.size()

## Decoder

In [None]:
%run report_generation/decoder_lstm.py

In [None]:
vocab_size = 1000
embedding_size = 200
hidden_size = 100

decoder = LSTMDecoder(vocab_size, embedding_size, hidden_size)
decoder

In [None]:
batch_size = 4
hidden_size = 100

initial_state = torch.rand(batch_size, hidden_size)
initial_state.size()

In [None]:
outputs = decoder(initial_state, 10)
words = outputs[0]
words.size()

## Decoder with attention

### Attention

In [None]:
%run ./report_generation/att_2layer.py

In [None]:
features_size = 512
lstm_size = 100
att = AttentionTwoLayers(features_size, lstm_size)
att

In [None]:
bs = 8
height = width = 16

features = torch.rand(bs, features_size, height, width)
h_state = torch.rand(bs, lstm_size)

In [None]:
feats, scores = att(features, h_state)
feats.size(), scores.size()

In [None]:
scores.sum(dim=-1).sum(dim=-1)

### LSTM-att-v2

In [None]:
%run ./report_generation/decoder_lstm_att_v2.py

In [None]:
vocab_size = 1000
embedding_size = 100
hidden_size = 200
features_size = 512

model = LSTMAttDecoderV2(vocab_size, embedding_size, hidden_size, features_size)

In [None]:
bs = 7
height = width = 16

features = torch.rand(bs, features_size, height, width)
reports = (torch.rand(bs, 20) * (vocab_size - 1)).long()

In [None]:
generated, scores = model(features, reports)
generated.size(), scores.size()

## Debug hierarchical

In [None]:
import torch
from torch import nn
import numpy as np

from medai.utils.nlp import PAD_IDX, START_IDX, END_OF_SENTENCE_IDX

In [None]:
%run report_generation/att_2layer.py

### Debug input/output

In [None]:
%run report_generation/decoder_h_lstm_att.py

In [None]:
batch_size = 2

In [None]:
image_features = torch.rand(batch_size, 1024, 16, 16)

In [None]:
reports_h = torch.tensor([[[1, 2, 3, 0],
                           [1, 5, 0, 0],
                           [2, 2, 2, 0],
                          ],
                          [[7, 9, 10, 0],
                           [1, 4, 0, 0],
                           [8, 9, 0, 0],
                          ],
                         ])
reports_h.size()

In [None]:
model = HierarchicalLSTMAttDecoder(200, 100, 100, (1024, 16, 16))

In [None]:
generated, stops, scores = model(image_features, reports_h)
generated.size(), stops.size(), scores.size()

### Debug flat and hierarchical padding

In [None]:
%run ../training/report_generation/hierarchical.py
%run ../training/report_generation/flat.py

In [None]:
%run ../datasets/iu_xray.py

In [None]:
dataset = IUXRayDataset(max_samples=20)
dataset.size()

In [None]:
dataloader = create_hierarchical_dataloader(dataset, batch_size=4)

In [None]:
for batch in dataloader:
    images, reports, stops = batch
    break

In [None]:
reports, stops

In [None]:
from torch import nn

In [None]:
s = torch.tensor([[0, 0, 1],
                  [0, 0, 1],
                  [0, 0, 0],
                  [0, 0, 0],
                 ])

In [None]:
stops

In [None]:
loss = nn.BCELoss()

In [None]:
loss(s, stops.long())

In [None]:
a = torch.rand(10, 5, 1)
a.squeeze(-1).size()

In [None]:
flat_dataloader = create_flat_dataloader(dataset, batch_size=4)

In [None]:
for batch in flat_dataloader:
    _, flat_reports = batch
    break

In [None]:
flat_reports.size()

In [None]:
flat_reports

In [None]:
reports

## Dummy baselines

### Load data

In [None]:
%run ../datasets/iu_xray.py

In [None]:
dataset_kwargs = {
    'max_samples': 100,
    'frontal_only': False,
    'image_size': (512, 512),
}

train_dataset = IUXRayDataset(dataset_type='train', **dataset_kwargs)
dataset_kwargs['vocab'] = train_dataset.get_vocab()
val_dataset = IUXRayDataset(dataset_type='val', **dataset_kwargs)
test_dataset = IUXRayDataset(dataset_type='test', **dataset_kwargs)
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
VOCAB = train_dataset.get_vocab()
vocab_size = len(VOCAB)
vocab_size

### Random

In [None]:
%run ./report_generation/dummy/random.py

In [None]:
model = RandomReport(train_dataset)
model

In [None]:
bs = 2
features = torch.rand(bs, 256, 16, 16)
reports = (torch.rand(bs, 20) * vocab_size).long()

In [None]:
vocab_size, reports.max().item()

In [None]:
r, = model(features, None, free=True)
r.size()

### MostSimilarImage

In [None]:
from tqdm.notebook import tqdm

In [None]:
%run ../utils/nlp.py
%run ../training/report_generation/flat.py

In [None]:
reader = ReportReader(VOCAB)

In [None]:
cnn = cnn.to(DEVICE)

In [None]:
bs = 1

train_dataloader = create_flat_dataloader(train_dataset, batch_size=bs)
val_dataloader = create_flat_dataloader(val_dataset, batch_size=bs)
test_dataloader = create_flat_dataloader(test_dataset, batch_size=bs)

In [None]:
%run ./report_generation/dummy/most_similar_image.py

In [None]:
model = MostSimilarImage(cnn, VOCAB).to(DEVICE)
model.fit(train_dataloader, device=DEVICE)

#### Test with random example

In [None]:
bs_2 = 1

images = torch.rand(bs_2, 3, 256, 256).to(DEVICE)
reports = (torch.randn(bs_2, 4) * vocab_size).long().to(DEVICE)

In [None]:
out = model(images, reports, free=False)
out = out[0]
out.size()

#### Test with real sample

In [None]:
model.train(False)
torch.set_grad_enabled(False)

In [None]:
dataloader = train_dataloader

In [None]:
for batch in tqdm(iter(dataloader)):
    images = batch.images.to(DEVICE)
    reports = batch.reports.to(DEVICE)
    filenames = batch.filenames
    
    output, _ = model(images, reports, free=True)
    _, output = output.max(dim=-1)
    
    for report, gen, filename in zip(reports, output, filenames):
        report = reader.idx_to_text(report)
        gen = reader.idx_to_text(gen)

        if report != gen:
            print(filename)

In [None]:
output, dist = model(images, reports, free=True)
_, output = output.max(dim=-1)
output.size(), dist.size()

In [None]:
reader.idx_to_text(reports)

In [None]:
reader.idx_to_text(output)

### Common sentences

In [None]:
%run ../training/report_generation/hierarchical.py

In [None]:
%run ./report_generation/dummy/common_sentences.py

In [None]:
model = MostCommonSentences(train_dataset)

In [None]:
model.n_sentences, model.n_weights

In [None]:
bs = 10
images = torch.rand(bs, 3, 16, 16)
reports = (torch.randn(bs, 8, 6) * vocab_size).long()

In [None]:
a, b = model(images, reports, free=False)
a.size()

### Common words

In [None]:
%run ./report_generation/dummy/common_words.py

In [None]:
model = MostCommonWords(train_dataset)

In [None]:
bs = 10
images = torch.rand(bs, 3, 16, 16)
reports = (torch.randn(bs, 7) * vocab_size).long()

In [None]:
a, = model(images, reports, free=True)
a.size()