## Imports

In [1]:
import torch
from torch import nn

## Load data

In [2]:
from torch.utils.data import DataLoader

In [3]:
%run ../datasets/jsrt.py

In [37]:
_SEG_DATASETS = {
    'jsrt': JSRTDataset,
}

In [38]:
def prepare_data_segmentation(dataset_name=None,
                              dataset_type='train',
                              image_size=(512, 512),
                              batch_size=10,
                              shuffle=False,
                              num_workers=2,
                             ):
    assert dataset_name in _SEG_DATASETS, f'Dataset not found: {dataset_name}'
    DatasetClass = _SEG_DATASETS[dataset_name]
    
    dataset = DatasetClass(dataset_type=dataset_type,
                           image_size=image_size,
                          )

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            num_workers=num_workers,
                           )
    return dataloader

In [40]:
BS = 20

train_dataloader = prepare_data_segmentation('jsrt', 'train', batch_size=BS)
val_dataloader = prepare_data_segmentation('jsrt', 'val', batch_size=BS)
len(train_dataloader.dataset), len(val_dataloader.dataset)

(124, 61)

## Create model

In [41]:
%run ../models/segmentation/scan.py

In [42]:
model = ScanFCN()
# model

## Train model

In [43]:
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage
from torch import optim
from ignite.utils import to_onehot

In [44]:
def get_step_fn(model, optimizer=None, training=False, device='cuda'):
    criterion = nn.CrossEntropyLoss()

    def step_fn(engine, batch):
        images = batch.image.to(device) # shape: batch_size, 1, height, width
        masks = batch.masks.to(device) # shape: batch_size, height, width
        
        # Enable training
        model.train(training)
        torch.set_grad_enabled(training)

        if training:
            optimizer.zero_grad()
        
        # Pass thru model
        output = model(images)
        # shape: batch_size, n_labels, height, width
        
        loss = criterion(output, masks)
        batch_loss = loss.item()
        
        if training:
            loss.backward()
            optimizer.step()
            
        return {
            'loss': batch_loss,
            'activations': output,
            'gt_map': masks,
        }

    return step_fn

In [45]:
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

In [53]:
model = model.to(DEVICE)

In [54]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [55]:
trainer = Engine(get_step_fn(model, optimizer, training=True, device=DEVICE))

In [56]:
validator = Engine(get_step_fn(model, training=False, device=DEVICE))

In [59]:
%run ../metrics/segmentation/__init__.py

In [60]:
attach_metrics_segmentation(trainer, dataset.seg_labels, multilabel=False)
attach_metrics_segmentation(validator, dataset.seg_labels, multilabel=False)

In [61]:
print_metrics = ['loss', 'iou', 'dice']

@trainer.on(Events.EPOCH_COMPLETED)
def log_metrics(trainer):
    # Run on validation
    validator.run(val_dataloader, 1)
    
    # State
    epoch = trainer.state.epoch
    max_epochs = trainer.state.max_epochs
    train_metrics = trainer.state.metrics
    val_metrics = validator.state.metrics

    metrics_str = ''
    for metric in print_metrics:
        train_value = train_metrics.get(metric, -1)
        val_value = val_metrics.get(metric, -1)
        metrics_str += f' {metric} {train_value:.3f} {val_value:.3f},'
    
    print(f'Epoch {epoch}/{max_epochs}, {metrics_str}')

In [62]:
trainer.run(dataloader, 4)

Epoch 1/4,  loss 1.383 1.382, iou 0.115 0.058, dice 0.092 0.051,
Epoch 2/4,  loss 1.376 1.382, iou 0.178 0.065, dice 0.132 0.057,
Epoch 3/4,  loss 1.363 1.383, iou 0.249 0.058, dice 0.171 0.052,
Epoch 4/4,  loss 1.346 1.378, iou 0.296 0.128, dice 0.195 0.108,


State:
	iteration: 40
	epoch: 4
	epoch_length: 10
	max_epochs: 4
	output: <class 'dict'>
	batch: <class 'medai.datasets.common.BatchItem'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [64]:
validator.state.metrics

{'loss': 1.3776187896728516,
 'iou-background': 0.24102243781089783,
 'iou-heart': 0.002288548741489649,
 'iou-right lung': 0.1287536919116974,
 'iou-left lung': 0.1393049657344818,
 'iou': 0.12784241139888763,
 'dice-background': 0.19315950572490692,
 'dice-heart': 0.0022826166823506355,
 'dice-right lung': 0.11386360228061676,
 'dice-left lung': 0.12189631164073944,
 'dice': 0.10780051350593567}