## Imports

In [1]:
!echo $CUDA_VISIBLE_DEVICES

0,1


In [2]:
import torch
from torch import nn
from torch import optim
from ignite.engine import Engine, Events
from ignite.handlers import Timer #, EarlyStopping

import time

## Functions

In [2]:
!echo ${CUDA_VISIBLE_DEVICES}

2,3


In [3]:
DEVICE = torch.device('cuda:1')
DEVICE

device(type='cuda', index=1)

### Dataloader

In [4]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [5]:
def create_dataloader(dataset, batch_size=128, shuffle=False):
    def collate_fn(batch_tuples):
        images = []
        batch_seq_out = []
        for image, seq_out in batch_tuples:
            images.append(image)
            batch_seq_out.append(seq_out)

        images = torch.stack(images)
        batch_seq_out = pad_sequence(batch_seq_out, batch_first=True)
        return images, batch_seq_out

    dataloader = DataLoader(dataset, batch_size, collate_fn=collate_fn,
                            shuffle=shuffle)
    return dataloader

### Step

In [6]:
def get_step_fn(model, optimizer=None, training=True, device=DEVICE):
    """Creates a step function for an Engine."""
    loss_fn = nn.CrossEntropyLoss()
    def step_fn(engine, data_batch):
        # Images
        images = data_batch[0].to(device)
        # shape: batch_size, 3, height, width

        # Reports, as word ids
        reports = data_batch[1].to(device).long() # shape: batch_size, max_sentence_len
        _, max_sentence_len = reports.size()
        
        # Enable training
        model.train(training)
        torch.set_grad_enabled(training) # enable recording gradients

        # zero the parameter gradients
        if training:
            optimizer.zero_grad()

        # Pass thru the model
        output_tuple = model(images, max_sentence_len, reports)

        generated_words = output_tuple[0]
        _, _, vocab_size = generated_words.size()
        # shape: batch_size, n_sentences, vocab_size

        # Compute classification loss
        loss = loss_fn(generated_words.view(-1, vocab_size), reports.view(-1))
        
        batch_loss = loss.item()

        if training:
            loss.backward()
            optimizer.step()

        return batch_loss, generated_words, reports

    return step_fn

### Metrics

In [7]:
%run utils/__init__.py

In [8]:
from ignite.metrics import RunningAverage

In [9]:
%run metrics/report_generation/word_accuracy.py

In [10]:
def _transform_word_accuracy(outputs):
    _, generated_scores, seq = outputs
    _, words_predicted = generated_scores.max(dim=2)
    return words_predicted, seq

def attach_metrics_report_generation(engine):
    loss = RunningAverage(output_transform=lambda x: x[0])
    loss.attach(engine, 'loss')
    
    word_acc = WordAccuracy(output_transform=_transform_word_accuracy)
    word_acc.attach(engine, 'word_acc')

### Train

In [11]:
%run utils/runs.py
%run tensorboard.py

In [12]:
def train_model(run_name, model, train_dataloader, val_dataloader, n_epochs=1, lr=0.0001,
                debug=True):
    # Prepare run
    run_state = RunState(run_name, classification=False, debug=debug)
    initial_epoch = run_state.current_epoch()
    if initial_epoch > 0:
        print('Found previous run on epoch: ', initial_epoch)

    # TB writer
    tb_writer = TBWriter(run_name, classification=False, debug=debug)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Create validator engine
    validator = Engine(get_step_fn(model, training=False))
    attach_metrics_report_generation(validator)
    
    # Create trainer engine
    trainer = Engine(get_step_fn(model, optimizer=optimizer, training=True))
    attach_metrics_report_generation(trainer)
    
    # Create Timer to measure wall time between epochs
    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        # State
        epoch = trainer.state.epoch + initial_epoch
        max_epochs = trainer.state.max_epochs + initial_epoch
        run_state.save_state(epoch)
        
        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        # Metrics
        train_metrics = trainer.state.metrics
        val_metrics = validator.state.metrics
        
        # Log to TB
        tb_writer.write_metrics(train_metrics, 'train', epoch, wall_time)
        tb_writer.write_metrics(val_metrics, 'val', epoch, wall_time)
        tb_writer.write_histogram(model, epoch, wall_time)
        
        # Print metrics
        train_loss = train_metrics.get('loss', -1)
        val_loss = val_metrics.get('loss', -1)
        loss_str = f'loss {train_loss:.4f}, {val_loss:.4f}'
        
        train_acc = train_metrics.get('word_acc', -1)
        val_acc = val_metrics.get('word_acc', -1)
        acc_str = f'acc {train_acc:.4f}, {val_acc:.4f}'
        
        duration_str = duration_to_str(timer._elapsed())
        print(f'Finished epoch {epoch}/{max_epochs}, {loss_str}, {acc_str} (took {duration_str})')

    # Train!
    print('-' * 50)
    print('Training...')
    trainer.run(train_dataloader, n_epochs)

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = duration_to_str(secs_per_epoch)
    print('Average time per epoch: ', duration_per_epoch)
    print('-'*50)
    
    # Close stuff
    tb_writer.close()

    return

