In [1]:
import sys
import os

import torch 
from torch import nn
from torch.nn import functional as F

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

sys.path.append('../..')
from lib.schedulers import DelayedScheduler
from lib.dataloaders import PatchesDataset
from lib.datasets import (actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset
from lib.trainers import PatchesModule
from lib.augmentations import augment_empty, augment_v1
from lib.losses import SmoothLoss

In [2]:
class MyModule(PatchesModule):
    def __init__(self, model, hparams, num_workers=0, log_train_every_batch=False):
        super().__init__(model, hparams, num_workers, log_train_every_batch)
        self.hparams = hparams
        self.model = model
        self.num_workers = num_workers
        self.log_train_every_batch = log_train_every_batch
        
        self.dec_loss = nn.MSELoss()
        self.mask_loss = SmoothLoss(nn.KLDivLoss(), smoothing=0.2, one_hot_target=True)
        self.lbl_loss = SmoothLoss(nn.KLDivLoss(), smoothing=0.2, one_hot_target=False)

        self.loss_weights = {
            'dec': 1/3,
            'mask': 1/3,
            'label': 1/3
        }

    def _loss(self, o_masks, o_labels, o_imgs, masks, labels, imgs, n_imgs, provider):
        pr0_idx = provider == 0
        pr1_idx = provider == 1

        dec_loss = self.dec_loss(o_imgs, imgs)
        
        tmp = masks[pr1_idx]-6
        if tmp.max() > 3 or tmp.min() < 0:
            assert f"WTF {tmp.max()}, {tmp.min()} !!!"

        mask_loss = (self.mask_loss(o_masks[pr0_idx, :6], masks[pr0_idx]) +
            self.mask_loss(o_masks[pr1_idx, -3:], masks[pr1_idx]-6))

        label_loss = (self.lbl_loss(o_labels[pr0_idx, :6], labels[pr0_idx, :6]) +
            self.lbl_loss(o_labels[pr1_idx, -3:], labels[pr1_idx, -3:]))

        loss = (self.loss_weights['dec'] * dec_loss +
                self.loss_weights['mask'] * mask_loss +
                self.loss_weights['label'] * label_loss)

        return loss, dec_loss, mask_loss, label_loss
    
    def step(self, batch, batch_idx, is_train):
        imgs, masks, labels, provider, isup_grade, g_score = batch
        b = imgs.shape[0]
        n_imgs = imgs - self.rgb_mean / self.rgb_std
                
        o_masks, o_labels, o_imgs = self(n_imgs)
        
        loss, dec_loss, mask_loss, label_loss =\
            self._loss(o_masks, o_labels, o_imgs, masks, labels, imgs, n_imgs, provider)

        lbl_acc = self._accuracy(o_labels, labels.argmax(dim=1))
        mask_acc = self._accuracy(o_masks.view(b, actual_lbl_nums, -1), 
                                  masks.view(b, -1))
        
        lr = self.optimizer.param_groups[0]['lr']
        
        pr = '' if is_train else 'val_'
        
        log_dict = {
            pr+'loss': loss.item(), 
            pr+'dec_loss': dec_loss.item(), 
            pr+'mask_loss': mask_loss.item(), 
            pr+'label_loss': label_loss.item(),
            pr+'lbl_acc': lbl_acc.item(),
            pr+'mask_acc': mask_acc.item(),
            pr+'lr': lr
        }
        
        if is_train and self.log_train_every_batch:
            return {'loss': loss, 'log': log_dict}
        else:
            return log_dict    
    
    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, True)

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, False)
    
    def prepare_data(self):
        self.rgb_mean, self.rgb_std = (torch.tensor(patches_rgb_mean_av1, dtype=torch.float32), 
                                       torch.tensor(patches_rgb_std_av1, dtype=torch.float32))
        self.train_img_ids, self.test_img_ids = get_train_test_img_ids_split()
        
    def _apply(self, fn):
        self = super()._apply(fn)
        self.rgb_mean = fn(self.rgb_mean)
        self.rgb_std = fn(self.rgb_std)
        
        return self

    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(
            PatchesDataset(self.train_img_ids, transform=augment_v1, 
                           scale=self.hparams['dataset']['scale'], load_masks=True),
            batch_size=self.hparams['batch_size'], shuffle=True,
            num_workers=self.num_workers, pin_memory=True,
        )
        
        self.steps_in_batch = len(train_loader)
        
        return train_loader

    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(
            PatchesDataset(self.test_img_ids, transform=augment_empty, 
                           scale=self.hparams['dataset']['scale'], load_masks=True),
            batch_size=self.hparams['batch_size'], shuffle=False, 
            num_workers=self.num_workers, pin_memory=True,
        )
        
        return val_loader

In [3]:
import copy
from segmentation_models_pytorch.encoders import get_encoder, get_encoder_names, encoders
from segmentation_models_pytorch.unet import Unet
import segmentation_models_pytorch

from segmentation_models_pytorch.unet.decoder import DecoderBlock

In [4]:
# activation для масок не забыть!!!
model = Unet('resnet18', encoder_weights=None, 
             activation='logsoftmax',
             classes=actual_lbl_nums, 
             aux_params={
                 'classes': actual_lbl_nums, 
                 'activation': 'logsoftmax'})

