## Imports

In [1]:
import torch
from torch import nn
from torch import optim
from ignite.engine import Engine, Events
from ignite.handlers import Timer #, EarlyStopping

import time

## Functions

In [2]:
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

### Dataloader

In [3]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [4]:
def create_dataloader(dataset, batch_size=128, shuffle=False):
    def collate_fn(batch_tuples):
        images = []
        batch_seq_out = []
        for image, seq_out in batch_tuples:
            images.append(image)
            batch_seq_out.append(seq_out)

        images = torch.stack(images)
        batch_seq_out = pad_sequence(batch_seq_out, batch_first=True)
        return images, batch_seq_out

    dataloader = DataLoader(dataset, batch_size, collate_fn=collate_fn,
                            shuffle=shuffle)
    return dataloader

### Step

In [5]:
def get_step_fn(model, optimizer=None, training=True, device=DEVICE):
    """Creates a step function for an Engine."""
    loss_fn = nn.CrossEntropyLoss()
    def step_fn(engine, data_batch):
        # Images
        images = data_batch[0].to(device)
        # shape: batch_size, 3, height, width

        # Reports, as word ids
        reports = data_batch[1].to(device).long()
        _, max_sentence_len = reports.size()
        # shape: batch_size, max_sentence_len
        
        # Enable training
        model.train(training)
        torch.set_grad_enabled(training) # enable recording gradients

        # zero the parameter gradients
        if training:
            optimizer.zero_grad()

        # Pass thru the model
        output_tuple = model(images, max_sentence_len)

        generated_words = output_tuple[0]
        _, _, vocab_size = generated_words.size()
        # shape: batch_size, n_sentences, vocab_size

        # Compute classification loss
        loss = loss_fn(generated_words.view(-1, vocab_size), reports.view(-1))
        
        batch_loss = loss.item()

        if training:
            loss.backward()
            optimizer.step()

        return batch_loss, generated_words, reports

    return step_fn

### Metrics

In [6]:
%run utils/__init__.py

In [7]:
from ignite.metrics import RunningAverage

In [8]:
%run metrics/report_generation/word_accuracy.py

In [9]:
def _transform_word_accuracy(outputs):
    _, generated_scores, seq = outputs
    _, words_predicted = generated_scores.max(dim=2)
    return words_predicted, seq

def attach_metrics_report_generation(engine):
    loss = RunningAverage(output_transform=lambda x: x[0])
    loss.attach(engine, 'loss')
    
    word_acc = WordAccuracy(output_transform=_transform_word_accuracy)
    word_acc.attach(engine, 'word_acc')

### Train

In [10]:
def train_model(model, train_dataloader, val_dataloader, n_epochs=1, lr=0.0001):
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Create validator engine
    validator = Engine(get_step_fn(model, training=False))
    attach_metrics_report_generation(validator)
    
    # Create trainer engine
    trainer = Engine(get_step_fn(model, optimizer=optimizer, training=True))
    attach_metrics_report_generation(trainer)
    
    # Create Timer to measure wall time between epochs
    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        epoch = trainer.state.epoch
        max_epochs = trainer.state.max_epochs

        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        train_loss = trainer.state.metrics.get('loss', 0)
        val_loss = validator.state.metrics.get('loss', 0)
        loss_str = f'loss {train_loss:.4f}, {val_loss:.4f}'
        
        train_acc = trainer.state.metrics.get('word_acc', 0)
        val_acc = validator.state.metrics.get('word_acc', 0)
        acc_str = f'acc {train_acc:.4f}, {val_acc:.4f}'
        
        duration_str = duration_to_str(timer._elapsed())
        print(f'Finished epoch {epoch}/{max_epochs}, {loss_str}, {acc_str} (took {duration_str})')

    # Train!
    print('-' * 50)
    print('Training...')
    trainer.run(train_dataloader, n_epochs)

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = duration_to_str(secs_per_epoch)
    print('Average time per epoch: ', duration_per_epoch)
    print('-'*50)

    return trainer.state.metrics, validator.state.metrics

## Load stuff

In [11]:
%run ./datasets/iu_xray.py

In [12]:
train_dataset = IUXRayDataset(dataset_type='train')
val_dataset = IUXRayDataset(dataset_type='val', vocab=train_dataset.get_vocab())
train_dataset.size(), val_dataset.size()

((5927, 3062), (751, 382))

In [13]:
BS = 10

train_dataloader = create_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(5927, 3062)

In [14]:
%run ./datasets/cxr14.py
%run ./models/classification/__init__.py
%run ./models/report_generation/decoder_lstm.py
%run ./models/report_generation/cnn_to_seq.py

In [15]:
cnn = init_empty_model('resnet', CXR14_DISEASES).to(DEVICE)
decoder = LSTMDecoder(len(train_dataset.word_to_idx), 100, 100).to(DEVICE)

In [16]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

## Train

In [None]:
%%time

train_metrics, val_metrics = train_model(model, train_dataloader, val_dataloader,
                                         n_epochs=10, lr=0.0001)

--------------------------------------------------
Training...
Finished epoch 1/10, loss 4.9905, 4.9460, acc 0.1193, 0.1425 (took 0h 1m 26s)
Finished epoch 2/10, loss 4.5256, 4.6848, acc 0.1488, 0.1535 (took 0h 1m 27s)


## Test samples

In [18]:
idx_to_word = {v: k for k, v in train_dataset.get_vocab().items()}
# idx_to_word

In [19]:
def idx_to_text(idxs):
    return ' '.join([idx_to_word[int(g.item())] for g in idxs])

In [22]:
image, report = train_dataset[0]
image.size(), report.size()

(torch.Size([3, 512, 512]), torch.Size([5]))

In [25]:
images = image.unsqueeze(0).to(DEVICE)
generated, = model(images, report.size()[0])
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
print(generated.size())
print(generated)

idx_to_text(generated)

torch.Size([5])
tensor([119, 918, 941, 856, 288])


'an critical additional injuries defibrillator'

In [24]:
idx_to_text(report)

'no active disease . END'

In [85]:
from tqdm.notebook import tqdm

In [94]:
import re

In [98]:
target = 'the size is normal'
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if re.search(r'\A[a-zA-Z]+ size is normal', report):
        found.append(report)

len(found)

213

In [99]:
found

['heart size is normal and lungs are clear END',
 'heart size is normal and lungs are clear END',
 'heart size is normal the lungs are clear END',
 'heart size is normal the lungs are clear END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and lungs are clear . END',
 'heart size is normal and the lungs are clear . END',
 'heart size is normal and the lungs are clear . END',
 'heart size is normal and the lungs are clear . END',
 'heart size is normal and the lungs are clear . END',
 'heart size is normal in the lungs are clear . END',
 'heart size is normal and the lungs are clear . END',
 'heart size is 