In [1]:
import os

import cv2
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Dice
import segmentation_models_pytorch as smp

import albumentations as A
from albumentations.pytorch import ToTensorV2

import lightning as L
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from lightning.pytorch.loggers import WandbLogger

  from .autonotebook import tqdm as notebook_tqdm


## Decode mask

In [2]:
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

In [3]:
def make_binary_mask_images(img_name, mask_rle, save_path):
    img = cv2.imread(os.path.join("data/train_img", img_name))
    img = rle_decode(mask_rle, (img.shape[0], img.shape[1]))
    cv2.imwrite(os.path.join(save_path, img_name), img)
    
    return img

In [4]:
df = pd.read_csv("data/train.csv")

mask_path = "data/train_mask"
os.makedirs(mask_path, exist_ok=True)

for idx, row in df.iterrows():
    img_name = row["img_path"].split('/')[-1]
    mask_rle = row["mask_rle"]
    
    make_binary_mask_images(img_name, mask_rle, mask_path)

## Validation offline augmentation
crop with (256 * 256)

In [5]:
train_df = pd.read_csv("data/train.csv")
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=0)

In [6]:
print(len(train_df))
print(len(val_df))

5712
1428


In [7]:

def crop_img(img, img_size=256):
    img_list = []
    
    y_cnt = 0
    while True:
        start_y = y_cnt * img_size
        end_y = (y_cnt + 1) * img_size
        
        if end_y > 1024:
            break
        
        x_cnt = 0
        while True:
            start_x = x_cnt * img_size
            end_x = (x_cnt + 1) * img_size
            
            if end_x > 1024:
                break
            
            temp_img = img[start_x:end_x, start_y:end_y, :]
            x_cnt += 1
            img_list.append(temp_img)
            
        y_cnt += 1

    return img_list

In [8]:
new_data_list = []
for idx, row in val_df.iterrows():
    img_name = row["img_path"].split("/")[-1]
    img_path = os.path.join("data/train_img", img_name)
    mask_path = os.path.join("data/train_mask", img_name)
    
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    
    images = crop_img(img)
    masks = crop_img(mask)
    
    for idx, (img, mask) in enumerate(zip(images, masks)):
        new_img_name = img_name[:-4] + "_" + str(idx).zfill(2) + ".png"
        new_data_list.append({"img_id": new_img_name[:-4]})
        
        cv2.imwrite(os.path.join("data/train_img", new_img_name), img)
        cv2.imwrite(os.path.join("data/train_mask", new_img_name), mask)

    os.remove(img_path)
    os.remove(mask_path)

In [9]:
new_val_df = pd.DataFrame(new_data_list)
new_val_df

Unnamed: 0,img_id
0,TRAIN_4972_00
1,TRAIN_4972_01
2,TRAIN_4972_02
3,TRAIN_4972_03
4,TRAIN_4972_04
...,...
22843,TRAIN_0346_11
22844,TRAIN_0346_12
22845,TRAIN_0346_13
22846,TRAIN_0346_14


In [10]:
train_df.to_csv("data/new_train.csv")
new_val_df.to_csv("data/new_val.csv")

## Train Model

In [11]:
configs = {
    "model": {
        "encoder_name": "timm-regnety_320",
        "encoder_weights": "imagenet",
        "in_channels": 3,
        "classes": 1
    },
    "data": {
        "root": "data",
        "batch_size": 64
    },
}

#### Dataset

In [12]:
class DefaultTransforms:
    def __init__(self) -> None:
        super().__init__()

    def train_transform(self):
        return A.Compose(
            [
                A.OneOf(
                    [
                        A.RandomBrightness(p=1),
                        A.RandomBrightnessContrast(p=1),
                        A.Emboss(p=1),
                        A.RandomShadow(p=1),
                        A.NoOp(),
                    ],
                    p=1,
                ),
                A.OneOf(
                    [
                        A.Blur(p=1),
                        A.AdvancedBlur(p=1),
                        A.MotionBlur(p=1),
                    ],
                    p=0.6,
                ),
                A.OneOf(
                    [
                        A.NoOp(),
                        A.HorizontalFlip(p=0.5),
                        A.VerticalFlip(p=0.5),
                        A.ShiftScaleRotate(p=0.5),
                        A.Rotate(limit=90, p=1, border_mode=cv2.BORDER_REPLICATE),
                        A.RandomRotate90(p=1)
                    ],
                    p=1,
                ),
                A.RandomCrop(224, 224),
                A.Normalize(),
                ToTensorV2(transpose_mask=True)
            ]
        )

    def val_transform(self):
        return A.Compose(
            [   
                A.Resize(224, 224),
                A.Normalize(),
                ToTensorV2(transpose_mask=True)
            ]
        )

    def test_transform(self):
        return A.Compose(
            [
                A.Normalize(),
                ToTensorV2(transpose_mask=True)
            ]
        )

