## Imports

In [1]:
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Timer #, EarlyStopping
from torch import optim

In [2]:
%run ./datasets/__init__.py

In [3]:
%run ./utils/__init__.py

In [4]:
%run ./models/classification/__init__.py

In [5]:
%run ./losses/__init__.py

In [6]:
%run ./metrics/classification/__init__.py

In [24]:
%run utils/__init__.py

## Functions

In [14]:
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [42]:
def get_step_fn(model, optimizer, loss_fn, training=True, device=DEVICE):
    """Creates a step function for an Engine."""
    def step_fn(engine, data_batch):
        # Input and sizes
        images, labels, names, _, _ = data_batch
        n_samples, n_labels = labels.size()
        
        # Move tensors to GPU
        images = images.to(device)
        labels = labels.to(device)

        # Enable training
        model.train(training)
        torch.set_grad_enabled(training) # enable recording gradients

        # zero the parameter gradients
        optimizer.zero_grad()

        # Forward, receive outputs from the model and segments (bboxes)
        output_tuple = model(images)
        outputs = output_tuple[0]

        # Compute classification loss
        loss = loss_fn(outputs, labels.float())
        
        batch_loss = loss.item()

        if training:
            loss.backward()
            optimizer.step()

        return batch_loss, outputs, labels

    return step_fn

In [43]:
def train_model(model, train_dataloader, val_dataloader, n_epochs=1, lr=0.0001,
               loss_name='wbce'):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss = get_loss_function(loss_name)
    
    labels = dataloader.dataset.labels

    # Create validator engine
    validator = Engine(get_step_fn(model, optimizer, loss, training=False))
    attach_metrics_classification(validator, labels, loss_name)
    
    # Create trainer engine
    trainer = Engine(get_step_fn(model, optimizer, loss, training=True))
    attach_metrics_classification(trainer, labels, loss_name)
    
    # Create Timer to measure wall time between epochs
    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        epoch = trainer.state.epoch
        max_epochs = trainer.state.max_epochs

        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        train_loss = trainer.state.metrics.get(loss_name, 0)
        val_loss = validator.state.metrics.get(loss_name, 0)
        
        loss_str = f'loss {train_loss:.4f}, {val_loss:.4f}'
        duration_str = duration_to_str(timer._elapsed())
        print(f'Finished epoch {epoch}/{max_epochs}, {loss_str} (took {duration_str})')

    # Train!
    print('-' * 50)
    print('Training...')
    trainer.run(train_dataloader, n_epochs)
    

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = duration_to_str(secs_per_epoch)
    print('Average time per epoch: ', duration_per_epoch)
    print('-'*50)

    return trainer.state.metrics, validator.state.metrics

## Load stuff

In [40]:
train_dataloader = prepare_data_classification(dataset_type='train', max_images=100)
val_dataloader = prepare_data_classification(dataset_type='val', max_images=100)
train_dataloader.dataset.size()

Loading train dataset...
Loading val dataset...


(100, 14)

In [18]:
model = init_empty_model('resnet', dataloader.dataset.labels).to(DEVICE)

## Train

In [44]:
train_metrics, val_metrics = train_model(model, train_dataloader, val_dataloader,
                                         n_epochs=10, loss_name='bce')

--------------------------------------------------
Training...
Finished epoch 1/10, loss 0.2112, 0.2235 (took 0h 0m 4s)
Finished epoch 2/10, loss 0.0704, 0.1989 (took 0h 0m 4s)
Finished epoch 3/10, loss 0.0544, 0.1791 (took 0h 0m 4s)
Finished epoch 4/10, loss 0.0624, 0.1479 (took 0h 0m 4s)
Finished epoch 5/10, loss 0.0533, 0.1285 (took 0h 0m 4s)
Finished epoch 6/10, loss 0.0559, 0.1255 (took 0h 0m 4s)
Finished epoch 7/10, loss 0.0592, 0.1287 (took 0h 0m 4s)
Finished epoch 8/10, loss 0.0528, 0.1273 (took 0h 0m 4s)
Finished epoch 9/10, loss 0.0526, 0.1253 (took 0h 0m 4s)
Finished epoch 10/10, loss 0.0524, 0.1255 (took 0h 0m 4s)
Average time per epoch:  0h 0m 4s
--------------------------------------------------


In [45]:
[(k, v) for k, v in val_metrics.items() if 'acc' in k]

[('acc_Atelectasis', 0.95),
 ('acc_Cardiomegaly', 0.99),
 ('acc_Effusion', 0.89),
 ('acc_Infiltration', 0.83),
 ('acc_Mass', 0.95),
 ('acc_Nodule', 0.97),
 ('acc_Pneumonia', 0.99),
 ('acc_Pneumothorax', 0.97),
 ('acc_Consolidation', 0.93),
 ('acc_Edema', 0.99),
 ('acc_Emphysema', 0.98),
 ('acc_Fibrosis', 1.0),
 ('acc_Pleural_Thickening', 0.98),
 ('acc_Hernia', 1.0)]