## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

2,3


In [2]:
import torch
from torch import nn

In [3]:
%run -n train_report_generation.py

In [4]:
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

## Load stuff

### Load data

In [5]:
%run ./datasets/iu_xray.py

In [6]:
max_samples = None

In [7]:
train_dataset = IUXRayDataset(dataset_type='train', max_samples=max_samples)
val_dataset = IUXRayDataset(dataset_type='val', max_samples=max_samples,
                            vocab=train_dataset.get_vocab())
train_dataset.size(), val_dataset.size()

((5927, 3062), (751, 382))

In [8]:
BS = 20

train_dataloader = create_pad_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_pad_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(5927, 3062)

### Load model

In [9]:
%run ./models/classification/__init__.py
%run ./models/report_generation/decoder_lstm.py
%run ./models/report_generation/decoder_lstm_att.py
%run ./models/report_generation/cnn_to_seq.py
%run ./models/checkpoint/__init__.py

#### Load CNN

In [10]:
cnn_run_name = '0706_134245_covid-kaggle_tfs-small_lr1e-06'
debug_run = True

compiled_cnn = load_compiled_model_classification(cnn_run_name,
                                                  debug=debug_run,
                                                  device=DEVICE)
cnn = compiled_cnn.model

#### Or create CNN

In [10]:
cnn = init_empty_model('densenet-121', # resnet-50
                       labels=[],
                       imagenet=True,
                       freeze=False,
                       ).to(DEVICE)

#### Create Decoder

In [11]:
decoder = LSTMDecoder(len(train_dataset.word_to_idx), 100, 100, cnn.features_size,
                      teacher_forcing=True).to(DEVICE)

In [11]:
decoder_att = LSTMAttDecoder(len(train_dataset.word_to_idx), 100, 100, cnn.features_size,
                             teacher_forcing=True).to(DEVICE)

#### Full model

In [12]:
model = CNN2Seq(cnn, decoder_att).to(DEVICE)

In [13]:
model = nn.DataParallel(model)

In [14]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

compiled_model = CompiledModel(model, optimizer)

## Train

In [None]:
%%time

train_metrics, val_metrics = train_model('lstm_att',
                                         compiled_model,
                                         train_dataloader,
                                         val_dataloader,
                                         n_epochs=100,
                                         dryrun=False,
                                         save_model=False,
                                         debug=True,
                                         device=DEVICE)

--------------------------------------------------
Training...
Finished epoch 1/100 loss 6.0208 5.4252, bleu 0.0368 0.0351, 0h 1m 16s
Finished epoch 2/100 loss 4.7187 4.7282, bleu 0.0370 0.0351, 0h 1m 14s
Finished epoch 3/100 loss 4.4436 4.6362, bleu 0.0370 0.0351, 0h 1m 14s
Finished epoch 4/100 loss 4.2943 4.6106, bleu 0.0388 0.0700, 0h 1m 14s


In [15]:
val_metrics

{'loss': 5.431304320209503, 'word_acc': 0.0, 'bleu': 0.011711711785813284}

In [14]:
train_metrics

{'loss': 5.42281898562622,
 'word_acc': 0.004357298474945534,
 'bleu': 0.01834862394682142}

## Test samples

In [19]:
idx_to_word = {v: k for k, v in train_dataset.get_vocab().items()}
# idx_to_word

In [20]:
def idx_to_text(idxs):
    return ' '.join([idx_to_word[int(g.item())] for g in idxs])

In [41]:
image, report = train_dataset[-1]
image.size(), report.size()

(torch.Size([3, 512, 512]), torch.Size([79]))

In [42]:
images = image.unsqueeze(0).to(DEVICE)
generated, = model(images, report.size()[0])
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
print(generated.size())
print(generated)

idx_to_text(generated)

torch.Size([79])
tensor([76,  4,  7,  8, 12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14])


'status the and mediastinum normal . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'

In [23]:
idx_to_text(report)

'heart size and pulmonary vascularity appear within normal limits . retrocardiac soft tissue density is present . there appears to be air within this which could suggest that this represents a hiatal hernia . vascular calcification is noted . calcified granuloma is seen . there has been interval development of bandlike opacity in the left lung base . this may represent atelectasis . no pneumothorax or pleural effusion is seen . osteopenia is present in the spine . END'

In [24]:
from tqdm.notebook import tqdm

In [25]:
import re

In [26]:
target = 'the size is normal'
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if re.search(r'\A[a-zA-Z]+ size is normal', report):
        found.append(report)

len(found)

0

In [27]:
found

[]

## Debug metrics

In [8]:
from ignite.metrics import MetricsLambda

In [6]:
%run metrics/report_generation/bleu.py

In [9]:
bleu_up_to_4 = Bleu(n=4)

In [10]:
bleu1 = MetricsLambda(lambda x: x[0], bleu_up_to_4)
bleu2 = MetricsLambda(lambda x: x[1], bleu_up_to_4)
bleu3 = MetricsLambda(lambda x: x[2], bleu_up_to_4)
bleu4 = MetricsLambda(lambda x: x[3], bleu_up_to_4)
bleuAvg = MetricsLambda(lambda x: torch.mean(x), bleu_up_to_4)

## Debug attention

In [15]:
from torch import nn

In [146]:
%run models/report_generation/decoder_lstm_att.py

In [147]:
decoder = LSTMAttDecoder(200, 100, 100, (2048, 16, 16))

In [148]:
images = torch.randn(3, 2048, 16, 16).float()
images.size()

torch.Size([3, 2048, 16, 16])

In [149]:
out, scores = decoder(images, 10)
out.size(), scores.size()

(torch.Size([3, 10, 200]), torch.Size([3, 10, 16, 16]))

In [140]:
feats, scores = att(images, h_state)
feats.size(), scores.size()

(torch.Size([3, 2048]), torch.Size([3, 16, 16]))