## Imports

In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch
from torch import nn

In [3]:
%run ../utils/common.py

In [4]:
DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
DEVICE

device(type='cuda')

## CNN

In [5]:
%run ../utils/conv.py

In [6]:
%run classification/transfusion.py
%run classification/resnet.py
%run classification/densenet.py
%run classification/mobilenet.py
%run classification/vgg.py
%run classification/load_imagenet.py
%run classification/tiny_res_scan.py

In [7]:
LABELS = [f'disease{idx}' for idx in range(14)]

### Num trainable params

In [8]:
def print_trainable_params(cnn):
    total = num_trainable_parameters(cnn)
    if isinstance(cnn, (ImageNetModel, TransfusionCBRCNN, TinyResScanCNN)):
        feats = cnn.features
    else:
        feats = cnn.base_cnn.features

    feats = num_trainable_parameters(feats)
    fc = total - feats

    print(f'Total: {total:,}')
    print(f'Feats: {feats:,} ({feats / total * 100:.1f}%)')
    print(f'FC: {fc:,} ({fc / total * 100:.1f}%)')

In [9]:
cnn = ImageNetModel(model_name='densenet-121', labels=LABELS)
print_trainable_params(cnn)

NameError: name 'labels' is not defined

In [41]:
cnn = ImageNetModel(model_name='resnet-50', labels=labels)
print_trainable_params(cnn)

Total: 23,536,718
Feats: 23,508,032 (99.9%)
FC: 28,686 (0.1%)


In [48]:
cnn = ImageNetModel(model_name='mobilenet', labels=labels)
print_trainable_params(cnn)

Total: 2,241,806
Feats: 2,223,872 (99.2%)
FC: 17,934 (0.8%)


In [9]:
cnn = TransfusionCBRCNN(LABELS, name='small', n_channels=3)
print_trainable_params(cnn)

Total: 2,117,134
Feats: 2,113,536 (99.8%)
FC: 3,598 (0.2%)


In [50]:
cnn = TransfusionCBRCNN(labels, name='tiny', n_channels=3)
print_trainable_params(cnn)

Total: 4,315,662
Feats: 4,308,480 (99.8%)
FC: 7,182 (0.2%)


In [51]:
cnn = TransfusionCBRCNN(labels, name='wide', n_channels=3)
print_trainable_params(cnn)

Total: 8,449,038
Feats: 8,441,856 (99.9%)
FC: 7,182 (0.1%)


In [52]:
cnn = TransfusionCBRCNN(labels, name='tall', n_channels=3)
print_trainable_params(cnn)

Total: 8,544,782
Feats: 8,537,600 (99.9%)
FC: 7,182 (0.1%)


### Debug input/output

In [53]:
bs = 4
height = width = 512

images = torch.rand(bs, 3, height, width)
images.size()

torch.Size([4, 3, 512, 512])

In [54]:
features = cnn(images, features=True)
features.size()

torch.Size([4, 512, 19, 19])

In [55]:
output, = cnn(images, features=False)
output.size()

torch.Size([4, 14])

### Debug imagenet models

In [5]:
%run ./classification/load_imagenet.py

In [15]:
DEVICE = 'cuda'

In [16]:
def test_pass(model, bs=4, height=512, width=512, features=False):
    images = torch.rand(bs, 3, height, width).to(DEVICE)
    
    return model(images, features=features)

In [17]:
model = ImageNetModel(list(range(3)), model_name='resnet-50').to(DEVICE)
# model

In [19]:
out = test_pass(model, height=1024, width=1024)
y, emb = out
y.size(), emb.size()

(torch.Size([4, 3]), torch.Size([4, 2048]))

### Tiny densenet

In [5]:
from torchvision.models import densenet as dn

In [14]:
cnn = dn.densenet121(pretrained=False)
# cnn = dn.DenseNet(32, (6, 12, 24, 16), 64)

In [16]:
f'{num_trainable_parameters(cnn.features):,}'

'6,953,856'

In [90]:
cnn = dn.DenseNet(12, (6, 6, 6, 12), 64, num_classes=14)

In [91]:
f'{num_trainable_parameters(cnn.features):,}'

'368,132'

In [9]:
def short_name(module):
    if isinstance(module, (nn.BatchNorm2d, nn.MaxPool2d, nn.ReLU, nn.Conv2d)):
        return module
    return module.__class__

In [10]:
def model_details(model):
    return list((k, short_name(v)) for k, v in model._modules.items())

In [88]:
model_details(cnn.features)

