In [3]:
! pip install segmentation_models_pytorch

In [5]:
from os import path
import glob
import os
import cv2
import shutil
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

import pytorch_lightning as pl
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchmetrics import Dice
from torchmetrics import MetricCollection

import wandb
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import Tuple, List

In [None]:
wandb_logger = WandbLogger(project="RZD", name="v0_unetpp", log_model="all")

In [8]:
BATCH_SIZE = 8
NUM_WORKERS = 8
LOSS = "dice"
OPTIMIZER = "Adam"
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-6
SCHEDULER = None
MIN_LR = 1e-6

FAST_DEV_RUN = False # Debug training
GPUS = 1
MAX_EPOCHS = 3

CLASSES = {0:'background', 7: 'railway', 6: 'other railways', 10: 'trains'}
MAP_MASKS = {7: 1, 6: 2, 10: 3, 0: 0}
MAP_SUBMIT = {1: 7, 2: 6, 3: 10, 0: 0}

TRAIN_DATASET_PATH = '../input/train-dataset/train_dataset_train/train'
TRAIN_PATH = {'images': path.join(TRAIN_DATASET_PATH, 'images'), 'mask': path.join(TRAIN_DATASET_PATH, 'mask')}
ALL_MASKS = glob.glob(path.join(TRAIN_PATH['mask'], '*.png'))
ALL_IMAGES = glob.glob(path.join(TRAIN_PATH['images'], '*.png'))

PATH_TEST = '../input/train-dataset/test_dataset_test'
TEST_IMAGES = glob.glob(path.join(PATH_TEST, '*.png'))


LOSS_FNS = {
    "bce": smp.losses.SoftBCEWithLogitsLoss(),
    "dice": smp.losses.DiceLoss(mode="multiclass"),
    "focal": smp.losses.FocalLoss(mode="multiclass"),
    "jaccard": smp.losses.JaccardLoss(mode="multiclass"),
    "lovasz": smp.losses.LovaszLoss(mode="multiclass"),
    "tversky": smp.losses.TverskyLoss(mode="multiclass"),}

In [11]:
class RZDDataset(Dataset):
    def __init__(self, image_paths: List[Path] = ALL_IMAGES, mask_paths: List[Path] = ALL_MASKS, transforms: Callable = None):        
        self.image_paths = image_paths

        self.mask_paths = mask_paths

        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = self._load_image(image_path)
        mask = self._load_mask(mask_path)
        if self.transforms is not None:
            data = self.transforms(image=image, mask=mask)
            image, mask = data["image"], data["mask"]

        return image, mask

    @staticmethod
    def _load_image(image_path: Path) -> np.ndarray:
        return cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)

    @staticmethod
    def _load_mask(mask_path: Path) -> np.ndarray:
        transorm_mask = np.vectorize(lambda x: MAP_MASKS[x])
        return transorm_mask(cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE))

In [12]:
ds = RZDDataset()

In [13]:
def show_examples(name: str, pair: np.array):
    plt.figure(figsize=(10, 14))
    plt.subplot(1, 2, 1)
    plt.imshow(pair[1])
    plt.title(f"Image: {name}")

    plt.subplot(1, 2, 2)
    plt.imshow(pair[0])
    plt.title(f"Mask: {name}")

In [14]:
show_examples('train', ds[34])

In [17]:
class RZDDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset = RZDDataset,
        all_images: List[Path] = ALL_IMAGES,
        all_masks: List[Path] = ALL_MASKS,
        train_size_coef: int = 0.8,
        batch_size: int = 8,
        num_workers: int = 2,
        input_shape: Tuple[int, int] = (512, 512)
    ):
        super().__init__()
        
        self.dataset = dataset
        self.all_images = all_images
        self.all_masks = all_masks
        self.save_hyperparameters()

        self.train_transforms, self.val_transforms = self._init_transforms()

    def _init_transforms(self) -> Tuple[Callable, Callable]:
        train_transforms = [
            A.Resize(*self.hparams.input_shape),
            A.augmentations.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ToTensorV2(),
            
        ]

        val_transforms = [
            A.Resize(*self.hparams.input_shape),
            A.augmentations.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ToTensorV2(),
        ]

        return A.Compose(train_transforms), A.Compose(val_transforms)

    def setup(self, stage=None):
        images_train, images_val, masks_train, masks_val = train_test_split(self.all_images, self.all_masks, train_size=self.hparams.train_size_coef)
        self.train_dataset = self.dataset(images_train, masks_train, self.train_transforms)
        self.val_dataset = self.dataset(images_val, masks_val, self.val_transforms)

    def train_dataloader(self):
        return self._dataloader(self.train_dataset)

    def val_dataloader(self):
        return self._dataloader(self.val_dataset)

    def _dataloader(self, dataset: RZDDataset) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
        )