In [5]:
class AutoDecoder(nn.Module):
    def __init__(
            self,
            channels,
            use_batchnorm=True,
            attention_type=None,
    ):
        super().__init__()

        in_channels = channels[:0:-1]
        out_channels = channels[-2::-1]        
        
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)        
        blocks = [
            DecoderBlock(in_ch, 0, out_ch, **kwargs)
            for in_ch, out_ch in zip(in_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, features):
        x = features
        
        for i, decoder_block in enumerate(self.blocks):
            x = decoder_block(x)
            
        x = torch.sigmoid(x)
        # x = torch.tanh(x)

        return x


import types

def forward(self, x):
    """Sequentially pass `x` trough model`s encoder, decoder and heads"""
    features = self.encoder(x)
    decoder_output = self.decoder(*features)
    masks = self.segmentation_head(decoder_output)
    
    out = masks

    if self.classification_head is not None:
        labels = self.classification_head(features[-1])
        out = (masks, labels)
    
    if self.autodecoder is not None:
        decoded = self.autodecoder(features[-1])
        out = (masks, labels, decoded)

    return out

model.forward = types.MethodType(forward, model)
channels = model.encoder.out_channels
model.autodecoder = AutoDecoder(channels)

In [6]:
sum([p.data.numel() for p in model.parameters()])

16734015

In [7]:
sum([p.data.numel() for p in model.autodecoder.parameters()])

2400029

In [None]:
DelayedScheduler

In [8]:
batches_in_epoch = 34362

epochs = 2
# warmup_epochs = 1
warmup_epochs = 0
warmup_steps = 3000
batch_size = 58
hparams = {
    'batch_size': batch_size,
    'learning_rate': 0.1 * batch_size / 256,
    'dataset': {'scale': 0.5},
    'optimizer': {
        'name': 'SGD', 
        'params': {
            'momentum': 0.9, 
            'weight_decay': 1e-4
        }
    },
    'scheduler': {
        'name': 'CosineAnnealingLR',
        'params': {
            # 'T_max': (epochs-warmup_epochs) * batches_in_epoch
            'T_max': epochs * batches_in_epoch - warmup_steps
        },
        'interval': 'step'
    },
    # 'warmup_epochs': warmup_epochs,
    'warmup_steps': warmup_steps,
    'steps_in_batch': batches_in_epoch,
    'epochs': epochs
}
module = MyModule(model, hparams, num_workers=12, log_train_every_batch=True)

In [9]:
logger = TensorBoardLogger(
    os.getcwd(), 'Patches256TestRun',
)
# accumulate_grad_batches=1
trainer = Trainer(logger, max_epochs=epochs, gpus=[0,], fast_dev_run=False)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]


In [1]:
# Start tensorboard.
#%load_ext tensorboard
#%tensorboard --logdir ./Patches256TestRun

In [11]:
g_batch = 123

In [20]:
trainer.fit(module)
trainer.save_checkpoint(os.path.join(trainer.checkpoint_callback.dirpath, "last.ckpt"))

INFO:lightning:
    | Name                                            | Type               | Params
-----------------------------------------------------------------------------------
0   | model                                           | Unet               | 16 M  
1   | model.encoder                                   | ResNetEncoder      | 11 M  
2   | model.encoder.conv1                             | Conv2d             | 9 K   
3   | model.encoder.bn1                               | BatchNorm2d        | 128   
4   | model.encoder.relu                              | ReLU               | 0     
5   | model.encoder.maxpool                           | MaxPool2d          | 0     
6   | model.encoder.layer1                            | Sequential         | 147 K 
7   | model.encoder.layer1.0                          | BasicBlock         | 73 K  
8   | model.encoder.layer1.0.conv1                    | Conv2d             | 36 K  
9   | model.encoder.layer1.0.bn1                      | Batc

[{'interval': 'step', 'scheduler': <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f939c20db10>}]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  return self.activation(x)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

RuntimeError: CUDA out of memory. Tried to allocate 960.00 MiB (GPU 0; 10.91 GiB total capacity; 8.44 GiB already allocated; 852.38 MiB free; 9.54 GiB reserved in total by PyTorch)

In [None]:
#!rm -r lightning_logs

In [None]:
# trainer.test()

In [None]:
# from sklearn import metrics
# metrics.cohen_kappa_score(y, X_p, weights='quadratic')

In [None]:
'''
@dataclass
class KappaScore(ConfusionMatrix):
    "Computes the rate of agreement (Cohens Kappa)."
    weights:Optional[str]=None      # None, `linear`, or `quadratic`

    def on_epoch_end(self, last_metrics, **kwargs):
        sum0 = self.cm.sum(dim=0)
        sum1 = self.cm.sum(dim=1)
        expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
        if self.weights is None:
            w = torch.ones((self.n_classes, self.n_classes))
            w[self.x, self.x] = 0
        elif self.weights == "linear" or self.weights == "quadratic":
            w = torch.zeros((self.n_classes, self.n_classes))
            w += torch.arange(self.n_classes, dtype=torch.float)
            w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
        else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
        k = torch.sum(w * self.cm) / torch.sum(w * expected)
        return add_metrics(last_metrics, 1-k)
        
        
        #self.optimizer = optim.SGD(
        #    self.parameters(),
        #    lr=self.hparams['learning_rate'],
        #    momentum=self.hparams['momentum'],
        #    weight_decay=self.hparams['weight_decay']
        #)
        
        self.optimizer = optim.Adam(
            self.parameters(),
            lr=self.hparams['learning_rate'],
            weight_decay=self.hparams['weight_decay']
        )        
        self.scheduler = DelayedScheduler(self.optimizer, self.hparams['sched_warmup_epoch'], 
            torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 
                                                       self.hparams['epochs']-self.hparams['sched_warmup_epoch']))
        #self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, 
        #                                                      milestones=[100, 150], 
        #                                                      last_epoch=-1)
        #self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, 
        #                                                        gamma=self.hparams['lr_shed_gamma']) 
''';