<a href="https://colab.research.google.com/github/pandov/diploma/blob/main/train_catalyst.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! git clone https://pandov:yujhnm12gh@github.com/pandov/diploma.git
! pip install -r diploma/requirements.txt

In [None]:
cd diploma

In [2]:
! rm -r logs

In [3]:
import torch
from catalyst import utils, dl
utils.set_global_seed(17)
from src import criterion, dataset, metric, model, utils

In [4]:
class CustomRunner(dl.Runner):
    def _handle_batch(self, batch):
        inputs = batch['images']
        targets = batch['masks']
        classes = batch['cracks']

        with torch.set_grad_enabled(self.is_train_loader):
            outputs = self.model(inputs)
            loss_dice = self.criterion['dice'](outputs, targets)
            loss_bce = self.criterion['bce'](outputs, targets)
            loss = loss_bce + (classes * loss_dice).mean()
            if self.is_train_loader:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.train_batch_step == self.train_len:
                    self.scheduler.step()
                    print('scheduler.step')

            self.batch_metrics.update({
                'loss': loss,
                'bce': loss_bce,
                'dice': 1 - loss_dice.mean(),
                # 'iou': metric.iou(outputs, targets).mean(),
                'lr': self.scheduler.get_last_lr()[0],
            })

In [5]:
datasets = {
    'train': dataset.CracksDataset('train'),
    'valid': dataset.CracksDataset('valid'),
}
loaders = {
    'train': datasets['train'].get_loader(batch_size=10, shuffle=True, drop_last=True),
    'valid': datasets['valid'].get_loader(batch_size=10),
}
criterion = {
    'dice': criterion.DiceLoss(),
    'bce': torch.nn.BCELoss(),
}
net = model.UNet.load_from_torch_hub()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 20], gamma=0.1)

In [None]:
runner = CustomRunner()
runner.train(
    model=net,
    loaders=loaders,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=30,
    verbose=True,
    logdir='logs/'
)