In [18]:
def show_batch():
    nrows = 3
    ncols = 3
    batch_size = nrows * ncols
    data_module = RZDDataModule(batch_size=batch_size)
    data_module.setup()
    data_loader = data_module.train_dataloader()

    images, masks = next(iter(data_loader))

    fig, _ = plt.subplots(figsize=(10, 10))
    for i, (image, mask) in enumerate(zip(images, masks)):
        plt.subplot(nrows, ncols, i + 1)
        plt.tight_layout()
        plt.axis('off')

        image = image.permute(1, 2, 0).numpy()
        mask = mask.numpy()

        print(image.shape, image.min(), image.max(), image.mean(), image.std())
        print(mask.shape, mask.min(), mask.max(), mask.mean(), mask.std())

        plt.imshow(image)
        plt.imshow(mask, alpha=0.2)

In [19]:
show_batch()

In [46]:
def test_model_and_loss():
    model = smp.UnetPlusPlus(
                        encoder_name='resnet34', 
                        encoder_depth=5, 
                        encoder_weights=None,
                        decoder_channels=(512, 256, 128, 64, 16),
                        encoder_weights=None,
                        in_channels=3, 
                        classes=4, 
                        activation='sigmoid'
                    )
    data_module = RZDDataModule(batch_size=4)
    data_module.setup()
    data_loader = data_module.train_dataloader()
    images, masks = next(iter(data_loader))
    y_hat = model(images)
    bce_loss = LOSS_FNS['bce']
    dice_loss = LOSS_FNS['dice']
    print(dice_loss(y_hat, masks.type(torch.int64)))

In [47]:
test_model_and_loss()

In [20]:
class RZDModel(pl.LightningModule):
    def __init__(
        self,
        loss: str,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        scheduler: str,
        T_max: int,
        T_0: int,
        min_lr: int,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = self._init_model()

        self.loss_fn = self._init_loss_fn()

#         self.metrics = self._init_metrics()

    def _init_model(self) -> nn.Module:
        return smp.UnetPlusPlus(
                    encoder_name='resnet34', 
                    encoder_depth=5, 
                    encoder_weights=None,
                    decoder_channels=(512, 256, 128, 64, 16),
                    in_channels=3, 
                    classes=4, 
                    activation=None
                )

    def _init_loss_fn(self) -> Callable:
        loss = self.hparams.loss
        assert loss in LOSS_FNS, 'Choose from exstisting!'
        return LOSS_FNS[loss]

#     def _init_metrics(self) -> nn.ModuleDict:
#         train_metrics = MetricCollection({"train_dice": Dice()})
#         val_metrics = MetricCollection({"val_dice": Dice()})

#         return nn.ModuleDict(
#             {
#                 "train_metrics": train_metrics,
#                 "val_metrics": val_metrics,
#             }
#         )

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer_kwargs = dict(
            params=self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay
        )
        if self.hparams.optimizer == "Adam":
            optimizer = torch.optim.Adam(**optimizer_kwargs)
        elif self.hparams.optimizer == "AdamW":
            optimizer = torch.optim.AdamW(**optimizer_kwargs)
        elif self.hparams.optimizer == "SGD":
            optimizer = torch.optim.SGD(**optimizer_kwargs)
        else:
            raise ValueError(f"Unknown optimizer: {self.hparams.optimizer}")

        if self.hparams.scheduler is not None:
            if self.hparams.scheduler == "CosineAnnealingLR":
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, T_max=self.hparams.T_max, eta_min=self.hparams.min_lr
                )
            elif self.hparams.scheduler == "CosineAnnealingWarmRestarts":
                scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                    optimizer, T_0=self.hparams.T_0, eta_min=self.hparams.min_lr
                )
            else:
                raise ValueError(f"Unknown scheduler: {self.hparams.scheduler}")

            return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}
        else:
            return {"optimizer": optimizer}

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        return self.model(images)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self.shared_step(batch, "train")

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        self.shared_step(batch, "val")

    def shared_step(self, batch: Tuple[torch.Tensor, torch.Tensor], stage: str) -> torch.Tensor:
        images, masks = batch
        y_pred = self(images)
        
        loss = self.loss_fn(y_pred, masks.type(torch.int64)) #error here
