## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

0,1,2,3


In [2]:
import torch

In [3]:
%run ../utils/common.py

In [4]:
DEVICE = torch.device('cuda', 1)
DEVICE

device(type='cuda', index=1)

## CNN

In [5]:
%run ../utils/conv.py

In [13]:
%run classification/transfusion.py
%run classification/resnet.py
%run classification/densenet.py
%run classification/vgg.py
%run classification/mobilenet.py

In [7]:
labels = [f'disease{idx}' for idx in range(14)]

In [8]:
cnn = Densenet121CNN(labels, multilabel=True)
num_trainable_parameters(cnn)

7993206

In [9]:
cnn = Resnet50CNN(labels, multilabel=True)
num_trainable_parameters(cnn)

25585718

In [10]:
cnn = TransfusionCBRCNN(labels, multilabel=True, name='tiny', n_channels=3)
num_trainable_parameters(cnn)

4315662

In [11]:
cnn = VGG19CNN(labels)
num_trainable_parameters(cnn)

139627598

In [14]:
cnn = MobileNetV2CNN(labels)
num_trainable_parameters(cnn)

3522806

In [100]:
cnn

MobileNetV2CNN(
  (base_cnn): MobileNetV2(
    (features): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=

In [15]:
batch_size = 4
h = w = 224

images = torch.rand(4, 3, h, w)
images.size()

torch.Size([4, 3, 224, 224])

In [16]:
features = cnn(images, features=True)
features.size()

torch.Size([4, 1280, 7, 7])

In [27]:
output, = cnn(images, features=False)
output.size()

torch.Size([4, 14])

## Decoder

In [21]:
%run report_generation/decoder_lstm.py

In [22]:
vocab_size = 1000
embedding_size = 200
hidden_size = 100

decoder = LSTMDecoder(vocab_size, embedding_size, hidden_size)
decoder

LSTMDecoder(
  (embeddings_table): Embedding(1000, 200, padding_idx=0)
  (lstm_cell): LSTMCell(200, 100)
  (W_vocab): Linear(in_features=100, out_features=1000, bias=True)
)

In [23]:
batch_size = 4
hidden_size = 100

initial_state = torch.rand(batch_size, hidden_size)
initial_state.size()

torch.Size([4, 100])

In [24]:
outputs = decoder(initial_state, 10)
words = outputs[0]
words.size()

torch.Size([4, 10, 1000])

## Encoder-Decoder

In [44]:
%run report_generation/cnn_to_seq.py

In [45]:
model = CNN2Seq(cnn, decoder)
# model

In [46]:
words, = model(images, 10)
words.size()

torch.Size([4, 10, 1000])

## Debug hierarchical

In [60]:
import torch
from torch import nn
import numpy as np

from mrg.utils.nlp import PAD_IDX, START_IDX, END_OF_SENTENCE_IDX

In [6]:
%run report_generation/att_2layer.py

### Debug input/output

In [2]:
%run report_generation/decoder_h_lstm_att.py

In [8]:
batch_size = 2

In [9]:
image_features = torch.rand(batch_size, 1024, 16, 16)

In [13]:
reports_h = torch.tensor([[[1, 2, 3, 0],
                           [1, 5, 0, 0],
                           [2, 2, 2, 0],
                          ],
                          [[7, 9, 10, 0],
                           [1, 4, 0, 0],
                           [8, 9, 0, 0],
                          ],
                         ])
reports_h.size()

torch.Size([2, 3, 4])

In [11]:
model = HierarchicalLSTMAttDecoder(200, 100, 100, (1024, 16, 16))

In [12]:
generated, stops, scores = model(image_features, reports_h)
generated.size(), stops.size(), scores.size()

(torch.Size([2, 3, 4, 200]), torch.Size([2, 3, 1]), torch.Size([2, 3, 16, 16]))

### Debug flat and hierarchical padding

In [31]:
%run ../training/report_generation/hierarchical.py
%run ../training/report_generation/flat.py

In [32]:
%run ../datasets/iu_xray.py

In [33]:
dataset = IUXRayDataset(max_samples=20)
dataset.size()

(42, 20)

In [34]:
dataloader = create_hierarchical_dataloader(dataset, batch_size=4)

In [35]:
for batch in dataloader:
    images, reports, stops = batch
    break

In [36]:
reports, stops

(tensor([[[49, 32, 11, 50,  8, 51,  4],
          [52,  8,  9, 13,  4,  0,  0],
          [ 0,  0,  0,  0,  0,  0,  0]],
 
         [[49, 32, 11, 50,  8, 51,  4],
          [52,  8,  9, 13,  4,  0,  0],
          [ 0,  0,  0,  0,  0,  0,  0]],
 
         [[76, 77, 63, 78,  4,  0,  0],
          [52, 10, 13,  4,  0,  0,  0],
          [32, 11, 50,  4,  0,  0,  0]],
 
         [[76, 77, 63, 78,  4,  0,  0],
          [52, 10, 13,  4,  0,  0,  0],
          [32, 11, 50,  4,  0,  0,  0]]]),
 tensor([[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]))

In [38]:
from torch import nn

In [46]:
s = torch.tensor([[0, 0, 1],
                  [0, 0, 1],
                  [0, 0, 0],
                  [0, 0, 0],
                 ])

In [47]:
stops

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [48]:
loss = nn.BCELoss()

In [61]:
loss(s, stops.long())

RuntimeError: "binary_cross_entropy" not implemented for 'Long'

In [59]:
a = torch.rand(10, 5, 1)
a.squeeze(-1).size()

torch.Size([10, 5])

In [129]:
flat_dataloader = create_flat_dataloader(dataset, batch_size=4)

In [130]:
for batch in flat_dataloader:
    _, flat_reports = batch
    break

In [132]:
flat_reports.size()

torch.Size([4, 14])

In [133]:
flat_reports

tensor([[49, 32, 11, 50,  8, 51,  4, 52,  8,  9, 13,  4,  1,  0],
        [49, 32, 11, 50,  8, 51,  4, 52,  8,  9, 13,  4,  1,  0],
        [76, 77, 63, 78,  4, 52, 10, 13,  4, 32, 11, 50,  4,  1],
        [76, 77, 63, 78,  4, 52, 10, 13,  4, 32, 11, 50,  4,  1]])

In [134]:
reports

tensor([[[49, 32, 11, 50,  8, 51,  4],
         [52,  8,  9, 13,  4,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0]],

        [[49, 32, 11, 50,  8, 51,  4],
         [52,  8,  9, 13,  4,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0]],

        [[76, 77, 63, 78,  4,  0,  0],
         [52, 10, 13,  4,  0,  0,  0],
         [32, 11, 50,  4,  0,  0,  0]],

        [[76, 77, 63, 78,  4,  0,  0],
         [52, 10, 13,  4,  0,  0,  0],
         [32, 11, 50,  4,  0,  0,  0]]])