## Imports

In [1]:
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Timer #, EarlyStopping
from torch import optim

import time

In [16]:
%run ./datasets/__init__.py

In [17]:
%run ./utils/__init__.py
%run ./utils/runs.py

In [18]:
%run ./models/classification/__init__.py

In [19]:
%run ./losses/__init__.py

In [20]:
%run ./metrics/classification/__init__.py

In [21]:
%run ./tensorboard.py

## Functions

In [8]:
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [9]:
def get_step_fn(model, optimizer, loss_fn, training=True, device=DEVICE):
    """Creates a step function for an Engine."""
    def step_fn(engine, data_batch):
        # Move inputs to GPU
        images = data_batch[0].to(device)
        # shape: batch_size, 3, height, width
        
        labels = data_batch[1].to(device)
        # shape: batch_size, n_labels

        # Enable training
        model.train(training)
        torch.set_grad_enabled(training) # enable recording gradients

        # zero the parameter gradients
        optimizer.zero_grad()

        # Forward
        output_tuple = model(images)
        outputs = output_tuple[0]

        # Compute classification loss
        loss = loss_fn(outputs, labels.float())
        
        batch_loss = loss.item()

        if training:
            loss.backward()
            optimizer.step()

        return batch_loss, outputs, labels

    return step_fn

In [43]:
# SEED = 10 # will be ignored in ignite v0.4.0

In [44]:
def train_model(run_name, model, train_dataloader, val_dataloader, n_epochs=1, lr=0.0001,
               loss_name='wbce', debug=True):
    # Prepare run
    run_state = RunState(run_name, classification=True, debug=debug)
    initial_epoch = run_state.current_epoch()
    if initial_epoch > 0:
        print('Found previous run on epoch: ', initial_epoch)
    
    # Prepare optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss = get_loss_function(loss_name)
    
    # Classification labels
    labels = train_dataloader.dataset.labels
    
    # TB writer
    tb_writer = TBClassificationWriter(run_name, labels=labels, debug=debug)

    # Create validator engine
    validator = Engine(get_step_fn(model, optimizer, loss, training=False))
    attach_metrics_classification(validator, labels)
    
    # Create trainer engine
    trainer = Engine(get_step_fn(model, optimizer, loss, training=True))
    attach_metrics_classification(trainer, labels)
    
    # Create Timer to measure wall time between epochs
    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_metrics(trainer):
        # Run on evaluation
        validator.run(val_dataloader, 1)

        # State
        epoch = trainer.state.epoch + initial_epoch
        max_epochs = trainer.state.max_epochs + initial_epoch
        train_metrics = trainer.state.metrics
        val_metrics = validator.state.metrics
        
        # Save state
        run_state.save_state(epoch)
        
        # Walltime
        wall_time = time.time()

        # Log to TB
        tb_writer.write_histogram(model, epoch, wall_time)
        tb_writer.write_metrics(train_metrics, 'train', epoch, wall_time)
        tb_writer.write_metrics(val_metrics, 'val', epoch, wall_time)
        
        # Print results
        train_loss = train_metrics.get('loss', -1)
        val_loss = val_metrics.get('loss', -1)
        
        loss_str = f'loss {train_loss:.4f}, {val_loss:.4f}'
        duration_str = duration_to_str(timer._elapsed())
        print(f'Finished epoch {epoch}/{max_epochs}, {loss_str} (took {duration_str})')

    # Train!
    print('-' * 50)
    print('Training...')
    trainer.run(train_dataloader, n_epochs)
    
    # Capture time per epoch
    secs_per_epoch = timer.value()
    duration_per_epoch = duration_to_str(secs_per_epoch)
    print('Average time per epoch: ', duration_per_epoch)
    print('-'*50)

    tb_writer.close()

    return

## Load stuff

In [11]:
train_dataloader = prepare_data_classification(dataset_type='train', max_images=100)
val_dataloader = prepare_data_classification(dataset_type='val', max_images=100)
train_dataloader.dataset.size()

Loading train dataset...
Loading val dataset...


(100, 14)

In [28]:
model = init_empty_model('resnet', train_dataloader.dataset.labels).to(DEVICE)

## Train

In [48]:
train_model('run_name_6', model, train_dataloader, val_dataloader, n_epochs=3, loss_name='wbce')

Found previous run on epoch:  6
--------------------------------------------------
Training...
Finished epoch 7/9, loss 0.0268, 722.7981 (took 0h 0m 8s)
Finished epoch 8/9, loss 0.0065, 877.6345 (took 0h 0m 8s)
Finished epoch 9/9, loss 0.1154, 1115.6416 (took 0h 0m 8s)
Average time per epoch:  0h 0m 8s
--------------------------------------------------
