# Семантическая сегментация птиц
![birds_semantic_seg.jpg](birds_semantic_seg.jpg)

Необходимо реализовать бинарную сегментацию изображений птиц.

Классификация пикселей осуществляется на базе [Lite R-ASPP](https://pytorch.org/vision/main/models/generated/torchvision.models.segmentation.lraspp_mobilenet_v3_large.html). Исходная модель тренируется предсказывать 21 класс из датасета COCO вместе с фоном, нам нужен только `classes[3]=='bird'`, поэтому 
от свёрток в голове `low_classifier` и `high_classifier` оставляем только веса соответствующие этому классу.  
При этом незамороженную часть скелета `ResNet` тренируем с `lr` меньше, чем у классификатора, чтобы не потерять выученные репрезентации и сконцентрироваться на `classifier`.  


Ошибка, оптимизируемая сетью - $\alpha\mathrm{DiceLoss}*\mathrm{size}(batch) + (1-\alpha)\mathrm{CE}, \alpha=0.1$, позволяющая совместить эффективность $\mathrm{DiceLoss}$ при дисбалансе классов (в бинарной маске) и хорошие качества сходимости $\mathrm{CE}$ (не идеально, но для учебного примера подходит)

Структура хранения данных - в папке `image_folder` находятся изображения птиц, разбитые по папкам-классам, в `gt_folder` лежат ground truth маски сегментации:  
<ul>
    <li> image_folder </li>
    <ul>
        <li> bird_type_1 </li>
        <ul>
            <li> image_1.jpg </li>
            <li> image_2.jpg </li>
            <li> ... </li>
        </ul>
        <li> bird_type_2 </li>
        <li> ... </li>
    </ul>
    <li> gt_folder </li>
    <ul>
        <li> bird_type_1 </li>
        <ul>
            <li> image_1.png </li>
            <li> image_2.png </li>
            <li> ... </li>
        </ul>
        <li> bird_type_2 </li>
        <li> ... </li>
    </ul>
</ul>  

Для обучения делается `stratified split` по классам птиц, чтобы не дать сети переобучиться на маджорити классах.

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
import torchvision
import albumentations as A
import numpy as np
from sklearn.model_selection import train_test_split

from torchvision.models.segmentation import lraspp_mobilenet_v3_large


import os
from tqdm import tqdm

from pathlib import Path
import cv2



EPOCHS_NUM = 40
BATCH_SIZE = 64
TEST_SIZE = 0.2



def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=1)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=1) +\
                                                 target.sum(dim=2).sum(dim=1) + smooth)))

    return loss.mean()


MEAN=[0.485, 0.456, 0.406]
STD=[0.229, 0.224, 0.225]

DEFAULT_TRAIN_TRANSFORM = A.Compose(
    [
        A.Resize(520, 520),
        A.RandomSizedCrop([460,500], 520, 520),
        A.Rotate(20),
        A.Normalize(mean=MEAN, std=STD),
        A.ToTensorV2(),
    ]
)

DEFAULT_VAL_TRANSFORM = A.Compose(
    [
        A.Resize(520, 520),
        A.Normalize(mean=MEAN, std=STD),
        A.ToTensorV2(),
    ]
)

# torch Dataset for train and validation phases
# samples = all filenames ([..., bird_i_image_j.jpg, ...])
# transform = albumentations transformations to be applied
class BirdDataset(torch.utils.data.Dataset):
    def __init__(self, samples, image_folder, gt_folder, transform) -> None:
        super(BirdDataset, self).__init__()
        self.samples = samples
        self.image_folder = image_folder
        self.gt_folder = gt_folder
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.samples[index]

        # [:,:,::-1] to get RGB format
        img = cv2.imread(os.path.join(self.image_folder, img_path))[:,:,::-1]
        img = img.astype(np.float32)

        label = cv2.imread(os.path.join(self.gt_folder, f'{img_path[:-3]}png'))
        # we don't need to predict the same values
        if len(label.shape) == 3:
            label = label[:, :, 0]
        # for simplicity let's predict a binary mask
        label = (label > 127).astype(float)

        if self.transform is not None:
            output = self.transform(image=img, mask=label)
            img = output['image']
            label = output['mask']

        return img, label

    def __len__(self):
        return len(self.samples)