#         metrics = self.metrics[f"{stage}_metrics"](y_pred, masks)

        self._log(loss, metrics={}, stage=stage)

        return loss

    def _log(self, loss: torch.Tensor, metrics: dict, stage: str):
        on_step = True if stage == "train" else False
        self.log(f"{stage}_loss", loss)#, on_step=on_step, on_epoch=True, prog_bar=not on_step)
#         self.log_dict(metrics, on_step=False, on_epoch=True)

    @classmethod
    def load_eval_checkpoint(cls, checkpoint_path: Path, device: str) -> nn.Module:
        module = cls.load_from_checkpoint(checkpoint_path=checkpoint_path).to(device)
        module.eval()

        return module

In [None]:
callbacks = []

In [49]:
def train():
    pl.seed_everything(hash("kek") % 2**32 - 1)
    
    model = RZDModel(LOSS, OPTIMIZER, LEARNING_RATE, WEIGHT_DECAY, SCHEDULER, 0, 0, MIN_LR)
    data_module = RZDDataModule(batch_size=BATCH_SIZE)
    trainer = pl.Trainer(
        logger=wandb_logger,
        max_epochs=MAX_EPOCHS,
        fast_dev_run=FAST_DEV_RUN,
        gpus=GPUS,
        log_every_n_steps=10,
        )
    trainer.fit(model, data_module)
    return trainer

In [50]:
torch.cuda.empty_cache()

In [51]:
trainer = train()

In [23]:
class TestRZD(RZDDataset):
    def __init__(self, image_paths: List[Path] = TEST_IMAGES):
        super().__init__(image_paths)
        
    def __getitem__(self, idx: int) -> torch.Tensor:
        image_path = self.image_paths[idx]
        image_name = image_path.split('/')[-1]
        image = self._load_image(image_path)
        return image, image_name

In [25]:
SUBMISSION_PATH='./submit'
model = RZDModel.load_eval_checkpoint('../input/unetplusplus-weights-512px/epoch2-step2463.ckpt', device='cpu')
test_dataset = TestRZD(TEST_IMAGES)

In [39]:
from tqdm import tqdm

In [40]:
def predict(model: torch.nn.Module, dataset: Dataset, savedir: Path = SUBMISSION_PATH) -> None:
    os.makedirs(SUBMISSION_PATH, exist_ok=True)
    test_transform = {
        "out": lambda shape: A.Resize(*shape),
        "in": A.Compose(
            [
                A.Resize(*(512, 512)),
                A.augmentations.transforms.Normalize(
                    mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
                ),
                ToTensorV2(),
            ]
        ),
    }
    transform_submit = np.vectorize(lambda x: MAP_SUBMIT[x])

    for pair in tqdm(dataset):
        image, image_name = pair
        image_tr = test_transform["in"](image=image)["image"]
        mask = model(image_tr.reshape(1, *image_tr.shape))

        mask_np = (
            mask.argmax(dim=1).numpy().reshape(512, 512)
        )  # size depends on your model
        mask_qhd = test_transform["out"](image.shape[:2])(image=mask_np.astype(np.float64))[
            "image"
        ].astype(int)
        cv2.imwrite(
            os.path.join(SUBMISSION_PATH, image_name), transform_submit(mask_qhd)
        )

In [41]:
predict(model, test_dataset)

In [42]:
import shutil

In [45]:
path_to_archive = shutil.make_archive(SUBMISSION_PATH,'zip',SUBMISSION_PATH)

In [51]:
TEST_MASKS = glob.glob(os.path.join(SUBMISSION_PATH, '*.png'))

In [63]:
dataset_check = RZDDataset(sorted(TEST_IMAGES), sorted(TEST_MASKS))

In [122]:
def show_batch_t():
    nrows = 3
    ncols = 3
    batch_size = nrows * ncols

    fig, _ = plt.subplots(figsize=(20, 20))
    for i in range(9):
        image, mask = dataset_check[np.random.randint(0, 1000)]
        plt.subplot(nrows, ncols, i + 1)
        plt.tight_layout()
        plt.axis('off')

        plt.imshow(image)
        plt.imshow(mask, alpha=0.2)

In [124]:
show_batch_t()

In [87]:
show_examples('final', dataset_check[900])