## Imports

In [1]:
import torch
from torch import nn

In [4]:
%env CUDA_VISIBLE_DEVICES=0,1

env: CUDA_VISIBLE_DEVICES=0,1


In [5]:
DEVICE = torch.device('cuda')
DEVICE

device(type='cuda')

## Load data

In [6]:
%run ../datasets/__init__.py

In [7]:
BS = 20

train_dataloader = prepare_data_segmentation('jsrt', 'train', batch_size=BS)
val_dataloader = prepare_data_segmentation('jsrt', 'val', batch_size=BS)
len(train_dataloader.dataset), len(val_dataloader.dataset)

(124, 61)

## Create model

In [8]:
from torch import optim

In [9]:
%run ../models/checkpoint/__init__.py
%run ../models/segmentation/scan.py

In [10]:
model = ScanFCN().to(DEVICE)
# model

In [11]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [12]:
compiled_model = CompiledModel(model, optimizer, {})

## Train model

In [28]:
%run -n ../train_segmentation.py

In [29]:
run_name = 'debugging'

In [30]:
loss_weights = [0.1, 0.3, 0.3, 0.3]

In [31]:
train_metrics, val_metrics = train_model(run_name,
                                         compiled_model,
                                         train_dataloader,
                                         val_dataloader,
                                         n_epochs=2,
                                         loss_weights=loss_weights,
                                         print_metrics=None,
                                         debug=True,
                                         device=DEVICE,
                                        )

INFO(11-06 12:58) Training run: debugging
INFO(11-06 12:58) ---------------------------------------------------
INFO(11-06 12:58) Training...
INFO(11-06 12:58) Finished epoch 1/2,  loss 1.386 1.386, iou 0.095 0.055, dice 0.168 0.097, 0h 0m 8s
INFO(11-06 12:58) Finished epoch 2/2,  loss 1.385 1.386, iou 0.117 0.049, dice 0.201 0.085, 0h 0m 8s
INFO(11-06 12:58) Average time per epoch: 0h 0m 8s
INFO(11-06 12:58) Finished training: debugging
INFO(11-06 12:58) --------------------------------------------------


In [32]:
train_metrics, val_metrics

({'loss': 1.3845785856246948,
  'iou-background': 0.19311708211898804,
  'iou-heart': 0.08055438846349716,
  'iou-right lung': 0.17928826808929443,
  'iou-left lung': 0.015760594978928566,
  'iou': 0.11718008667230606,
  'dice-background': 0.3225189745426178,
  'dice-heart': 0.1487397402524948,
  'dice-right lung': 0.30328142642974854,
  'dice-left lung': 0.030922692269086838,
  'dice': 0.20136570930480957},
 {'loss': 1.3862011432647705,
  'iou-background': 0.0,
  'iou-heart': 0.023369964212179184,
  'iou-right lung': 0.1734345406293869,
  'iou-left lung': 0.0,
  'iou': 0.049201127141714096,
  'dice-background': 0.0,
  'dice-heart': 0.04567057639360428,
  'dice-right lung': 0.2947053015232086,
  'dice-left lung': 0.0,
  'dice': 0.08509396761655807})