## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

0,1,2,3


In [2]:
import torch
from torch import nn

In [3]:
%run -n train_report_generation.py

In [4]:
DEVICE = torch.device('cuda', 1)
DEVICE

device(type='cuda', index=1)

## Load stuff

### Load data

In [5]:
%run ./datasets/iu_xray.py

In [6]:
dataset_kwargs = {
    'max_samples': 100,
    'frontal_only': False,
}

train_dataset = IUXRayDataset(dataset_type='train', **dataset_kwargs)
val_dataset = IUXRayDataset(dataset_type='val', vocab=train_dataset.get_vocab(),
                            **dataset_kwargs)
train_dataset.size(), val_dataset.size()

((195, 100), (198, 100))

### Create Flat dataloader

In [8]:
%run training/report_generation/flat.py

In [7]:
BS = 5

train_dataloader = create_flat_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_flat_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(195, 100)

### ...or hierarchical dataloader

In [7]:
%run training/report_generation/hierarchical.py

In [8]:
BS = 5

train_dataloader = create_hierarchical_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_hierarchical_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(195, 100)

In [9]:
VOCAB_SIZE = len(train_dataset.word_to_idx)
VOCAB_SIZE

443

### Load model

In [10]:
%run ./models/classification/__init__.py
%run ./models/report_generation/cnn_to_seq.py
%run ./models/checkpoint/__init__.py

#### Load CNN

In [11]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### ..or create CNN

In [12]:
cnn = init_empty_model('densenet-121', # resnet-50
                       labels=[],
                       imagenet=True,
                       freeze=False,
                       ).to(DEVICE)

#### Create Flat LSTM decoder

In [None]:
%run ./models/report_generation/decoder_lstm.py

In [11]:
decoder = LSTMDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                      teacher_forcing=True).to(DEVICE)

#### ...or with attention

In [34]:
%run ./models/report_generation/decoder_lstm_att.py

In [35]:
decoder_att = LSTMAttDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                             teacher_forcing=True).to(DEVICE)

#### ...or hierarchical decoder

In [13]:
%run ./models/report_generation/decoder_h_lstm_att.py

In [14]:
decoder_h = HierarchicalLSTMAttDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                                       teacher_forcing=True).to(DEVICE)

#### Full model

In [15]:
model = CNN2Seq(cnn, decoder_h).to(DEVICE)

In [16]:
# model = nn.DataParallel(model)

In [17]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(model, optimizer)

## Train

In [18]:
%%time

train_metrics, val_metrics = train_model('debugging',
                                         compiled_model,
                                         train_dataloader,
                                         val_dataloader,
                                         n_epochs=2,
                                         hierarchical=True,
                                         dryrun=True,
                                         save_model=False,
                                         debug=True,
                                         device=DEVICE)

--------------------------------------------------
Training...
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.])
tensor([52, 10, 16, 13,  8,  5, 32, 11, 50,  4,  0,  0,  0])
tensor([121., 110., 431., 346., 314., 126., 167., 124., 124., 132., 121., 121.,
        322., 322., 322., 322., 322., 322., 322., 322., 121., 121., 322., 322.,
        322., 322., 322., 322., 322., 322.])
tensor([52, 10, 16, 13,  8,  5, 32, 11, 50,  4,  0,  0,  0])
tensor([121., 320., 413., 124., 345.,  97., 245., 347., 347., 322., 121., 110.,
        355., 355., 314., 347., 347., 347., 132., 132., 121.,  94., 132., 322.,
        322., 322., 322., 322., 322., 322.])
tensor([49, 32, 11, 50,  8, 51,  4, 52,  8,  9, 13,  4,  0])
tensor([121., 320., 413., 124., 345.,  97., 245., 347., 347., 322., 121., 110.,
        355., 355., 314., 347., 347., 347., 132., 132., 121.,  94., 132., 322.,
        322., 322., 322., 322., 322., 322.])
te

tensor([121, 110, 382, 346, 382, 382, 110, 322, 322, 322, 322, 121, 196, 346,
        343, 226, 347, 347, 272, 322, 322, 322, 121, 135, 136, 253, 357, 179,
          5, 346, 322,   8,  28, 121, 150, 322, 322, 322, 322, 322, 322, 322,
        322, 322])
