<a href="https://colab.research.google.com/github/pandov/diploma/blob/main/train_catalyst.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! git config --global user.email "ipandov17@gmail.com"
! git config --global user.name "Vyacheslav Pandov"
! git clone https://pandov:yujhnm12gh@github.com/pandov/diploma.git
! pip install -r diploma/requirements.txt

In [None]:
cd diploma

In [None]:
! rm -r logs/1
! git pull

In [None]:
import torch
from catalyst import utils, dl
utils.set_global_seed(17)
from src import criterion, dataset, metric, model, utils

In [None]:
class CustomRunner(dl.Runner):
    def _handle_batch(self, batch):
        inputs = batch['images']
        targets = batch['masks']
        classes = batch['cracks']

        with torch.set_grad_enabled(self.is_train_loader):
            outputs = self.model(inputs)
            loss_dice = self.criterion['dice'](outputs, targets)
            loss_bce = self.criterion['bce'](outputs, targets)
            loss = 0.6 * loss_bce + 0.4 * loss_dice.mean()
            # loss = 0.6 * loss_bce + 0.4 * (classes * loss_dice).mean()
            if self.is_train_loader:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

            self.batch_metrics.update({
                'loss': loss,
                'bce': loss_bce,
                'dice': 1 - loss_dice.mean(),
                # 'iou': metric.iou(outputs, targets).mean(),
                'lr': self.scheduler.get_last_lr()[0],
            })

    def on_epoch_end(self, runner):
        super().on_epoch_end(runner)
        self.scheduler.step()

In [None]:
datasets = {
    'train': dataset.CracksDataset('train'),
    'valid': dataset.CracksDataset('valid'),
}
loaders = {
    'train': datasets['train'].get_loader(batch_size=15, shuffle=True, drop_last=True),
    'valid': datasets['valid'].get_loader(batch_size=15),
}
criterion = {
    'dice': criterion.DiceLoss(),
    'bce': torch.nn.BCELoss(),
}
net = model.UNet()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2, momentum=0.9, nesterov=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 32], gamma=0.1)

In [None]:
runner = CustomRunner()
runner.train(
    model=net,
    loaders=loaders,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=50,
    logdir='logs/1',
    verbose=True,
)

In [None]:
! git pull
! git add logs
! git commit -m 'Trained'
! git push