In [13]:
class SatelliteDataset(Dataset):
    def __init__(self, root, df, train=True, transform=None):
        self.root = root
        self.data = df
        self.train = train
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx, 1]

        if self.train:
            image = cv2.imread(os.path.join(self.root, "train_img", img_name + ".png"))
            mask = cv2.imread(os.path.join(self.root, "train_mask", img_name + ".png"), cv2.IMREAD_GRAYSCALE)

            if self.transform:
                augmented = self.transform(image=image, mask=mask)
                image = augmented['image']
                mask = augmented['mask']

            return image, mask
        else: 
            image = cv2.imread(os.path.join(self.root, "test_img", img_name, + ".png"))
            
            if self.transform:
                image = self.transform(image=image)['image']
            return image
    
    
class SatelliteDataModule(L.LightningDataModule):
    def __init__(
        self,
        root: str,
        batch_size: int,
    ) -> None:
        super().__init__()
        self.root = root
        self.transforms = DefaultTransforms()
        self.train_transform = self.transforms.train_transform()
        self.val_transform = self.transforms.val_transform()
        self.test_transform = self.transforms.test_transform()
        self.batch_size = batch_size
        
    def setup(self, stage: str) -> None:
        train_df = pd.read_csv(os.path.join(self.root, "new_train.csv"))
        test_df = pd.read_csv(os.path.join(self.root, "test.csv"))
        val_df = pd.read_csv(os.path.join(self.root, "new_val.csv"))
        
        self.train_dataset = SatelliteDataset(
            self.root, df=train_df, train=True, transform=self.train_transform
        )
        self.val_dataset = SatelliteDataset(
            self.root, df=val_df, train=True, transform=self.val_transform
        )
        self.test_dataset = SatelliteDataset(
            self.root, df=test_df, train=False, transform=self.test_transform
        )

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8
        )

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8
        )

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(
            self.test_dataset, batch_size=100, shuffle=False, num_workers=8
        )


#### Model

In [14]:
class SMP(nn.Module):
    def __init__(self, encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1):
        super().__init__()
        
        self.model = smp.UnetPlusPlus(
            encoder_name=encoder_name,   
            encoder_weights=encoder_weights,   
            in_channels=in_channels,             
            classes=classes,                    
        )

    def forward(self, x):
        x = self.model(x)
        return x

#### Training

In [15]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        inputs = F.sigmoid(inputs) # sigmoid를 통과한 출력이면 주석처리
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice 


In [16]:

class LitSeg(L.LightningModule):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["net", "loss_module", "metric_module"])

        self.net = SMP(
            encoder_name=configs["model"]["encoder_name"], 
            encoder_weights=configs["model"]["encoder_weights"], 
            in_channels=configs["model"]["in_channels"], 
            classes=configs["model"]["classes"], 
        )

        self.loss_module = DiceLoss()
        self.metric_module = Dice()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=300)
        
        return [optimizer], [scheduler]
    
    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        img, mask = batch
        pred = self(img)

        loss = self.loss_module(pred, mask)
        self.log("train/loss", loss.item())

        return loss

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        pred = self(img).squeeze()

        loss = self.loss_module(pred, mask)
        self.log("val/loss", loss.item(), on_epoch=True, on_step=False)

        dice_score = self.metric_module(pred, mask)
        self.log("val/dice_score", dice_score, on_epoch=True, on_step=False, prog_bar=True)
        
        self.logger.log_image(
            key="results",
            images=[img.unbind(dim=0)[0], pred.unbind(dim=0)[0], mask.float().unbind(dim=0)[0]],
            caption=["image", "pred", "mask"]
        )
        
    def on_test_start(self):
        self.result = []

    def test_step(self, batch, batch_idx):
        img = batch
        pred = self(img)
        
        pred = torch.sigmoid(pred).cpu().numpy()
        pred = np.squeeze(pred, axis=1)
        pred = (pred > 0.35).astype(np.uint8) # Threshold = 0.35
        
        for i in range(len(img)):
            mask_rle = self.rle_encode(masks[i])
            if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                self.result.append(-1)
            else:
                self.result.append(mask_rle)
        
    def on_test_end(self):
        submit = pd.read_csv('data/sample_submission.csv')
        submit['mask_rle'] = self.result
        submit.to_csv('./submit.csv', index=False)
        
    def rle_encode(self, mask):
        pixels = mask.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

In [None]:
logger = WandbLogger(project="DACON")
model = LitSeg()
data_module = SatelliteDataModule(
    root=configs["data"]["root"], 
    batch_size=configs["data"]["batch_size"]
)
trainer = L.Trainer(
    accelerator="gpu",
    precision="16-mixed",
    logger=logger, 
    max_epochs=300, 
)
trainer.fit(model=model, datamodule=data_module)

In [None]:
model = LitSeg().load_from_checkpoint("DACON/...")
trainer.test(model=model, datamodule=data_module)

![picture](https://github.com/silverstar0727/artifacts/blob/main/dacon-buliding-wandb.png?raw=true)