## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

0,1


In [2]:
import torch
# from torch import nn
# import time

In [3]:
%run -n train_report_generation.py

In [4]:
DEVICE = torch.device('cuda', 1)
DEVICE

device(type='cuda', index=1)

## Load stuff

### Load data

In [5]:
%run ./datasets/iu_xray.py

In [6]:
max_samples = 20

In [7]:
train_dataset = IUXRayDataset(dataset_type='train', max_samples=max_samples)
val_dataset = IUXRayDataset(dataset_type='val', max_samples=max_samples,
                            vocab=train_dataset.get_vocab())
train_dataset.size(), val_dataset.size()

((42, 20), (37, 20))

In [8]:
BS = 5

train_dataloader = create_pad_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_pad_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(42, 20)

### Load model

In [9]:
%run ./models/classification/__init__.py
%run ./models/report_generation/decoder_lstm.py
%run ./models/report_generation/cnn_to_seq.py
%run ./models/checkpoint/__init__.py

#### Load CNN

In [11]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### Or create CNN

In [10]:
cnn = init_empty_model('resnet-50',
                       labels=[],
                       imagenet=True,
                       freeze=False,
                       ).to(DEVICE)

#### Create Decoder

In [11]:
decoder = LSTMDecoder(len(train_dataset.word_to_idx), 100, 100,
                      teacher_forcing=True).to(DEVICE)

#### Full model

In [12]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(model, optimizer)

## Train

In [13]:
%%time

train_metrics, val_metrics = train_model('debugging',
                                         compiled_model,
                                         train_dataloader,
                                         val_dataloader,
                                         n_epochs=2,
                                         dryrun=True,
                                         device=DEVICE)

--------------------------------------------------
Training...
Finished epoch 1/2 loss 5.4560 5.4668, word_acc 0.0022 0.0161, 0h 0m 5s
Finished epoch 2/2 loss 5.4228 5.4313, word_acc 0.0044 0.0000, 0h 0m 5s
Average time per epoch:  0h 0m 5s
--------------------------------------------------
CPU times: user 46.9 s, sys: 39.5 s, total: 1min 26s
Wall time: 11.8 s


In [15]:
val_metrics

{'loss': 5.431304320209503, 'word_acc': 0.0, 'bleu': 0.011711711785813284}

In [14]:
train_metrics

{'loss': 5.42281898562622,
 'word_acc': 0.004357298474945534,
 'bleu': 0.01834862394682142}

## Test samples

In [19]:
idx_to_word = {v: k for k, v in train_dataset.get_vocab().items()}
# idx_to_word

In [20]:
def idx_to_text(idxs):
    return ' '.join([idx_to_word[int(g.item())] for g in idxs])

In [41]:
image, report = train_dataset[-1]
image.size(), report.size()

(torch.Size([3, 512, 512]), torch.Size([79]))

In [42]:
images = image.unsqueeze(0).to(DEVICE)
generated, = model(images, report.size()[0])
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
print(generated.size())
print(generated)

idx_to_text(generated)

torch.Size([79])
tensor([76,  4,  7,  8, 12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14])


'status the and mediastinum normal . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'

In [23]:
idx_to_text(report)

'heart size and pulmonary vascularity appear within normal limits . retrocardiac soft tissue density is present . there appears to be air within this which could suggest that this represents a hiatal hernia . vascular calcification is noted . calcified granuloma is seen . there has been interval development of bandlike opacity in the left lung base . this may represent atelectasis . no pneumothorax or pleural effusion is seen . osteopenia is present in the spine . END'

In [24]:
from tqdm.notebook import tqdm

In [25]:
import re

In [26]:
target = 'the size is normal'
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if re.search(r'\A[a-zA-Z]+ size is normal', report):
        found.append(report)

len(found)

0

In [27]:
found

[]