tensor([ 52,  16, 154,   4,  18,  22,  11, 434,   4, 167,  71, 225,   8, 156,
        168,  22, 169,  11,  85,   4,   0,   0])
tensor([121, 110, 382, 346, 382, 382, 110, 322, 322, 322, 322, 121, 440, 357,
        343, 226, 347, 347, 272, 322, 322, 322, 121, 135, 136, 253, 357, 179,
          5, 346, 322,   8,  28, 121, 150, 322, 322, 322, 322, 322, 322, 322,
        322, 322])
tensor([ 52,  16, 154,   4,  18,  22,  11, 434,   4, 167,  71, 225,   8, 156,
        168,  22, 169,  11,  85,   4,   0,   0])
tensor([121, 313, 355, 391, 158, 158, 313, 126,  50, 347, 347, 121,  74, 221,
        124, 276, 347, 347, 110, 132, 132, 132, 121, 400, 400, 346, 382, 346,
        346, 314, 347, 347, 272, 121,  94, 322, 322, 322, 322, 322, 322, 322,
      

tensor([121, 110, 365,  74, 150, 382, 150, 322, 322, 322, 121,  97, 110, 110,
        110, 110, 110, 132, 132, 132, 121,  97, 427, 347, 382, 322, 132, 132,
        132, 322, 121,  97, 322,  28,  28,  28,   9, 149, 322, 132, 121,  94,
        346, 154, 382, 382, 382, 322, 322, 322, 121,  94, 346, 346, 343, 314,
        382, 382, 382, 132, 121, 150,   8, 276, 347, 382, 110, 132, 132, 132])
tensor([180, 146,  52,   4,  17, 181,   4,  17,  19,   4,  17, 182,  39, 183,
          4, 184,  22,  99,   4, 184,  22,  16,  13,  10,   4,  32,  11,  50,
          4,   0])
tensor([121, 110, 365,  74, 150, 347, 150, 322, 322, 322, 121,  97, 110, 110,
        110, 110, 110, 132, 132, 132, 121,  97, 355, 347, 382, 322, 132, 132,
        132, 322, 121,  97, 322,  28,  28,  28,   9, 149, 322, 132, 121,  94,
        346, 154, 382, 382, 382, 322, 322, 322, 121,  94, 346, 346, 343, 314,
        382, 382, 382, 132, 121, 426,   8, 276, 347, 347, 110, 132, 132, 132])
tensor([180, 146,  52,   4,  17, 181,   4, 

tensor([121,  74, 110,  42, 143,  28,  28, 313, 314,  50, 314, 431, 355, 245,
        121, 172, 221, 124,  28, 126, 126,  97, 347, 347, 110, 110, 132, 132,
        121, 135, 110, 429, 313, 343,  50, 110, 226, 400, 347, 149, 149, 322,
        121, 150, 322, 322, 322, 322, 322, 322, 322, 126, 126, 126, 126, 126,
        121, 150, 322, 322, 322, 322, 322, 322, 322, 126, 126, 126, 126, 126])
tensor([  5,  29,   7,   8, 229,  11,  12,  13,  14,  30,  10,   8,  31,   4,
          5,  32,  11,  33,  34,   8,  50,   4, 143, 125,  11,  12,  13,  14,
         30, 147, 317,   4,   0,   0])
tensor([121, 398, 110,  42, 143,  28,  28, 313, 314,  50, 314, 431, 355, 245,
        121,  77, 322,   8, 377, 177, 377,  74, 346, 322, 126, 431,  94, 347,
        121, 410, 400,   8, 398, 110, 347, 382, 110, 110, 132, 132, 322, 322,
        121, 150, 132, 322, 322, 322, 322, 322, 322, 126, 126, 126, 126, 126,
        121, 150, 132, 322, 322, 322, 322, 322, 322, 126, 126, 126, 126, 126])
tensor([  5,  29,   7, 

tensor([121, 398, 110,  42, 235, 344, 382, 382, 382, 322, 322, 132, 132, 132,
        126, 126, 126, 126, 121,  77, 322, 198, 276, 347, 382, 382, 322, 322,
        322, 322, 322, 322, 126, 126, 126, 126, 121,  97, 343, 126, 322, 330,
        359, 400, 400, 319,   5, 346, 322,  98, 400, 201, 179,  74, 138,  50,
        110, 429, 313, 343,  50, 347, 347, 110, 132, 132, 132, 132, 132, 126,
        126, 126, 138, 150, 126, 322, 322, 322, 322, 126, 126, 126, 126, 126,
        126, 126, 126, 126, 126, 126, 138, 150, 126, 322, 322, 322, 322, 126,
        126, 126, 126, 126, 126, 126, 126, 126, 126, 126])
tensor([  5,  29,   7,  16,  40,   4,  45, 165, 395, 166,   4,  15,  16,  17,
         25, 396,  87,  93,  28,  39, 168,  22, 169,  87,  93, 154,  25,  26,
          4, 124, 125,  11,  12,  13,  14,   4,   0])
tensor([121, 398, 221, 124, 276, 399, 347, 347, 110, 132, 132, 132, 132, 132,
        126, 126, 126, 126, 121, 322, 276, 442, 322,   8, 355,  52, 245,  74,
        415, 272, 232, 429, 3

tensor([150, 122, 196, 343, 110,  52, 110, 313, 150, 122, 150, 122, 196, 343,
        110,  52, 110, 313, 150, 122, 150, 122, 196, 343, 110,  52, 110, 313,
        150, 122, 150, 122, 196, 343, 110,  52, 110, 313, 150, 122])
tensor([  5,  32,   8,  25, 128, 133,  17,  46,  48,   4,  52,  10,   8,  18,
         79,  12,  13,  14,   4,   4,   0,   0])
tensor([121,   5, 172, 382, 120,  74, 172, 136, 135, 400, 121,   5, 172, 382,
        120,  74, 172, 136, 135, 400, 121,   5, 172, 382, 120,  74, 172, 136,
        135, 400, 121,   5, 172, 382, 120,  74, 172, 136, 135, 400])
tensor([  5,   6, 123,  11,  13,   4,   5,  32,  11,  50,   4, 139, 360,   4,
         22,  22,  23,   5, 109,   4,   0,   0])
tensor([121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 121,   5, 398, 167,
         71, 136, 135, 276, 122,  52, 121,   5, 398, 167,  71, 136, 135, 276,
        122,  52, 121,   5, 398, 167,  71, 136, 135, 276, 122,  52])
tensor([  5,   6, 123,  11,  13,   4,   5,  32,  11,  50,   4, 139, 360

tensor([121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 111, 413,
        198, 121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 111,
        413, 198, 121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18,
        111, 413, 198, 121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,
         18, 111, 413, 198])
tensor([  3,   3,  43, 206,  25,  26,  56,  43, 200,  55,  61,   4,  13,   6,
         31,  56, 291, 116,  75,   5, 112,   4,  50,  63,  45,  22,   4])
tensor([121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 111, 413,
        198, 121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 111,
        413, 198, 138, 382, 295,  74, 172, 136, 135, 400,  65, 132,   9, 355,
        377, 136, 306, 138, 382, 295,  74, 172, 136, 135, 400,  65, 132,   9,
        355, 377, 136, 306])
tensor([  3,   3,  43, 206,  25,  26,  56,  43, 200,  55,  61,   4,  13,   6,
         31,  56, 291, 116,  75,   5, 112,   4,  50,  63,  45,  22,   4])
tensor([150, 1

tensor([121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 138, 382,
        295,  74, 172, 136, 135, 400,  65, 132,   9, 355, 138, 382, 295,  74,
        172, 136, 135, 400,  65, 132,   9, 355, 138, 382, 295,  74, 172, 136,
        135, 400,  65, 132,   9, 355, 138, 313, 400, 355, 330, 322, 187, 377,
        136, 306, 429, 126])
tensor([  5,  52,  10,   8,  29,   7,  11,  12,  13,  14,   4,  18, 229,  86,
         13,   4,  15,  16,  17,  20,  89, 110,  21,   4,  17,  25,  26,  39,
         28,   4,   0])
tensor([121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 121,   5,
        398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 121,   5, 398, 167,
         71, 136, 135, 276, 122,  52, 110,  18, 121,   5, 398, 167,  71, 136,
        135, 276, 122,  52, 110,  18, 121,   5, 398, 167,  71, 136, 135, 276,
        122,  52, 110,  18])
tensor([ 32,  11,  50,  35,  20,  21,  38,  26,  38,  39,  28,   4,  13,  52,
         10,   4,  17, 334,  18, 182,  39, 183,   4, 124, 33

tensor([121,   5, 398, 167,  71, 136, 135, 276, 122,  52, 110,  18, 111, 150,
        122, 196, 201, 343, 110, 426,  16, 343, 110,  16, 343, 110, 150, 122,
        196, 201, 343, 110, 426,  16, 343, 110,  16, 343, 110, 150, 122, 196,
        201, 343, 110, 426,  16, 343, 110,  16, 343, 110, 150, 122, 196, 201,
        343, 110, 426,  16, 343, 110,  16, 343, 110, 150, 122, 196, 201, 343,
        110, 426,  16, 343, 110,  16, 343, 110])
tensor([ 17,  46, 143,  48,   4,  82,   3,  11,  12,  13,  14,   4,   3,   3,
         23,   5,  52,   4,  13, 228, 229,   4,  17,  20,  60,  23,  21,  38,
         25,  26,  38,  39,  28,   4,   0])
tensor([121,   5, 172, 382, 398,  28, 283, 442,  74, 203, 174, 174, 151, 121,
          5, 172, 382, 398,  28, 283, 442,  74, 203, 174, 174, 151, 121,   5,
        172, 382, 398,  28, 283, 442,  74, 203, 174, 174, 151, 121,   5, 172,
        382, 398,  28, 283, 442,  74, 203, 174, 174, 151, 121,   5, 172, 382,
        398,  28, 283, 442,  74, 203, 174, 174, 1

tensor([150,  49, 355, 391, 158, 158, 313, 126,  50, 347, 347, 110, 150, 398,
        221, 124, 276, 382, 382, 110, 110, 132, 132, 132, 150, 322, 110, 442,
        377, 347, 347, 382, 132, 132, 132, 132])
tensor([  6,   8, 122, 123,  11,  12,  13,  14,   4,   5,  32,  11,  50,   4,
        124, 125,  11, 126,   4])
tensor([150,  49, 355, 391, 158, 158, 313, 126,  50, 347, 347, 110, 150, 398,
        221, 124, 276, 382, 382, 110, 110, 132, 132, 132, 150, 322, 110, 442,
        377, 347, 347, 382, 132, 132, 132, 132])
tensor([  6,   8, 122, 123,  11,  12,  13,  14,   4,   5,  32,  11,  50,   4,
        124, 125,  11, 126,   4])
tensor([150,  49, 355, 391, 158, 158, 313, 126,  50, 347, 347, 110, 150, 398,
        221, 124, 276, 382, 382, 110, 110, 132, 132, 132, 150,  50, 110, 442,
        377, 347, 347, 382, 132, 132, 132, 132])
tensor([  6,   8, 122, 123,  11,  12,  13,  14,   4,   5,  32,  11,  50,   4,
        124, 125,  11, 126,   4])
tensor([121, 313, 355, 391, 158, 158, 313, 126,  

tensor([150, 398, 221, 124, 132, 382, 382, 110, 110, 132, 132, 313,  97, 343,
        126, 107,  96,  28,  77, 347, 347, 347, 313, 398,  74, 355, 355,  28,
        343, 347, 347, 110, 132, 276, 398, 252, 391, 429, 343, 347, 347, 110,
        132, 132, 276, 150, 132, 132, 132, 132, 132, 126, 126, 126, 126])
tensor([  5,  32,  11,  50,   4,  15,  16,  17,  25,  26,  39,  28,   4,   5,
         52,   8,   9,  11,  13,   4,   5, 400, 125,  11,  13,   4,   0,   0])
tensor([150, 398, 221, 124, 132, 382, 382, 110, 110, 132, 132, 150,  97, 343,
        126, 107,  96,  28,  77, 347, 347, 347, 150, 398, 110, 355, 355,  28,
        343, 347, 347, 110, 132, 150, 398, 252, 391, 429, 343, 347, 347, 110,
        132, 132, 150, 150, 132, 132, 132, 132, 132, 126, 126, 126, 126])
tensor([  5,  32,  11,  50,   4,  15,  16,  17,  25,  26,  39,  28,   4,   5,
         52,   8,   9,  11,  13,   4,   5, 400, 125,  11,  13,   4,   0,   0])
tensor([150, 398, 110, 431, 382, 313, 314,  50, 347, 382, 110, 150, 44

tensor([ 52,  10,   8, 122, 123,  11,  13,  62, 140,   4,  17, 141,  36, 127,
          4,  17, 142,  27,  23,  25,  26,  39,  28,   4, 132, 143, 125,  80,
        126,   4])
tensor([150, 110, 431, 355, 391, 158, 158, 343, 343, 343, 382, 382, 110, 132,
        132, 150,  97,  21,  16,  94, 382, 382, 110, 132, 132, 132, 132, 126,
        126, 126, 150, 442,  32, 322,   8, 179, 439,  28,  77, 382,  74,  74,
        322, 132, 132, 150, 276, 110, 110, 150, 313, 347, 382, 382, 132, 132,
        132, 132, 126, 126])
tensor([ 52,  10,   8, 122, 123,  11,  13,  62, 140,   4,  17, 141,  36, 127,
          4,  17, 142,  27,  23,  25,  26,  39,  28,   4, 132, 143, 125,  80,
        126,   4])
tensor([313, 110, 431, 382, 343, 382, 382, 110, 132, 132, 132, 132, 132, 126,
        126, 313, 440,  52, 344, 382, 382, 110, 322, 322, 322, 132, 132, 126,
        126, 126, 313, 313,  97, 313, 313,  16, 429, 190, 442, 442,  77, 196,
        272, 431, 429, 313, 382, 382, 314,  31, 304, 263, 263, 132, 132, 13

tensor([150, 110, 431, 382, 343,   0, 382, 110, 132, 132, 132, 132, 132, 126,
        126, 126, 126, 126, 150, 398, 221, 124, 132, 382, 382, 110, 110, 132,
        132, 132, 132, 126, 126, 126, 126, 126, 150,  97,  28, 442,  97, 132,
         18,  16, 382, 382, 110, 110, 132, 132, 132, 126, 126, 126, 150, 442,
        276,   8, 319,  55,   0, 110, 110, 110, 132, 132, 132, 132, 126, 126,
        126, 126, 150, 398, 132, 357, 391, 158, 158, 343, 347, 347, 110, 132,
        132, 132, 132, 126, 126, 126, 150, 314, 221, 391, 347, 347, 110, 132,
        132, 132, 132, 132, 126, 126, 126, 126, 126, 126])
tensor([ 52,  10,  16,  13,   4,   5,  32,  11,  50,   4,  15,  11,  17,  20,
         89, 110, 237,   4,  17,  25, 181,  39, 238,   4,   5, 228,   8, 122,
        123,  11,  13,   4,  13,  18,  79,   4])
tensor([150, 398, 110, 355, 179, 346, 346, 346, 346,  32, 132, 382, 382, 110,
        110, 132, 132, 132, 150, 320, 413, 124, 345, 126, 245, 347, 347, 322,
        322, 322, 132, 132, 132, 1

tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 150, 122, 196, 132,
        437, 252, 252, 399,   3, 126, 276, 122,  52, 391, 187, 126,   5, 398,
         28, 283])
tensor([ 17, 111,   4,  52,  10,  13,   4, 359,   4,   0,   0,   0,   0,   0,
          0,   0,   0])
tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 150, 122, 196, 132,
        437, 252, 252, 399,   3, 126, 150, 122, 196, 132, 437, 252, 252, 399,
          3, 126])
tensor([ 13,  29, 123,   4,  17,  28,  38,  25, 181,  39,  20,  45,  21,   4,
          0,   0,   0])
tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 150, 122, 196, 132,
        437, 252, 252, 399,   3, 126, 150, 122, 196, 132, 437, 252, 252, 399,
          3, 126])
tensor([ 13,  29, 123,   4,  17,  28,  38,  25, 181,  39,  20,  45,  21,   4,
          0,   0,   0])
tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 150, 122, 196, 132,
        437, 252, 252, 399,   3, 126, 150, 122, 196, 132, 437, 252, 252, 399,
          3, 

tensor([ 94, 431, 126,   5, 398, 167,  71, 132,   8, 358, 124, 322, 187, 110,
        150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107,
        150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107])
tensor([ 52,  10,  38,  29,   7,  38,   8,  18, 229,  11,  12,  13,  14,   4,
         15,  11,  17, 225,  38, 181,  38,  39,  28,   4,   0,   0])
tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107,
        150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107,
        150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107])
tensor([ 52,  10,  38,  29,   7,  38,   8,  18, 229,  11,  12,  13,  14,   4,
         15,  11,  17, 225,  38, 181,  38,  39,  28,   4,   0,   0])
tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107,
        150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107,
        150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 107])
t

tensor([313, 400, 355, 122, 391, 431, 187, 126,   5, 398,  28, 313, 400, 355,
        122, 391, 431, 187, 126,   5, 398,  28, 150, 122, 196, 132, 437, 252,
        252, 399,   3, 126, 164, 150, 122, 196, 132, 437, 252, 252, 399,   3,
        126, 164, 150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164])
tensor([154,  41,  42,  62,   5,  43, 200,  55,  16, 265,   4,  17,  28,   4,
         52,  10,  16,  13,   4,  17, 154,  25, 181,   4,  17,  20,  36,   3,
          4])
tensor([121,   5, 172, 382, 398,  28, 283, 442,  74, 203, 174, 121,   5, 172,
        382, 398,  28, 283, 442,  74, 203, 174, 121,   5, 172, 382, 398,  28,
        283, 442,  74, 203, 174, 121,   5, 398, 167,  71, 136, 135, 276, 122,
         52, 110, 121,   5, 398, 167,  71, 136, 276, 107, 107,  28,  74])
tensor([154,  41,  42,  62,   5,  43, 200,  55,  16, 265,   4,  17,  28,   4,
         52,  10,  16,  13,   4,  17, 154,  25, 181,   4,  17,  20,  36,   3,
          4])
tensor([313, 400, 355, 122, 391, 431, 187, 1

tensor([150, 122, 196, 132, 437, 252, 252, 399,   3, 126, 164, 164, 167, 276,
        122,  52, 391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,
         52, 391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,  52,
        391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,  52, 391,
        187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,  52, 391, 187,
        126,   5, 398,  28, 283, 313, 400, 355])
tensor([ 17,  46, 143,  48,   4,  82,   3,  11,  12,  13,  14,   4,   3,   3,
         23,   5,  52,   4,  13, 228, 229,   4,  17,  20,  60,  23,  21,  38,
         25,  26,  38,  39,  28,   4,   0])
tensor([276, 122,  52, 391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276,
        122,  52, 391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,
         52, 391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,  52,
        391, 187, 126,   5, 398,  28, 283, 313, 400, 355, 276, 122,  52, 391,
        187, 126,   5, 398,  28, 283, 313, 400, 3

In [19]:
val_metrics

{'loss': 6.681816525846348,
 'word_acc': 0.0019833399444664813,
 'bleu1': 0.029961832061062982,
 'bleu2': 7.634887167752613e-11,
 'bleu3': 1.049683176342944e-13,
 'bleu4': 3.911667576820676e-15,
 'bleu': 0.007490458034380183,
 'rougeL': 0.035949752366824246,
 'ciderD': 9.764717528892482e-05}

In [20]:
train_metrics

{'loss': 6.727642167062097,
 'word_acc': 0.008928571428571428,
 'bleu1': 0.03051771117165658,
 'bleu2': 7.514120062760379e-11,
 'bleu3': 1.0209950204329273e-13,
 'bleu4': 3.781486327756373e-15,
 'bleu': 0.007629427811725916,
 'rougeL': 0.03438959201214743,
 'ciderD': 0.0002459661284152959}

## DEBUG

In [21]:
import numpy as np

In [17]:
image_features = torch.rand(2, *cnn.features_size).to(DEVICE)
image_features.size()

torch.Size([2, 1024, 16, 16])

In [19]:
reports_h = torch.tensor([[[1, 2, 3, 0],
                           [1, 5, 0, 0],
                           [2, 2, 2, 0],
                          ],
                          [[7, 9, 10, 0],
                           [1, 4, 0, 0],
                           [8, 9, 0, 0],
                          ],
                         ]).to(DEVICE)
reports_h.size()

torch.Size([2, 3, 4])

In [37]:
def _flatten_gt_reports(reports):
    texts = []

    for report in reports:
        text = []
        for sentence in report:
            sentence = np.trim_zeros(sentence.detach().cpu().numpy())
            if len(sentence) > 0:
                text.extend(sentence)

        texts.append(torch.tensor(text))

    return pad_sequence(texts, batch_first=True)

In [38]:
_flatten_gt_reports(reports_h)

tensor([[ 1,  2,  3,  1,  5,  2,  2,  2],
        [ 7,  9, 10,  1,  4,  8,  9,  0]])

In [19]:
r, st, sc = decoder_h(image_features, 0, reports_h)
r.size(), st.size()

(torch.Size([2, 3, 4, 443]), torch.Size([2, 3]))

In [24]:
r2 = _flatten_h_reports(r, st)
r2.size()

torch.Size([2, 4])

In [75]:
threshold = 0.35

In [76]:
tmp = torch.arange(st.size()[1], 0, -1)
tmp2 = tmp * (st.cpu() > threshold).long()

indices = torch.argmax(tmp2, 1, keepdim=True)
indices

tensor([[0],
        [0]])

In [81]:
indices.size()

torch.Size([2, 1])

In [79]:
r.size()

torch.Size([2, 3, 4, 443])

In [85]:
indices.view(-1)

tensor([0, 0])

In [91]:
_, r2 = r.max(dim=-1)
r2.size()

torch.Size([2, 3, 4])

In [92]:
for a in r2:
    break

In [93]:
a.size()

torch.Size([3, 4])

In [94]:
a

tensor([[206,   5,   6,  38],
        [206,   5,   6,  38],
        [206,   5,   6,  38]], device='cuda:1')

## Test samples

In [19]:
idx_to_word = {v: k for k, v in train_dataset.get_vocab().items()}
# idx_to_word

In [20]:
def idx_to_text(idxs):
    return ' '.join([idx_to_word[int(g.item())] for g in idxs])

In [85]:
idx = 150

In [86]:
image, report = train_dataset[idx]
image.size(), report.size()

(torch.Size([3, 512, 512]), torch.Size([13]))

In [87]:
images = image.unsqueeze(0).to(DEVICE)
generated, scores = model(images, report.size()[0])
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
# print(generated.size())
# print(generated)

idx_to_text(generated)

'both lungs are clear and expanded . heart and mediastinum normal . END'

In [88]:
idx_to_text(report)

'both lungs are clear and expanded . heart and mediastinum normal . END'

### Search reports with a certain pattern

In [24]:
from tqdm.notebook import tqdm
import re

In [65]:
# target = re.compile(r'\A[a-zA-Z]+ size is normal')
target = re.compile('both lungs are clear and expanded')
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if target.search(report):
        found.append(report)

len(found)

162

In [70]:
found_diff = list(set(found))
len(found_diff)

19

In [76]:
found_diff[5]

'chest . both lungs are clear and expanded with no pleural air collections or parenchymal consolidations . heart and mediastinum remain normal . lumbosacral spine . xxxx , disc spaces , and alignment are normal . sacrum and sacroiliac joints are normal . END'

## Debug metrics

In [8]:
from ignite.metrics import MetricsLambda

In [6]:
%run metrics/report_generation/bleu.py

In [9]:
bleu_up_to_4 = Bleu(n=4)

In [10]:
bleu1 = MetricsLambda(lambda x: x[0], bleu_up_to_4)
bleu2 = MetricsLambda(lambda x: x[1], bleu_up_to_4)
bleu3 = MetricsLambda(lambda x: x[2], bleu_up_to_4)
bleu4 = MetricsLambda(lambda x: x[3], bleu_up_to_4)
bleuAvg = MetricsLambda(lambda x: torch.mean(x), bleu_up_to_4)

## Debug attention

In [15]:
from torch import nn

In [146]:
%run models/report_generation/decoder_lstm_att.py

In [147]:
decoder = LSTMAttDecoder(200, 100, 100, (2048, 16, 16))

In [148]:
images = torch.randn(3, 2048, 16, 16).float()
images.size()

torch.Size([3, 2048, 16, 16])

In [149]:
out, scores = decoder(images, 10)
out.size(), scores.size()

(torch.Size([3, 10, 200]), torch.Size([3, 10, 16, 16]))

In [140]:
feats, scores = att(images, h_state)
feats.size(), scores.size()

(torch.Size([3, 2048]), torch.Size([3, 16, 16]))