# model outputs logits for each pixel
class BirdSegmenter(pl.LightningModule):
    def __init__(self, pretrained=False):
        super().__init__()
        self.model = lraspp_mobilenet_v3_large(pretrained=pretrained, pretrained_backbone=pretrained)
        
        # we preserve only the weights associated with classes[3]=='bird'
        new_low_classifier = nn.Conv2d(40, 2, kernel_size=(1, 1), stride=(1, 1))
        new_low_classifier.weight = nn.Parameter(self.model.classifier.low_classifier.weight.data[[0, 3]])
        new_low_classifier.bias = nn.Parameter(self.model.classifier.low_classifier.bias.data[[0, 3]])
        self.model.classifier.low_classifier = new_low_classifier

        new_high_classifier = nn.Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
        new_high_classifier.weight = nn.Parameter(self.model.classifier.high_classifier.weight.data[[0, 3]])
        new_high_classifier.bias = nn.Parameter(self.model.classifier.high_classifier.bias.data[[0, 3]])
        self.model.classifier.high_classifier = new_high_classifier
        
        # we'll train only 2 InvertedResidual with Conv2dNormActivation at the end of the backbone and LRASPPHead
        layers_to_freeze = list(self.model.backbone.children())[:-3]
        for layer in layers_to_freeze:
            for param in layer.parameters():
                param.requires_grad = False
        
        # CrossEntropy weight in total loss
        self.ce_weight = 0.9
        self.criterion = nn.CrossEntropyLoss()
    
    # REQUIRED
    def forward(self, x):
        x = self.model(x)['out']
        return x
    
    
    # REQUIRED
    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.to(torch.long)

        y_logit = self(x)
        ce = self.criterion(y_logit, y)
        
        pred = F.softmax(y_logit, dim=1)[:,1,:,:]
        dice = dice_loss(pred, y)

        loss = ce * self.ce_weight + dice * (1 - self.ce_weight) * y.size(0)

        return {'loss': loss}
    
    # REQUIRED
    def configure_optimizers(self):
        # we'll teach  ResNet backbone parameters with lower lr to avoid losing efficiency of representations
        # and focus on classifier learning
        backbone_layers_to_finetune = list(self.model.backbone.children())[-3:]
        backbone_parameters_to_finetune = []
        for layer in backbone_layers_to_finetune:
            for p in layer.parameters():
                backbone_parameters_to_finetune.append(p)
        grouped_parameters = [
            {'params': backbone_parameters_to_finetune, 'lr': 1e-5},
            {'params': self.model.classifier.parameters(), 'lr': 1e-3},
        ]

        optimizer = torch.optim.AdamW(grouped_parameters)
        
        # reduce lr when learning stagnates to better find minima
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                                  mode='min', 
                                                                  factor=0.1, 
                                                                  patience=10, 
                                                                  verbose=True)
        
        lr_dict = {
            "scheduler": lr_scheduler,
            "interval": "epoch",
            "frequency": 1,
            "monitor": "val_loss"
        }
        
        return [optimizer], [lr_dict]
    
    # OPTIONAL
    def validation_step(self, batch, batch_idx):
        """the full validation loop"""
        x, y = batch
        y = y.to(torch.long)

        y_logit = self(x)
        ce = self.criterion(y_logit, y)
        
        pred = F.softmax(y_logit, dim=1)[:,1,:,:]
        dice = dice_loss(pred, y)
        
        loss = ce * self.ce_weight + dice * (1 - self.ce_weight) * y.size(0)

        return {'val_loss': loss, 'logs':{'dice':dice, 'ce': ce}}

    # OPTIONAL
    def training_epoch_end(self, outputs):
        """log and display average train loss and accuracy across epoch"""
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        
        print(f"| Train_loss: {avg_loss:.3f}" )
        self.log('train_loss', avg_loss, prog_bar=True, on_epoch=True, on_step=False)
     
    # OPTIONAL
    def validation_epoch_end(self, outputs):
        """log and display average val loss and accuracy"""
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        
        avg_dice = torch.stack([x['logs']['dice'] for x in outputs]).mean()
        avg_ce = torch.stack([x['logs']['ce'] for x in outputs]).mean()
        
        print(f"[Epoch {self.trainer.current_epoch:3}] Val_loss: {avg_loss:.3f}, Val_dice: {avg_dice:.3f}, Val_ce: {avg_ce:.3f}", end= " ")
        self.log_dict({'val_loss': avg_loss}, prog_bar=True, on_epoch=True, on_step=False)


# returns trained model (also displays metrics during the training process)
def train_model(train_data_path, pretrained=False):
    trainer = pl.Trainer(
        accelerator='auto',
        max_epochs=EPOCHS_NUM
    )
    model = BirdSegmenter(pretrained=pretrained)

    gt_folder = os.path.join(train_data_path, 'gt')
    image_folder = os.path.join(train_data_path, 'images')
    samples = []
    classes = []
    
    # read image names and remember their classes
    for path in Path(image_folder).rglob('*/*'):
      samples.append( '/'.join(path.parts[-2:]) )
      classes.append( path.parts[-2] )
        
    # stratified split to let the network properly learn underrepresented bird classes
    train_idx, valid_idx = train_test_split(
        np.arange(len(classes)),
        test_size=TEST_SIZE,
        stratify=classes
    )

    train_dataset = BirdDataset(np.array(samples)[train_idx], image_folder, gt_folder,
                                transform=DEFAULT_TRAIN_TRANSFORM)
    val_dataset = BirdDataset(np.array(samples)[valid_idx], image_folder, gt_folder,
                             transform=DEFAULT_VAL_TRANSFORM)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE)

    trainer.fit(model, train_dataloader, val_dataloader)

    model.eval()
    return model


# predict the segmentation mask of an imagae lying in img_path
def predict(model, img_path):
    model.eval()
    with torch.inference_mode():
        # save as RGB and remember original shape
        img = cv2.imread(img_path)[:,:,::-1]
        original_shape = img.shape[:2]
        
        # transform image to pass through the network
        img = DEFAULT_VAL_TRANSFORM(image=img)['image']
        batch = torch.stack([img])
        output = model(batch)
        normalized_masks = torch.nn.functional.softmax(output, dim=1)
        mask = normalized_masks[0, 1, :, :].cpu().numpy()

        # restore to original shape
        mask = A.Resize(*original_shape)(image=mask)['image']

        # make the mask have the same dimensions
        if img.shape[2] == 3:
            mask = np.dstack([mask, mask, mask])
        
        return mask