In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from pytorch_lightning import LightningModule, Trainer
import kornia as K

# import numpy as np
import matplotlib.pyplot as plt
# import plotly.express as px
# from PIL import Image

from sklearn.model_selection import train_test_split
import os

%matplotlib inline

In [2]:
torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PATH_IMAGE =  '../AerialImageDataset/crop/images/'
PATH_MASK =  '../AerialImageDataset/crop/gt/'

list_mask = os.listdir(PATH_MASK)
list_img = os.listdir(PATH_IMAGE)

In [3]:
my_list = list(set(list_mask) & set(list_img))

In [4]:
train, val = train_test_split(my_list, train_size=0.8)

In [5]:
len(train), len(val), len(my_list)

(14400, 3600, 18000)

In [6]:
from torch._C import Size
# define some augmentations
class Preprocessing(nn.Module):
    # to dataloader
    _preprocess =  K.augmentation.AugmentationSequential(
    K.augmentation.RandomCrop((320, 320)),
    K.augmentation.RandomHorizontalFlip(p=0.5),
    K.augmentation.RandomVerticalFlip(p=0.5),
    # K.augmentation.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    data_keys=['input', 'mask'],
    )

    def __init__(self):
        super(Preprocessing, self).__init__()
        
    @torch.no_grad()
    def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x, y = self.preprocess(img, mask)
        return x[0], y[0]

    def preprocess(self, img: torch.Tensor, mask: torch.Tensor) -> dict:
        return self._preprocess(img, mask)


class Augmentation(nn.Module):
    # to SegModel
    _augmentations = K.augmentation.AugmentationSequential(
    K.augmentation.RandomCrop((320, 320)),
    K.augmentation.RandomHorizontalFlip(p=0.5),
    K.augmentation.RandomVerticalFlip(p=0.5),
    K.augmentation.RandomPlanckianJitter( p=0.3),
    # K.augmentation.RandomChannelShuffle(p=0.5),
    K.augmentation.RandomEqualize(p=0.3),
    K.augmentation.RandomAffine([-45., 45.], [0., 0.15], [0.5, 1.5], [0., 0.15], p=0.6),
    K.augmentation.RandomBoxBlur((3,3), p=0.2),
    # K.augmentation.RandomGaussianNoise(p=0.2),
    # K.augmentation.RandomMotionBlur(3, 35., 0.5, p=0.2),
    # K.augmentation.RandomPosterize(p=0.1),
    # K.augmentation.RandomPerspective(p=0.1),
    # K.augmentation.RandomElasticTransform(p=0.1),
    # K.augmentation.RandomSharpness(p=0.2),
    K.augmentation.RandomRotation(20, p=0.8),
    K.augmentation.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    data_keys=['input', 'mask'],
    same_on_batch=False,
    )

    def __init__(self):
        super(Augmentation, self).__init__()

    @torch.no_grad()
    def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        
        x, y = self.augmentations(img, mask)
        return x[0], y[0]

    def augmentations(self, img: torch.Tensor, mask: torch.Tensor) -> dict:
        x, y = self._augmentations(img, mask)
        return x, y


In [7]:
Pre = Preprocessing()
Aug = Augmentation()

In [8]:
class GeoDataset(Dataset):

    def __init__(self,
     list_names, preprocess, c, h, w):
        self.shared_image = torch.zeros(c, h, w).share_memory_()
        self.shared_mask = torch.zeros(1, h, w).share_memory_()
        self.preprocess = preprocess
        self.list_names = list_names

    @torch.no_grad()
    def get_img(self, idx):
        if type(self.list_names) ==  list:
            # name = random.choice(self.list_names)
            # name = next(iter(self.list_names))
            name = self.list_names[idx]
        image_name = PATH_IMAGE + f'{name}'
        mask_name = PATH_MASK + f'{name}'
        self.shared_image = torchvision.io.read_image(image_name) / 255.
        self.shared_mask = torchvision.io.read_image(mask_name) / 255.

    def __len__(self):
        return len(self.list_names)
    
    def __getitem__(self, idx):
        # if idx%10 == 0:  
        self.get_img(idx)
        return self.preprocess(self.shared_image, self.shared_mask)

