## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

0,1,2,3


In [2]:
import torch
from torch import nn

In [3]:
%run -n train_report_generation.py

In [4]:
# DEVICE = torch.device('cuda', 2)
DEVICE = torch.device('cpu')
DEVICE

device(type='cpu')

## Load stuff

### Load data

In [5]:
%run ./datasets/iu_xray.py

In [6]:
dataset_kwargs = {
    'max_samples': 200,
    'frontal_only': False,
}

train_dataset = IUXRayDataset(dataset_type='train', **dataset_kwargs)
val_dataset = IUXRayDataset(dataset_type='val', vocab=train_dataset.get_vocab(),
                            **dataset_kwargs)
train_dataset.size(), val_dataset.size()

((388, 200), (390, 200))

In [7]:
VOCAB_SIZE = len(train_dataset.word_to_idx)
VOCAB_SIZE

580

#### Create Flat dataloader

In [8]:
%run training/report_generation/flat.py

In [9]:
BS = 5

train_dataloader = create_flat_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_flat_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(388, 200)

#### ...or hierarchical dataloader

In [8]:
%run training/report_generation/hierarchical.py

In [9]:
BS = 5

train_dataloader = create_hierarchical_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_hierarchical_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(388, 200)

### Create CNN2Seq model

In [10]:
%run ./models/classification/__init__.py
%run ./models/report_generation/cnn_to_seq.py
%run ./models/checkpoint/__init__.py

#### Load CNN

In [11]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### ..or create CNN

In [11]:
cnn = init_empty_model('mobilenet', # resnet-50 # densenet-121
                       labels=[],
                       imagenet=True,
                       freeze=False,
                       ).to(DEVICE)

#### Create Flat LSTM decoder

In [26]:
%run ./models/report_generation/decoder_lstm.py

In [27]:
decoder = LSTMDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                      teacher_forcing=True).to(DEVICE)

#### ...or with attention

In [88]:
%run ./models/report_generation/decoder_lstm_att.py

In [89]:
decoder_att = LSTMAttDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                             teacher_forcing=True).to(DEVICE)

#### ...or hierarchical decoder

In [74]:
%run ./models/report_generation/decoder_h_lstm_att.py

In [75]:
decoder_h = HierarchicalLSTMAttDecoder(VOCAB_SIZE, 100, 100, cnn.features_size,
                                       teacher_forcing=True).to(DEVICE)

#### Full model

In [90]:
model = CNN2Seq(cnn, decoder_att).to(DEVICE)

In [91]:
# model = nn.DataParallel(model)

In [92]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(model, optimizer)

### ...or Load CNN2Seq model

In [5]:
%run models/checkpoint/__init__.py

In [6]:
run_name = '0714_181427_lstm_lr0.0001_densenet-121'
debug = True

compiled_model = load_compiled_model_report_generation(run_name, debug=debug, device=DEVICE)

compiled_model.metadata

{'cnn_kwargs': {'model_name': 'densenet-121',
  'labels': [],
  'imagenet': True,
  'freeze': False},
 'decoder_kwargs': {'decoder_name': 'lstm',
  'vocab_size': 443,
  'embedding_size': 100,
  'hidden_size': 100,
  'features_size': [1024, 16, 16],
  'teacher_forcing': True},
 'opt_kwargs': {'lr': 0.0001},
 'hparams': {'pretrained_cnn': None}}

## Train

In [45]:
%run -n train_report_generation.py

In [79]:
%%time

train_metrics, val_metrics = train_model('debugging',
                                         compiled_model,
                                         train_dataloader,
                                         val_dataloader,
                                         n_epochs=2,
                                         hierarchical=True,
                                         dryrun=True,
                                         save_model=False,
                                         debug=True,
                                         device=DEVICE)

Run:  debugging
--------------------------------------------------
Training...
Finished epoch 1/2 loss 6.9003 6.8527, bleu 0.0062 0.0008, 0h 2m 18s
Finished epoch 2/2 loss 6.5812 6.7018, bleu 0.0090 0.0000, 0h 2m 18s
Average time per epoch:  0h 2m 18s
--------------------------------------------------
CPU times: user 34min 42s, sys: 2min 5s, total: 36min 47s
Wall time: 4min 36s


In [34]:
val_metrics

{'loss': 4.245853958970024,
 'word_acc': 0.12073121735636802,
 'bleu1': 0.21052631578943184,
 'bleu2': 0.10222581386095834,
 'bleu3': 0.028250348198185463,
 'bleu4': 1.5015660752762683e-06,
 'bleu': 0.08525099485366272,
 'rougeL': 0.21206891447801543,
 'ciderD': 0.07031939644772105}

In [35]:
train_metrics

{'loss': 3.919986795364116,
 'word_acc': 0.14656188605108056,
 'bleu1': 0.17589893100093762,
 'bleu2': 0.04918162889166112,
 'bleu3': 7.98833985879885e-08,
 'bleu4': 1.0291031988067876e-10,
 'bleu': 0.05627015996972691,
 'rougeL': 0.16543961239792393,
 'ciderD': 0.044988202788620396}

## Debug flatten reports

In [21]:
import numpy as np

In [17]:
image_features = torch.rand(2, *cnn.features_size).to(DEVICE)
image_features.size()

torch.Size([2, 1024, 16, 16])

