## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

0,1


In [2]:
import torch
from torch import nn
from torch import optim
from ignite.engine import Engine, Events
from ignite.handlers import Timer #, EarlyStopping

import time

In [3]:
DEVICE = torch.device('cuda', 1)
DEVICE

device(type='cuda', index=1)

## Load stuff

### Load data

In [14]:
%run ./datasets/iu_xray.py

In [15]:
train_dataset = IUXRayDataset(dataset_type='train')
val_dataset = IUXRayDataset(dataset_type='val', vocab=train_dataset.get_vocab())
train_dataset.size(), val_dataset.size()

((5927, 3062), (751, 382))

In [16]:
BS = 15

train_dataloader = create_pad_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_pad_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(5927, 3062)

### Load model

In [18]:
%run ./datasets/__init__.py
%run ./models/classification/__init__.py
%run ./models/report_generation/decoder_lstm.py
%run ./models/report_generation/cnn_to_seq.py
%run ./models/checkpoint/__init__.py

#### Load CNN

In [22]:
cnn_run_name = '0627_151lr=lr524_cxr14_resnet_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### Or create CNN

In [None]:
cnn = init_empty_model('resnet',
                       labels=[],
                       imagenet=True,
                       freeze=False,
                       ).to(DEVICE)

#### Create Decoder

In [23]:
decoder = LSTMDecoder(len(train_dataset.word_to_idx), 100, 100,
                      teacher_forcing=True).to(DEVICE)

#### Full model

In [24]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(model, optimizer)

## Train

In [1]:
%run -n train_report_generation.py

In [48]:
%%time

train_model('debugging-2', compiled_model, train_dataloader, val_dataloader,
            n_epochs=2, device=DEVICE)

--------------------------------------------------
Training...
Finished epoch 10/11 loss 7.4644 7.4678, word_acc 0.0214 0.0116, 0h 1m 0s
Finished epoch 11/11 loss 7.4637 7.4666, word_acc 0.0219 0.0119, 0h 1m 1s
Average time per epoch:  0h 1m 1s
--------------------------------------------------
CPU times: user 7min 11s, sys: 35.3 s, total: 7min 46s
Wall time: 2min 1s


## Test samples

In [19]:
idx_to_word = {v: k for k, v in train_dataset.get_vocab().items()}
# idx_to_word

In [20]:
def idx_to_text(idxs):
    return ' '.join([idx_to_word[int(g.item())] for g in idxs])

In [41]:
image, report = train_dataset[-1]
image.size(), report.size()

(torch.Size([3, 512, 512]), torch.Size([79]))

In [42]:
images = image.unsqueeze(0).to(DEVICE)
generated, = model(images, report.size()[0])
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
print(generated.size())
print(generated)

idx_to_text(generated)

torch.Size([79])
tensor([76,  4,  7,  8, 12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14])


'status the and mediastinum normal . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'

In [23]:
idx_to_text(report)

'heart size and pulmonary vascularity appear within normal limits . retrocardiac soft tissue density is present . there appears to be air within this which could suggest that this represents a hiatal hernia . vascular calcification is noted . calcified granuloma is seen . there has been interval development of bandlike opacity in the left lung base . this may represent atelectasis . no pneumothorax or pleural effusion is seen . osteopenia is present in the spine . END'

In [24]:
from tqdm.notebook import tqdm

In [25]:
import re

In [26]:
target = 'the size is normal'
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if re.search(r'\A[a-zA-Z]+ size is normal', report):
        found.append(report)

len(found)

0

In [27]:
found

[]