In [9]:
train_dataset = GeoDataset(
    # image=image_train, mask=mask_train, 
    preprocess=Aug,
    list_names=train,
    c=3, h=500, w=500)
train_dataloader = DataLoader(train_dataset, batch_size=4, 
                              num_workers=16, 
                              shuffle=True)
val_dataset = GeoDataset(
    # image=image_train, mask=mask_train, 
    preprocess=Pre,
    list_names=val,
    c=3, h=500, w=500)
val_dataloader = DataLoader(val_dataset, batch_size=4, 
                            num_workers=16
                            )

In [10]:
# for i, m in train_dataloader:
#     print(i.shape, m.shape)

In [11]:
i, m = next(iter(val_dataloader))


In [12]:
i.shape, m.shape

(torch.Size([4, 3, 320, 320]), torch.Size([4, 1, 320, 320]))

In [13]:
import segmentation_models_pytorch as smp

In [14]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['building']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

In [15]:
loss = smp.losses.DiceLoss(mode='binary')
metric = smp.losses.JaccardLoss(mode='binary')

In [16]:
class SemSegment(LightningModule):
    def __init__(
        self,
        lr: float = 0.001,
        num_classes: int = 1,
        bilinear: bool = False,
    ):
        super().__init__()

        self.bilinear = bilinear
        self.lr = lr
        # self.net = smp.Unet(encoder_name='resnet18', classes=num_classes)
    #     self.net = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    # in_channels=3, out_channels=num_classes, init_features=32, pretrained=False)
        self.net = model
        self.loss = loss
        self.metric = metric

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        img, mask = batch
        out = self(img.float())
        # out[out >= 0.5] = 1.
        # out[out < 0.5] = 0
        # print(f'img size {img.shape}  out size {out.shape}  mask size {mask.shape}')
        loss_train_dice = self.loss(out, mask.float())
        metric = self.metric(out, mask.float())
        # Логи тренировочных шагов для tensorboard
        self.log('train_dice_step', loss_train_dice, on_step=True, )

        return {"loss": metric, "log": metric.detach(), "progress_bar": loss_train_dice.detach()}

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        # print(img.shape, mask.shape)
        out = self(img.float())
        # out[out >= 0.5] = 1.
        # out[out < 0.5] = 0
        # print(img.shape, mask.shape, out.shape)
        # print(f'img size {img.shape}  out size {out.shape}  mask size {mask.shape}')
        # loss_val_dice = self.loss(out, mask.float())
        metric = self.metric(out, mask.float())
        # Логи валидационных шагов для tensorboard
        self.log('val_loss_step', metric, on_step=True)

        return {"val_loss": metric}


    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x["val_loss"] for x in outputs]).mean()
        log_dict = {"val_loss": loss_val.detach()}
        # Логи валидационных эпох для tensorboard
        self.log('val_epoch_total_step', log_dict['val_loss'], on_epoch=True)

        return log_dict

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min'),
            'monitor' : 'train_dice_step',
        }
        return {
                "optimizer": optimizer,
                "lr_scheduler": scheduler,
                }


In [17]:
from pytorch_lightning.callbacks import ModelCheckpoint

In [18]:
# model = SemSegment.load_from_checkpoint('lightning_logs/best/epoch=222-step=16501.ckpt')
model = SemSegment()
checkpoint_callback = ModelCheckpoint(dirpath=f"lightning_logs/best__{ENCODER}", save_top_k=2, monitor="val_epoch_total_step")

In [19]:
trainer = Trainer(gpus=1,
 max_epochs=500,
 callbacks=[checkpoint_callback]
 )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [20]:
trainer.fit(model,
 train_dataloader, val_dataloader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type        | Params
---------------------------------------
0 | net    | Unet        | 32.5 M
1 | loss   | DiceLoss    | 0     
2 | metric | JaccardLoss | 0     
---------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params
130.084   Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
