<a href="https://colab.research.google.com/github/pandov/cups-mail/blob/master/segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
! sh colab.sh

In [None]:
%%capture
import torch
from catalyst.utils import metrics, set_global_seed, prepare_cudnn
from catalyst.dl import Runner
from src.nn import BACTERIA, DiceLoss, get_model
prepare_cudnn(deterministic=True)
set_global_seed(7)

In [None]:
criterion = DiceLoss()

class CustomRunner(Runner):

    def _handle_batch(self, batch):
        global criterion
        x, y, z = batch
        is_train = self.state.stage_name == 'train'
        with torch.set_grad_enabled(is_train):
            y_pred = self.model(x)
            loss = criterion(y_pred, y)
            metric = metrics.iou(y_pred, y)
            self.state.batch_metrics = {
                'loss': loss,
                'iou': metric,
            }
            if is_train:
                loss.backward()
                self.state.optimizer.step()
                self.state.optimizer.zero_grad()

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
torch.cuda.empty_cache()
model = get_model()
model.train()
dataset = BACTERIA()
loaders = next(dataset.crossval(kfold=4, batch_size=9))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
runner = CustomRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    num_epochs=100,
    logdir='./logs/segmentation',
    # verbose=True,
)

In [None]:
! zip -r logs/segmentation.zip logs/segmentation
! git add src logs
! git commit -m 'Changed from Colab'
! git push -u origin master