## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=

In [None]:
import torch
from torch import nn

In [None]:
%run ../utils/common.py

In [None]:
# DEVICE = torch.device('cuda')
DEVICE = torch.device('cpu')
DEVICE

## Utils

In [None]:
def short_name(module):
    if isinstance(module, (nn.BatchNorm2d, nn.MaxPool2d, nn.ReLU, nn.Conv2d)):
        return module
    return module.__class__

In [None]:
def model_details(model):
    return list((k, short_name(v)) for k, v in model._modules.items())

## CLS

In [None]:
%run ../utils/conv.py

In [None]:
%run classification/transfusion.py
%run classification/resnet.py
%run classification/densenet.py
%run classification/mobilenet.py
%run classification/vgg.py
%run classification/load_imagenet.py
%run classification/tiny_res_scan.py
%run classification/tiny_densenet.py

In [None]:
LABELS = [f'disease{idx}' for idx in range(14)]

### Num trainable params

In [None]:
def print_trainable_params(cnn):
    total = num_trainable_parameters(cnn)
    if hasattr(cnn, 'features'):
        feats = cnn.features
    elif hasattr(cnn, 'base_cnn'):
        feats = cnn.base_cnn.features
    else:
        raise Exception('Cannot get features attr')

    feats = num_trainable_parameters(feats)
    fc = total - feats

    print(f'Total: {total:,}')
    print(f'Feats: {feats:,} ({feats / total * 100:.1f}%)')
    print(f'FC: {fc:,} ({fc / total * 100:.1f}%)')

In [None]:
cnn = ImageNetModel(model_name='densenet-121', labels=LABELS)
print_trainable_params(cnn)

In [None]:
cnn = ImageNetModel(model_name='resnet-50', labels=LABELS)
print_trainable_params(cnn)

