In [None]:
import sys
sys.path.append('/path/to/base/dir/BRATS2020/')

In [None]:
from time import time

In [None]:
import torch
from torch.utils.data import DataLoader

In [None]:
from probunet_multiattn.models.probabilistic_unet import ProbUNet
from probunet_multiattn.predictor import Predictor
from probunet_multiattn.data_utils.dataloader_predict import BratsDataset

In [None]:
import numpy as np
np.random.seed(42)
torch.manual_seed(42)

In [None]:
class Config:
    
    def __init__(self):
        
        self.training_path = '/path/to/train/dir'
        self.val_path = '/path/to/val/dir'
        self.test_path = '/path/to/test/samples/dir'
        
        self.original_dim = (240, 240)
        
        # Model Definition
        self.input_shape = (4, 128, 128)
        self.output_shape = (3, 128, 128)
        self.base_filters = 16 #Filters in the first block
        self.depth = 4 #Depth of the UNet
        self.nblocks = 2 #number of conv blocks at each stage
        self.zdim = 6 #dimensionality of the Gaussian
        self.nclasses = 3
        self.activation = 'relu'
        self.norm='bn'
        self.nattn_blocks = 3
        
        # Logging utils
        self.reduce_class_dice = False
        self.split_seg_loss = True
        
        self.display = 'running'
        
        # Training parameters
        self.loss = {'f': 'dice', 'args': {'smooth': 1.}}
        self.epochs = 100
        self.batch_size = 32
        self.lr = 1e-3
        self.decay_every = 5
        self.validate_every = 1
        self.checkpoint_every = 2
        self.checkpoint_path= '/path/to/checkpoints/dir'
        self.nattn_blocks = 3
        
        # Generic model settings
        self.devices = {
            'unet': 'cuda:0',
            'prior_net': 'cuda:1',
            'output': 'cuda:1'
        }
        self.checkpoints = {
            'unet': '/path/to/checkpoint/dir/unet.pth',
            'prior_net': '/path/to/checkpoint/dir/prior_net.pth',
            'posterior_net': '/path/to/checkpoint/dir/posterior_net.pth',
            'fcomb': '/path/to/checkpoint/dir/fcomb.pth',
        }
        self.output_path = '/path/to/output/dir'
        
        self.train_logdir = '/path/to/train/log.log'
        self.val_logdir = '/path/to/val/log.log'
        self.config_log = '/path/to/config/log'
config = Config()

In [None]:
dataset = BratsDataset(config.test_path)
dataloader = DataLoader(dataset, batch_size=1, 
                        shuffle=False, num_workers=4, collate_fn=dataset.collate_batch)

In [None]:
model = ProbUNet(
    input_shape=config.input_shape,
    output_shape=config.output_shape,
    depth=config.depth,
    nblocks=config.nblocks,
    nclasses=config.nclasses,
    zdim=config.zdim, 
    base_filters=config.base_filters, 
    devices=config.devices, 
    checkpoints=config.checkpoints, 
    activation=config.activation, 
    norm=config.norm, use_posterior=False, 
    nattn_blocks=config.nattn_blocks)

In [None]:
predictor = Predictor(model, config, dataloader, save=True)

In [None]:
start = time()
predictor.predict(visualize=False, nsamples=10, reduce='mean')
end = time()
print('Finished in {} minutes.'.format((end - start)/60))