[('conv0',
  Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)),
 ('norm0',
  BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 ('relu0', ReLU(inplace=True)),
 ('pool0',
  MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)),
 ('denseblock1', torchvision.models.densenet._DenseBlock),
 ('transition1', torchvision.models.densenet._Transition),
 ('denseblock2', torchvision.models.densenet._DenseBlock),
 ('transition2', torchvision.models.densenet._Transition),
 ('denseblock3', torchvision.models.densenet._DenseBlock),
 ('transition3', torchvision.models.densenet._Transition),
 ('denseblock4', torchvision.models.densenet._DenseBlock),
 ('norm5',
  BatchNorm2d(211, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]

In [92]:
cnn.classifier

Linear(in_features=215, out_features=14, bias=True)

In [6]:
%run ./classification/tiny_densenet.py

In [15]:
cnn = SmallDenseNetCNN(list(range(14)))
cnn

CustomDenseNetCNN(
  (features): Sequential(
    (conv0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(32, 60, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(60, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       

In [14]:
model_details(cnn.features.denseblock4.denselayer12.conv2)

[]

In [36]:
f'{num_trainable_parameters(cnn):,}'

'371,156'

In [38]:
cnn = CustomDenseNetCNN(labels=list(range(14)), growth_rate=12,
                        block_config=(6, 6, 6, 12),
                        num_init_features=16,
                        bn_size=4,
                        drop_rate=0,
                       )
f'{num_trainable_parameters(cnn):,}'

'324,860'

In [8]:
x = torch.rand(7, 3, 512, 512)
x.size()

torch.Size([7, 3, 512, 512])

In [11]:
y = cnn.features(x)
y.size()

torch.Size([7, 215, 16, 16])

In [12]:
out = cnn(x)
out = out[0]
out.size()

torch.Size([7, 14])

### Tiny Resnet

In [105]:
# %run ./segmentation/scan.py
%run ./classification/tiny_res_scan.py
# %run ../utils/conv.py

In [106]:
labels = [f'd{i}' for i in range(14)]

In [107]:
cnn = TinyResScanCNN(labels)

In [109]:
f'{num_trainable_parameters(cnn):,}'

'226,846'

In [64]:
x = torch.randn(7, 1, 500, 500)
x.size()

torch.Size([7, 1, 500, 500])

In [65]:
out = cnn(x)
out = out[0]
out.size()

torch.Size([7, 14])

In [66]:
print_trainable_params(cnn)

Total: 226,846
Feats: 225,936 (99.6%)
FC: 910 (0.4%)


## SCAN

In [56]:
%run ./segmentation/scan.py

In [42]:
res = _ResBlock(50, 7)

In [43]:
images = torch.rand(7, 50, 400, 400)
res(images).size()

torch.Size([7, 50, 400, 400])

In [48]:
pres = _ParallelResBlocks(2, 16, 3)

In [49]:
images = torch.rand(7, 16, 100, 100)
pres(images).size()

torch.Size([7, 32, 100, 100])

In [59]:
model = ScanFCN()
# model
total = num_trainable_parameters(model)
print(f'{total:,}')

242,584


In [47]:
bs = 7
height = width = 1024

images = torch.rand(bs, 1, height, width)
out = model(images)
out.size()

torch.Size([7, 4, 1024, 1024])

## Decoder

In [21]:
%run report_generation/decoder_lstm.py

In [22]:
vocab_size = 1000
embedding_size = 200
hidden_size = 100

decoder = LSTMDecoder(vocab_size, embedding_size, hidden_size)
decoder

LSTMDecoder(
  (embeddings_table): Embedding(1000, 200, padding_idx=0)
  (lstm_cell): LSTMCell(200, 100)
  (W_vocab): Linear(in_features=100, out_features=1000, bias=True)
)

In [23]:
batch_size = 4
hidden_size = 100

initial_state = torch.rand(batch_size, hidden_size)
initial_state.size()

torch.Size([4, 100])

In [24]:
outputs = decoder(initial_state, 10)
words = outputs[0]
words.size()

torch.Size([4, 10, 1000])

## Decoder with attention

### Attention

In [7]:
%run ./report_generation/att_2layer.py

In [8]:
features_size = 512
lstm_size = 100
att = AttentionTwoLayers(features_size, lstm_size)
att

AttentionTwoLayers(
  (visual_fc): Linear(in_features=512, out_features=100, bias=True)
  (state_fc): Linear(in_features=100, out_features=100, bias=True)
  (last_fc): Sequential(
    (0): Tanh()
    (1): Linear(in_features=100, out_features=1, bias=True)
  )
  (softmax): Sequential(
    (0): Flatten()
    (1): Softmax(dim=-1)
  )
)

In [10]:
bs = 8
height = width = 16

features = torch.rand(bs, features_size, height, width)
h_state = torch.rand(bs, lstm_size)

In [11]:
feats, scores = att(features, h_state)
feats.size(), scores.size()

(torch.Size([8, 512]), torch.Size([8, 16, 16]))

In [12]:
scores.sum(dim=-1).sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

### LSTM-att-v2

In [26]:
%run ./report_generation/decoder_lstm_att_v2.py

In [27]:
vocab_size = 1000
embedding_size = 100
hidden_size = 200
features_size = 512

model = LSTMAttDecoderV2(vocab_size, embedding_size, hidden_size, features_size)

In [28]:
bs = 7
height = width = 16

features = torch.rand(bs, features_size, height, width)
reports = (torch.rand(bs, 20) * (vocab_size - 1)).long()

In [29]:
generated, scores = model(features, reports)
generated.size(), scores.size()

(torch.Size([7, 20, 1000]), torch.Size([7, 20, 16, 16]))

## Debug hierarchical

In [60]:
import torch
from torch import nn
import numpy as np

from medai.utils.nlp import PAD_IDX, START_IDX, END_OF_SENTENCE_IDX

In [6]:
%run report_generation/att_2layer.py

### Debug input/output

In [2]:
%run report_generation/decoder_h_lstm_att.py

In [8]:
batch_size = 2

In [9]:
image_features = torch.rand(batch_size, 1024, 16, 16)

In [13]:
reports_h = torch.tensor([[[1, 2, 3, 0],
                           [1, 5, 0, 0],
                           [2, 2, 2, 0],
                          ],
                          [[7, 9, 10, 0],
                           [1, 4, 0, 0],
                           [8, 9, 0, 0],
                          ],
                         ])
reports_h.size()

torch.Size([2, 3, 4])

In [11]:
model = HierarchicalLSTMAttDecoder(200, 100, 100, (1024, 16, 16))

In [12]:
generated, stops, scores = model(image_features, reports_h)
generated.size(), stops.size(), scores.size()

(torch.Size([2, 3, 4, 200]), torch.Size([2, 3, 1]), torch.Size([2, 3, 16, 16]))

### Debug flat and hierarchical padding

In [31]:
%run ../training/report_generation/hierarchical.py
%run ../training/report_generation/flat.py

In [32]:
%run ../datasets/iu_xray.py

In [33]:
dataset = IUXRayDataset(max_samples=20)
dataset.size()

(42, 20)

In [34]:
dataloader = create_hierarchical_dataloader(dataset, batch_size=4)

In [35]:
for batch in dataloader:
    images, reports, stops = batch
    break

In [36]:
reports, stops

(tensor([[[49, 32, 11, 50,  8, 51,  4],
          [52,  8,  9, 13,  4,  0,  0],
          [ 0,  0,  0,  0,  0,  0,  0]],
 
         [[49, 32, 11, 50,  8, 51,  4],
          [52,  8,  9, 13,  4,  0,  0],
          [ 0,  0,  0,  0,  0,  0,  0]],
 
         [[76, 77, 63, 78,  4,  0,  0],
          [52, 10, 13,  4,  0,  0,  0],
          [32, 11, 50,  4,  0,  0,  0]],
 
         [[76, 77, 63, 78,  4,  0,  0],
          [52, 10, 13,  4,  0,  0,  0],
          [32, 11, 50,  4,  0,  0,  0]]]),
 tensor([[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]))

In [38]:
from torch import nn

In [46]:
s = torch.tensor([[0, 0, 1],
                  [0, 0, 1],
                  [0, 0, 0],
                  [0, 0, 0],
                 ])

In [47]:
stops

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [48]:
loss = nn.BCELoss()

In [61]:
loss(s, stops.long())

RuntimeError: "binary_cross_entropy" not implemented for 'Long'

In [59]:
a = torch.rand(10, 5, 1)
a.squeeze(-1).size()

torch.Size([10, 5])

In [129]:
flat_dataloader = create_flat_dataloader(dataset, batch_size=4)

In [130]:
for batch in flat_dataloader:
    _, flat_reports = batch
    break

In [132]:
flat_reports.size()

torch.Size([4, 14])

In [133]:
flat_reports

tensor([[49, 32, 11, 50,  8, 51,  4, 52,  8,  9, 13,  4,  1,  0],
        [49, 32, 11, 50,  8, 51,  4, 52,  8,  9, 13,  4,  1,  0],
        [76, 77, 63, 78,  4, 52, 10, 13,  4, 32, 11, 50,  4,  1],
        [76, 77, 63, 78,  4, 52, 10, 13,  4, 32, 11, 50,  4,  1]])

In [134]:
reports

tensor([[[49, 32, 11, 50,  8, 51,  4],
         [52,  8,  9, 13,  4,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0]],

        [[49, 32, 11, 50,  8, 51,  4],
         [52,  8,  9, 13,  4,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0]],

        [[76, 77, 63, 78,  4,  0,  0],
         [52, 10, 13,  4,  0,  0,  0],
         [32, 11, 50,  4,  0,  0,  0]],

        [[76, 77, 63, 78,  4,  0,  0],
         [52, 10, 13,  4,  0,  0,  0],
         [32, 11, 50,  4,  0,  0,  0]]])

## Dummy baselines

### Load data

In [5]:
%run ../datasets/iu_xray.py

In [185]:
dataset_kwargs = {
    'max_samples': 100,
    'frontal_only': False,
    'image_size': (512, 512),
}

train_dataset = IUXRayDataset(dataset_type='train', **dataset_kwargs)
dataset_kwargs['vocab'] = train_dataset.get_vocab()
val_dataset = IUXRayDataset(dataset_type='val', **dataset_kwargs)
test_dataset = IUXRayDataset(dataset_type='test', **dataset_kwargs)
len(train_dataset), len(val_dataset), len(test_dataset)

(195, 198, 196)

In [94]:
VOCAB = train_dataset.get_vocab()
vocab_size = len(VOCAB)
vocab_size

1775

### Random

In [76]:
%run ./report_generation/dummy/random.py

In [77]:
model = RandomReport(train_dataset)
model

RandomReport()

In [78]:
bs = 2
features = torch.rand(bs, 256, 16, 16)
reports = (torch.rand(bs, 20) * vocab_size).long()

In [79]:
vocab_size, reports.max().item()

(1775, 1673)

In [109]:
r, = model(features, None, free=True)
r.size()

torch.Size([2, 46, 1775])

### MostSimilarImage

In [181]:
from tqdm.notebook import tqdm

In [95]:
%run ../utils/nlp.py
%run ../training/report_generation/flat.py

In [103]:
reader = ReportReader(VOCAB)

In [60]:
cnn = cnn.to(DEVICE)

In [187]:
bs = 1

train_dataloader = create_flat_dataloader(train_dataset, batch_size=bs)
val_dataloader = create_flat_dataloader(val_dataset, batch_size=bs)
test_dataloader = create_flat_dataloader(test_dataset, batch_size=bs)

In [188]:
%run ./report_generation/dummy/most_similar_image.py

In [194]:
model = MostSimilarImage(cnn, VOCAB).to(DEVICE)
model.fit(train_dataloader, device=DEVICE)

#### Test with random example

In [169]:
bs_2 = 1

images = torch.rand(bs_2, 3, 256, 256).to(DEVICE)
reports = (torch.randn(bs_2, 4) * vocab_size).long().to(DEVICE)

In [170]:
out = model(images, reports, free=False)
out = out[0]
out.size()

torch.Size([1, 4, 1775])

#### Test with real sample

In [195]:
model.train(False)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fcef17c9710>

In [196]:
dataloader = train_dataloader

In [197]:
for batch in tqdm(iter(dataloader)):
    images = batch.images.to(DEVICE)
    reports = batch.reports.to(DEVICE)
    filenames = batch.filenames
    
    output, _ = model(images, reports, free=True)
    _, output = output.max(dim=-1)
    
    for report, gen, filename in zip(reports, output, filenames):
        report = reader.idx_to_text(report)
        gen = reader.idx_to_text(gen)

        if report != gen:
            print(filename)

HBox(children=(FloatProgress(value=0.0, max=195.0), HTML(value='')))




In [None]:
output, dist = model(images, reports, free=True)
_, output = output.max(dim=-1)
output.size(), dist.size()

In [161]:
reader.idx_to_text(reports)

'increased interstitial opacities non-specific . question edema or atypical infection? END heart size is normal and the lungs are clear . END'

In [162]:
reader.idx_to_text(output)

'increased interstitial opacities non-specific . question edema or atypical infection? END heart size is normal and the lungs are clear . END'

### Common sentences

In [221]:
%run ../training/report_generation/hierarchical.py

In [222]:
%run ./report_generation/dummy/common_sentences.py

In [223]:
model = MostCommonSentences(train_dataset)

In [224]:
model.n_sentences, model.n_weights

((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 12, 18, 17),
 (73, 209, 1242, 1740, 1277, 754, 334, 153, 74, 33, 17, 4, 11, 1, 1))

In [245]:
bs = 10
images = torch.rand(bs, 3, 16, 16)
reports = (torch.randn(bs, 8, 6) * vocab_size).long()

In [254]:
a, b = model(images, reports, free=False)
a.size()

torch.Size([10, 8, 6, 1775])

### Common words

In [255]:
%run ./report_generation/dummy/common_words.py

In [256]:
model = MostCommonWords(train_dataset)

In [262]:
bs = 10
images = torch.rand(bs, 3, 16, 16)
reports = (torch.randn(bs, 7) * vocab_size).long()

In [268]:
a, = model(images, reports, free=True)
a.size()

torch.Size([10, 26, 1775])