In [None]:
import sys
sys.path.append('/path/to/base/dir/BRATS2020/')

In [None]:
import torch
from torch.utils.data import DataLoader

In [None]:
from probunet_multiattn.models.probabilistic_unet import ProbUNet
from probunet_multiattn.trainer import Trainer
from probunet_multiattn.data_utils.dataloader import BratsDataset

In [None]:
import numpy as np
np.random.seed(42)
torch.manual_seed(42)

In [None]:
class Config:
    
    def __init__(self):
        
        self.training_path = '/path/to/train/dir'
        self.val_path = '/path/to/val/dir'
        
        # Model Definition
        self.input_shape = (4, 128, 128)
        self.output_shape = (3, 128, 128)
        self.base_filters = 16 #Filters in the first block
        self.depth = 4 #Depth of the UNet
        self.nblocks = 2 #number of conv blocks at each stage
        self.zdim = 6 #dimensionality of the Gaussian
        self.nclasses = 3
        self.activation = 'relu'
        self.norm='bn'
        self.nattn_blocks = 3
        
        # Logging utils
        self.reduce_class_dice = False
        self.split_seg_loss = True
        
        self.display = 'running'
        
        # Training parameters
        self.loss = {'f': 'dice', 'args': {'smooth': 1.}}
        self.epochs = 100
        self.batch_size = 32
        self.lr = 1e-3
        self.decay_every = 5
        self.validate_every = 1
        self.checkpoint_every = 2
        self.checkpoint_path= '/path/to/checkpoints/dir'
        
        # Generic model settings
        self.devices = {
            'unet': 'cuda:0',
            'prior_net': 'cuda:1',
            'posterior_net': 'cuda:2',
            'output': 'cuda:2'
        }
        self.checkpoints = None
        
        self.train_logdir = '/path/to/log/dir/train.log'
        self.val_logdir = '/path/to/log/dir/val.log'
        self.config_log = '/path/to/log/dir/'
config = Config()

In [None]:
train_dataset = BratsDataset(config.training_path, phase='test', rotate=45., 
                             hflip=True, vflip=True)
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, 
                              shuffle=True, num_workers=1, collate_fn=train_dataset.collate_batch)

In [None]:
val_dataset = BratsDataset(config.val_path, phase='test')
val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, 
                              shuffle=True, num_workers=1, collate_fn=val_dataset.collate_batch)

In [None]:
model = ProbUNet(
    input_shape=config.input_shape,
    output_shape=config.output_shape,
    depth=config.depth,
    nblocks=config.nblocks,
    nclasses=config.nclasses,
    zdim=config.zdim, 
    base_filters=config.base_filters, 
    devices=config.devices, 
    checkpoints=config.checkpoints, 
    activation=config.activation, 
    norm=config.norm, 
    nattn_blocks=config.nattn_blocks)

In [None]:
trainer = Trainer(model, config, train_dataloader, val_dataloader=val_dataloader)
trainer.model.devices

In [None]:
trainer.train()