In [None]:
cnn = ImageNetModel(model_name='mobilenet', labels=LABELS)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(LABELS, name='small', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(labels, name='tiny', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(labels, name='wide', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = TransfusionCBRCNN(labels, name='tall', n_channels=3)
print_trainable_params(cnn)

In [None]:
cnn = SmallDenseNetCNN(list(range(14)))
print_trainable_params(cnn)

In [None]:
cnn = TinyDenseNetCNN(list(range(14)))
print_trainable_params(cnn)

### Debug input/output

In [None]:
bs = 4
height = width = 512

images = torch.rand(bs, 3, height, width)
images.size()

In [None]:
features = cnn(images, features=True)
features.size()

In [None]:
output, = cnn(images, features=False)
output.size()

### Debug imagenet models

In [None]:
%run ./classification/load_imagenet.py

In [None]:
DEVICE = 'cuda'

In [None]:
def test_pass(model, bs=4, height=512, width=512, features=False):
    images = torch.rand(bs, 3, height, width).to(DEVICE)
    
    return model(images, features=features)

In [None]:
model = ImageNetModel(list(range(3)), model_name='densenet-121').to(DEVICE)
model

In [None]:
out = test_pass(model, height=1024, width=1024)
y, emb = out
y.size(), emb.size()

### Tiny densenet

In [None]:
from torchvision.models import densenet as dn

In [None]:
%run ./classification/tiny_densenet.py

In [None]:
cnn = dn.DenseNet(12, (6, 6, 6, 12), 64, num_classes=14)

In [None]:
f'{num_trainable_parameters(cnn.features):,}'

In [None]:
model_details(cnn.features)

In [None]:
f'{num_trainable_parameters(cnn):,}'

In [None]:
model_details(cnn.features.denseblock4.denselayer12.conv2)

In [None]:
f'{num_trainable_parameters(cnn):,}'

In [None]:
cnn = CustomDenseNetCNN(labels=list(range(14)), growth_rate=12,
                        block_config=(6, 6, 6, 12),
                        num_init_features=16,
                        bn_size=4,
                        drop_rate=0,
                       )
f'{num_trainable_parameters(cnn):,}'

In [None]:
x = torch.rand(7, 3, 512, 512)
x.size()

In [None]:
y = cnn.features(x)
y.size()

In [None]:
out = cnn(x)
out = out[0]
out.size()

### Tiny Resnet

In [None]:
# %run ./segmentation/scan.py
%run ./classification/tiny_res_scan.py
# %run ../utils/conv.py

In [None]:
labels = [f'd{i}' for i in range(14)]

In [None]:
cnn = TinyResScanCNN(labels)

In [None]:
f'{num_trainable_parameters(cnn):,}'

In [None]:
x = torch.randn(7, 1, 500, 500)
x.size()

In [None]:
out = cnn(x)
out = out[0]
out.size()

In [None]:
print_trainable_params(cnn)

## Cls-spatial models

In [None]:
%run ./cls_spatial/imagenet_cls_spatial.py

In [None]:
model = ImageNetClsSpatialModel(list(range(14)))
# model

In [None]:
images = torch.randn(7, 3, 256, 256)
images.size()

In [None]:
cl_out, cl_spatial_out = model(images)
cl_out.size(), cl_spatial_out.size()

In [None]:
num_trainable_parameters(model.spatial_classifier), model.spatial_classifier

In [None]:
l = nn.Linear(1024, 14)
num_trainable_parameters(l), l

## Cls-seg models

In [None]:
%run ./cls_seg/imagenet.py
%run ./cls_seg/scan.py
%run ./cls_seg/tiny_densenet.py

In [None]:
cl_labels = list(range(12))
seg_labels = list(range(6))

In [None]:
model = ImageNetClsSegModel(
    cl_labels, seg_labels, model_name='densenet-121', dropout_features=0.5)
# model = ScanClsSeg(cl_labels, seg_labels, dropout_features=0.5)
# model

In [None]:
model = TinyDenseNetCNN(cl_labels, seg_labels)
num_trainable_parameters(model)

In [None]:
model = SmallDenseNetCNN(cl_labels, seg_labels)
num_trainable_parameters(model)

In [None]:
model

In [None]:
x = torch.randn(7, 3, 512, 512)
x.size()

In [None]:
cl, seg = model(x)
cl.size(), seg.size()

### Calculate necessary padding

In [None]:
kernel = 32
stride = 16
dilation = 1
out_padding = 0
def f(in_size, out_size):
    padding = ((in_size - 1) * stride + dilation*(kernel - 1) + out_padding + 1 - out_size) / 2
    return padding
f(12, 200), f(16, 256), f(32, 512), f(64, 1024)

In [None]:
kernel = 4
stride = 2
dilation = 1
out_padding = 0
def f(in_size, out_size):
    padding = ((in_size - 1) * stride + dilation*(kernel - 1) + out_padding + 1 - out_size) / 2
    return padding
f(6, 12), f(8, 16), f(16, 32), f(7, 14)

## SCAN

In [None]:
%run ./segmentation/scan.py

In [None]:
res = _ResBlock(50, 7)

In [None]:
images = torch.rand(7, 50, 400, 400)
res(images).size()

In [None]:
pres = _ParallelResBlocks(2, 16, 3)

In [None]:
images = torch.rand(7, 16, 100, 100)
pres(images).size()

In [None]:
model = ScanFCN()
# model
total = num_trainable_parameters(model)
print(f'{total:,}')

In [None]:
bs = 7
height = width = 1024

images = torch.rand(bs, 1, height, width)
out = model(images)
out.size()

## Dummy baselines

### Load data

In [None]:
%run ../datasets/iu_xray.py

In [None]:
dataset_kwargs = {
    'max_samples': 100,
    'frontal_only': False,
    'image_size': (512, 512),
}

train_dataset = IUXRayDataset(dataset_type='train', **dataset_kwargs)
dataset_kwargs['vocab'] = train_dataset.get_vocab()
val_dataset = IUXRayDataset(dataset_type='val', **dataset_kwargs)
test_dataset = IUXRayDataset(dataset_type='test', **dataset_kwargs)
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
VOCAB = train_dataset.get_vocab()
vocab_size = len(VOCAB)
vocab_size

### Random

In [None]:
%run ./report_generation/dummy/random.py

In [None]:
model = RandomReport(train_dataset)
model

In [None]:
bs = 2
features = torch.rand(bs, 256, 16, 16)
reports = (torch.rand(bs, 20) * vocab_size).long()

In [None]:
vocab_size, reports.max().item()

In [None]:
r, = model(features, None, free=True)
r.size()

### MostSimilarImage

In [None]:
from tqdm.notebook import tqdm

In [None]:
%run ../utils/nlp.py
%run ../training/report_generation/flat.py

In [None]:
reader = ReportReader(VOCAB)

In [None]:
cnn = cnn.to(DEVICE)

In [None]:
bs = 1

train_dataloader = create_flat_dataloader(train_dataset, batch_size=bs)
val_dataloader = create_flat_dataloader(val_dataset, batch_size=bs)
test_dataloader = create_flat_dataloader(test_dataset, batch_size=bs)

In [None]:
%run ./report_generation/dummy/most_similar_image.py

In [None]:
model = MostSimilarImage(cnn, VOCAB).to(DEVICE)
model.fit(train_dataloader, device=DEVICE)

#### Test with random example

In [None]:
bs_2 = 1

images = torch.rand(bs_2, 3, 256, 256).to(DEVICE)
reports = (torch.randn(bs_2, 4) * vocab_size).long().to(DEVICE)

In [None]:
out = model(images, reports, free=False)
out = out[0]
out.size()

#### Test with real sample

In [None]:
model.train(False)
torch.set_grad_enabled(False)

In [None]:
dataloader = train_dataloader

In [None]:
for batch in tqdm(iter(dataloader)):
    images = batch.images.to(DEVICE)
    reports = batch.reports.to(DEVICE)
    filenames = batch.filenames
    
    output, _ = model(images, reports, free=True)
    _, output = output.max(dim=-1)
    
    for report, gen, filename in zip(reports, output, filenames):
        report = reader.idx_to_text(report)
        gen = reader.idx_to_text(gen)

        if report != gen:
            print(filename)