<a href="https://colab.research.google.com/github/pandov/cups-mail/blob/master/multimodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
! sh colab.sh

In [2]:
%%capture
import torch
from src.nn import BACTERIA, Runner, get_multimodel_components, dice_and_iou, score_global, score_segmentation, score_classification
logdir = './logs/multimodel'

In [None]:
%load_ext tensorboard
%tensorboard --logdir $logdir

In [None]:
components = get_multimodel_components('resnet50', 'adam', 'steplr')
criterion = components['criterion'].copy()
del components['criterion']

In [None]:
class MultiRunner(Runner):

    def _handle_batch(self, batch):
        global criterion

        x, y, z = batch
        y_pred, z_pred = self.model(x)
        is_train = self.is_train_loader == True

        with torch.set_grad_enabled(is_train):
            loss_dice = criterion['dice'](y_pred, y)
            # loss_iou = criterion['iou'](y_pred, y)
            loss_cross_entropy = criterion['crossentropy'](z_pred, z)
            loss = loss_dice + loss_cross_entropy
            # metric_dice, metric_iou = dice_and_iou(y_pred, y)
            metric_iou = score_segmentation(y_pred, y)
            metric_precision = score_classification(z_pred, z)
            metric_presicion_sum = metric_precision.sum()
            self.batch_metrics = {
                # 'Metric Dice': metric_dice,
                # 'Metric IoU': metric_iou,
                # 'Loss IoU': loss_iou,
                'metric/presicion_sum': metric_precision_sum,
                'metric/presicion_mean': metric_precision.mean(),
                'metric/iou': metric_iou,
                'loss/dice': loss_dice,
                'loss/cross_entropy': loss_cross_entropy,
                'loss': loss,
                'score': metric_iou + metric_precision_sum,
                'lr': self.state.scheduler.get_last_lr()[0],
            }

            if is_train:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.global_batch_step == self.batch_size:
                    self.scheduler.step()
        
        self.input = {'targets': z}
        self.output = {'logits': z_pred}

In [None]:
dataset = BACTERIA(keys=['image', 'mask', 'label'])
experiments = list(dataset.crossval(kfold=4, batch_size=16))
num_experiment = 0
loaders = experiments[num_experiment]
runner = MultiRunner()
runner.train(
    loaders=loaders,
    logdir=f'{logdir}/{num_experiment}',
    num_epochs=100,
    minimize_metric=False,
    main_metric='score',
    # verbose=True,
    **components
)

In [None]:
! zip -r {logdir}.zip {logdir}
! git pull origin master
! git add logs
! git commit -m 'Changed from Colab'
! git push -u origin master