<a href="https://colab.research.google.com/github/pandov/cups-mail/blob/master/multimodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
! sh colab.sh

In [None]:
%%capture
import torch
from src.nn import BACTERIA, Runner, get_multimodel_components, iou_metric
logdir = './logs/multimodel'

In [None]:
%load_ext tensorboard
%tensorboard --logdir $logdir

In [None]:
components = get_multimodel_components('resnet50', 'adam', 'steplr')
criterion = components['criterion'].copy()
callbacks = components['callbacks'].copy()
del components['criterion']

In [None]:
class MultiRunner(Runner):

    def _handle_batch(self, batch):
        global criterion

        x, y, z = batch
        y_pred, z_pred = self.model(x)
        is_train = self.is_train_loader == True

        with torch.set_grad_enabled(is_train):
            loss_dice = criterion['dice'](y_pred, y)
            loss_crossentropy = criterion['crossentropy'](z_pred, z)
            loss = loss_dice + loss_crossentropy
            iou = iou_metric(y_pred, y)
            self.state.batch_metrics.update({
                'IoU': iou,
                'Dice': loss_dice,
                'CrossEntropy': loss_crossentropy,
                'loss': loss,
            })

            if is_train:
                loss.backward()
                self.state.optimizer.step()
                self.state.optimizer.zero_grad()
        
        self.input = {'targets': z}
        self.output = {'logits': z_pred}

In [None]:
dataset = BACTERIA(keys=['image', 'mask', 'label'])
experiments = list(dataset.crossval(kfold=4, batch_size=16))
num_experiment = 0
loaders = experiments[num_experiment]
runner = MultiRunner()
runner.train(
    loaders=loaders,
    logdir=f'{logdir}/{num_experiment}',
    num_epochs=150,
    # verbose=True,
    **components
)

In [None]:
! zip -r {logdir}.zip {logdir}
! git pull origin master
! git add logs
! git commit -m 'Changed from Colab'
! git push -u origin master