## Imports

In [1]:
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Timer #, EarlyStopping
from torch import optim

import time

In [2]:
%run ./datasets/__init__.py

In [3]:
%run ./utils/__init__.py
%run ./utils/runs.py

In [4]:
%run ./models/classification/__init__.py

In [5]:
%run ./losses/__init__.py

In [6]:
%run ./tensorboard.py

## Functions

In [7]:
DEVICE = torch.device('cuda:1')
DEVICE

device(type='cuda', index=1)

In [19]:
def get_step_fn(model, optimizer, loss_fn, training=True, device=DEVICE):
    """Creates a step function for an Engine."""
    def step_fn(engine, data_batch):
        # Move inputs to GPU
        images = data_batch[0].to(device)
        # shape: batch_size, 3, height, width
        
        labels = data_batch[1].to(device)
        # shape: batch_size, n_labels

        # Enable training
        model.train(training)
        torch.set_grad_enabled(training) # enable recording gradients

        # zero the parameter gradients
        optimizer.zero_grad()

        # Forward
        output_tuple = model(images)
        outputs = output_tuple[0]
        # shape: batch_size, n_labels

        if model.multilabel:
            labels = labels.float()
        else:
            labels = labels.long()
        
        # Compute classification loss
        loss = loss_fn(outputs, labels)
        
        batch_loss = loss.item()

        if training:
            loss.backward()
            optimizer.step()

        return batch_loss, outputs, labels

    return step_fn

In [29]:
%run ./metrics/classification/__init__.py

In [30]:
def train_model(run_name, model, train_dataloader, val_dataloader, n_epochs=1, lr=0.0001,
               loss_name='wbce', debug=True, print_metrics=['loss', 'acc']):
    # Prepare run
    run_state = RunState(run_name, classification=True, debug=debug)
    initial_epoch = run_state.current_epoch()
    if initial_epoch > 0:
        print('Found previous run on epoch: ', initial_epoch)
    
    # Prepare optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss = get_loss_function(loss_name)
    
    # Classification labels
    labels = train_dataloader.dataset.labels
    
    # TB writer
    tb_writer = TBWriter(run_name, classification=True, debug=debug)

    # Create validator engine
    validator = Engine(get_step_fn(model, optimizer, loss, training=False))
    attach_metrics_classification(validator, labels, multilabel=model.multilabel)
    
    # Create trainer engine
    trainer = Engine(get_step_fn(model, optimizer, loss, training=True))
    attach_metrics_classification(trainer, labels, multilabel=model.multilabel)
    
    # Create Timer to measure wall time between epochs
    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_metrics(trainer):
        # Run on evaluation
        validator.run(val_dataloader, 1)

        # State
        epoch = trainer.state.epoch + initial_epoch
        max_epochs = trainer.state.max_epochs + initial_epoch
        train_metrics = trainer.state.metrics
        val_metrics = validator.state.metrics
        
        # Save state
        run_state.save_state(epoch)
        
        # Walltime
        wall_time = time.time()

        # Log to TB
        tb_writer.write_histogram(model, epoch, wall_time)
        tb_writer.write_metrics(train_metrics, 'train', epoch, wall_time)
        tb_writer.write_metrics(val_metrics, 'val', epoch, wall_time)
        
        # Print results
        print_str = f'Finished epoch {epoch}/{max_epochs}'
        for metric in print_metrics:
            if not (metric in train_metrics and metric in val_metrics):
                continue
            train_value = train_metrics[metric]
            val_value = val_metrics[metric]
            metric_str = f' {metric} {train_value:.4f} {val_value:.4f},'
            print_str += metric_str

        print_str += f' (took {duration_to_str(timer._elapsed())})'
        print(print_str)

    # Train!
    print('-' * 50)
    print('Training...')
    trainer.run(train_dataloader, n_epochs)

    # Capture time per epoch
    secs_per_epoch = timer.value()
    duration_per_epoch = duration_to_str(secs_per_epoch)
    print('Average time per epoch: ', duration_per_epoch)
    print('-'*50)

    tb_writer.close()

    return

## Load stuff

In [11]:
%run datasets/__init__.py

In [32]:
train_dataloader = prepare_data_classification('cxr14',
                                               dataset_type='train', max_samples=10)
val_dataloader = prepare_data_classification('cxr14',
                                             dataset_type='val', max_samples=10)
train_dataloader.dataset.size()

Loading train dataset...
Loading val dataset...


(10, 14)

In [13]:
model = init_empty_model('resnet',
                         train_dataloader.dataset.labels,
                         multilabel=train_dataloader.dataset.multilabel,
                        ).to(DEVICE)

## Train

In [33]:
train_model('cxr', model, train_dataloader, val_dataloader, n_epochs=3,
            loss_name='wbce', lr=0.000001,
            # print_metrics=['loss', 'spec_covid', 'rec_covid'],
           )

--------------------------------------------------
Training...
Finished epoch 1/3 loss 302.4202 277.8905, (took 0h 0m 3s)
Finished epoch 2/3 loss 286.4681 279.3835, (took 0h 0m 3s)
Finished epoch 3/3 loss 271.4820 280.6645, (took 0h 0m 3s)
Average time per epoch:  0h 0m 3s
--------------------------------------------------


## Test stuff

### Test metrics

In [None]:
from ignite.metrics import Recall, Precision

In [None]:
%run ./metrics/classification/__init__.py
%run ./metrics/classification/specificity.py

In [None]:
sp = Specificity()
rec = Recall()
prec = Precision()

In [None]:
fn = _get_transform_one_class(0)

In [None]:
# outputs = torch.tensor([[1, 2, 1, 0, 0]])
# target = torch.tensor([[1, 0, 1, 1, 2]])
outputs = torch.tensor([[0, 20, -1],
                        [-40, 2, 3],
                        [17, 5, 6],
                       ])
target = torch.tensor([0, 0, 2])
outputs, target = fn((0, outputs, target))
outputs, target

In [None]:
sp.reset()
sp.update((outputs, target))
sp.compute()

In [None]:
rec.reset()
rec.update((outputs, target))
rec.compute().item()

In [None]:
prec.reset()
prec.update((outputs, target))
prec.compute().item()

### Test samples

In [None]:
import matplotlib.pyplot as plt

In [None]:
image, label = train_dataloader.dataset[200]
images = image.unsqueeze(0)
image = image.numpy().transpose(1, 2, 0)
print(image.shape)
print(label)

plt.imshow(image)