## Load stuff

### Load data

In [13]:
%run ./datasets/iu_xray.py

In [14]:
train_dataset = IUXRayDataset(dataset_type='train')
val_dataset = IUXRayDataset(dataset_type='val', vocab=train_dataset.get_vocab())
train_dataset.size(), val_dataset.size()

((5927, 3062), (751, 382))

In [15]:
BS = 15

train_dataloader = create_dataloader(train_dataset, batch_size=BS)
val_dataloader = create_dataloader(val_dataset, batch_size=BS)
train_dataloader.dataset.size()

(5927, 3062)

### Load model

In [16]:
from ignite.handlers import Checkpoint

In [17]:
%run ./datasets/cxr14.py
%run ./models/classification/__init__.py
%run ./models/report_generation/decoder_lstm.py
%run ./models/report_generation/cnn_to_seq.py
%run ./models/checkpoint.py

#### Load CNN

In [18]:
# TODO: create load_classification_model() function

run_name = '0627_151524_cxr14_resnet_lr1e-06'
debug_run = True

pretrained_cnn = init_empty_model('resnet',
                                  CXR14_DISEASES,
                                  multilabel=True,
                                 ).to(DEVICE)

dummy_optimizer = optim.Adam(pretrained_cnn.parameters(), lr=0.0001)

compiled_model = CompiledModel(pretrained_cnn, dummy_optimizer)
filepath = get_latest_filepath(run_name, classification=True, debug=debug_run)
checkpoint = torch.load(filepath)
Checkpoint.load_objects(compiled_model.to_save_checkpoint(), checkpoint)

In [19]:
cnn = pretrained_cnn

#### Create Decoder

In [20]:
decoder = LSTMDecoder(len(train_dataset.word_to_idx), 100, 100,
                      teacher_forcing=True).to(DEVICE)

#### Full model

In [21]:
model = CNN2Seq(cnn, decoder).to(DEVICE)

## Train

In [22]:
%%time

train_model('tf-new', model, train_dataloader, val_dataloader, n_epochs=10, lr=0.0001)

--------------------------------------------------
Training...
Finished epoch 1/10, loss 6.3861, 5.0906, acc 0.1129, 0.1393 (took 0h 1m 31s)
Finished epoch 2/10, loss 4.4598, 4.4035, acc 0.1440, 0.1393 (took 0h 1m 32s)
Finished epoch 3/10, loss 4.2623, 4.3845, acc 0.1457, 0.1037 (took 0h 1m 32s)
Finished epoch 4/10, loss 4.0759, 4.3604, acc 0.1989, 0.0923 (took 0h 1m 32s)
Finished epoch 5/10, loss 3.8773, 4.3706, acc 0.2625, 0.1014 (took 0h 1m 31s)
Finished epoch 6/10, loss 3.6922, 4.3984, acc 0.3170, 0.0986 (took 0h 1m 32s)
Finished epoch 7/10, loss 3.5309, 4.5021, acc 0.3433, 0.0907 (took 0h 1m 32s)
Finished epoch 8/10, loss 3.3881, 4.5830, acc 0.3590, 0.0907 (took 0h 1m 32s)
Finished epoch 9/10, loss 3.2619, 4.5946, acc 0.3722, 0.1074 (took 0h 1m 32s)
Finished epoch 10/10, loss 3.1505, 4.6772, acc 0.3903, 0.1074 (took 0h 1m 32s)
Average time per epoch:  0h 1m 32s
--------------------------------------------------
CPU times: user 40min 12s, sys: 7min 48s, total: 48min 1s
Wall time: 1

## Test samples

In [19]:
idx_to_word = {v: k for k, v in train_dataset.get_vocab().items()}
# idx_to_word

In [20]:
def idx_to_text(idxs):
    return ' '.join([idx_to_word[int(g.item())] for g in idxs])

In [41]:
image, report = train_dataset[-1]
image.size(), report.size()

(torch.Size([3, 512, 512]), torch.Size([79]))

In [42]:
images = image.unsqueeze(0).to(DEVICE)
generated, = model(images, report.size()[0])
_, generated = generated.max(dim=2)
generated = generated.squeeze(0).cpu()
print(generated.size())
print(generated)

idx_to_text(generated)

torch.Size([79])
tensor([76,  4,  7,  8, 12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14])


'status the and mediastinum normal . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'

In [23]:
idx_to_text(report)

'heart size and pulmonary vascularity appear within normal limits . retrocardiac soft tissue density is present . there appears to be air within this which could suggest that this represents a hiatal hernia . vascular calcification is noted . calcified granuloma is seen . there has been interval development of bandlike opacity in the left lung base . this may represent atelectasis . no pneumothorax or pleural effusion is seen . osteopenia is present in the spine . END'

In [24]:
from tqdm.notebook import tqdm

In [25]:
import re

In [26]:
target = 'the size is normal'
found = []

for report in train_dataset.reports:
    report = idx_to_text(report['tokens_idxs'])
    if re.search(r'\A[a-zA-Z]+ size is normal', report):
        found.append(report)

len(found)

0

In [27]:
found

[]