In [19]:
reports_h = torch.tensor([[[1, 2, 3, 0],
                           [1, 5, 0, 0],
                           [2, 2, 2, 0],
                          ],
                          [[7, 9, 10, 0],
                           [1, 4, 0, 0],
                           [8, 9, 0, 0],
                          ],
                         ]).to(DEVICE)
reports_h.size()

torch.Size([2, 3, 4])

In [37]:
def _flatten_gt_reports(reports):
    texts = []

    for report in reports:
        text = []
        for sentence in report:
            sentence = np.trim_zeros(sentence.detach().cpu().numpy())
            if len(sentence) > 0:
                text.extend(sentence)

        texts.append(torch.tensor(text))

    return pad_sequence(texts, batch_first=True)

In [38]:
_flatten_gt_reports(reports_h)

tensor([[ 1,  2,  3,  1,  5,  2,  2,  2],
        [ 7,  9, 10,  1,  4,  8,  9,  0]])

In [19]:
r, st, sc = decoder_h(image_features, 0, reports_h)
r.size(), st.size()

(torch.Size([2, 3, 4, 443]), torch.Size([2, 3]))

In [24]:
r2 = _flatten_h_reports(r, st)
r2.size()

torch.Size([2, 4])

In [75]:
threshold = 0.35

In [76]:
tmp = torch.arange(st.size()[1], 0, -1)
tmp2 = tmp * (st.cpu() > threshold).long()

indices = torch.argmax(tmp2, 1, keepdim=True)
indices

tensor([[0],
        [0]])

In [81]:
indices.size()

torch.Size([2, 1])

In [79]:
r.size()

torch.Size([2, 3, 4, 443])

In [85]:
indices.view(-1)

tensor([0, 0])

In [91]:
_, r2 = r.max(dim=-1)
r2.size()

torch.Size([2, 3, 4])

In [92]:
for a in r2:
    break

In [93]:
a.size()

torch.Size([3, 4])

In [94]:
a

tensor([[206,   5,   6,  38],
        [206,   5,   6,  38],
        [206,   5,   6,  38]], device='cuda:1')

## Test samples

In [93]:
_ = compiled_model.model.eval()

In [94]:
%run utils/nlp.py

In [95]:
report_reader = ReportReader(train_dataset.get_vocab())

In [96]:
idx = 10

In [97]:
image, report = train_dataset[idx]
image.size(), len(report)

(torch.Size([3, 512, 512]), 12)

In [98]:
reports = torch.tensor(report).unsqueeze(0).to(DEVICE)
reports.size()

torch.Size([1, 12])

In [100]:
images = image.unsqueeze(0).to(DEVICE)
tup = compiled_model.model(images, free=True, max_words=100)
generated = tup[0]
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
# print(generated.size())

report_reader.idx_to_text(generated)

'some excluded tortuosity includes calcifications osteopenia loss changes have spine osteopenia borderline mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal irregularity significantly have mid free frontal frontal'

In [101]:
report_reader.idx_to_text(report)

'both lungs are clear and expanded . heart and mediastinum normal .'

### Search reports with a certain pattern

In [24]:
from tqdm.notebook import tqdm
import re

In [65]:
# target = re.compile(r'\A[a-zA-Z]+ size is normal')
target = re.compile('both lungs are clear and expanded')
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if target.search(report):
        found.append(report)

len(found)

162

In [70]:
found_diff = list(set(found))
len(found_diff)

19

In [76]:
found_diff[5]

'chest . both lungs are clear and expanded with no pleural air collections or parenchymal consolidations . heart and mediastinum remain normal . lumbosacral spine . xxxx , disc spaces , and alignment are normal . sacrum and sacroiliac joints are normal . END'

## Debug metrics

In [8]:
from ignite.metrics import MetricsLambda

In [6]:
%run metrics/report_generation/bleu.py

In [9]:
bleu_up_to_4 = Bleu(n=4)

In [10]:
bleu1 = MetricsLambda(lambda x: x[0], bleu_up_to_4)
bleu2 = MetricsLambda(lambda x: x[1], bleu_up_to_4)
bleu3 = MetricsLambda(lambda x: x[2], bleu_up_to_4)
bleu4 = MetricsLambda(lambda x: x[3], bleu_up_to_4)
bleuAvg = MetricsLambda(lambda x: torch.mean(x), bleu_up_to_4)

## Debug attention

In [15]:
from torch import nn

In [146]:
%run models/report_generation/decoder_lstm_att.py

In [147]:
decoder = LSTMAttDecoder(200, 100, 100, (2048, 16, 16))

In [148]:
images = torch.randn(3, 2048, 16, 16).float()
images.size()

torch.Size([3, 2048, 16, 16])

In [149]:
out, scores = decoder(images, 10)
out.size(), scores.size()

(torch.Size([3, 10, 200]), torch.Size([3, 10, 16, 16]))

In [140]:
feats, scores = att(images, h_state)
feats.size(), scores.size()

(torch.Size([3, 2048]), torch.Size([3, 16, 16]))

## Debug LSTM

In [1]:
import torch

In [4]:
from mrg.utils.nlp import END_IDX
END_IDX

1

In [10]:
prediction_t = torch.rand(5, 4)
is_end_predicted = prediction_t.argmax(dim=-1) == END_IDX
is_end_predicted

tensor([False, False, False, False,  True])

In [11]:
is_end_predicted | is_end_predicted

tensor([False, False, False, False,  True])

In [16]:
torch.tensor(True).repeat(5).all()

tensor(True)