In [1]:
import os
import pandas as pd
import numpy as np
import cv2
from PIL import Image
from patchify import patchify
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch
from unet import LitUNet, UNet

In [2]:
# Constants
patch_size = 256
scaler = MinMaxScaler()

# Color to class mapping
def hex_to_rgb(hex_code):
    hex_code = hex_code.lstrip('#')
    return np.array(tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4)))

COLOR_MAP = {
    0: hex_to_rgb('#3C1098'),  # Building
    1: hex_to_rgb('#8429F6'),  # Land
    2: hex_to_rgb('#6EC1E4'),  # Road
    3: hex_to_rgb('FEDD3A'),   # Vegetation
    4: hex_to_rgb('E2A929'),   # Water
    5: hex_to_rgb('#9B9B9B')   # Unlabeled
}

In [3]:
def rgb_to_2D_label(label):
    label_seg = np.zeros(label.shape[:2], dtype=np.uint8)
    for k, v in COLOR_MAP.items():
        matches = np.all(label == v, axis=-1)
        label_seg[matches] = k
    return label_seg


In [4]:
# Custom dataset
class SegmentationDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.image_patches = []
        self.mask_patches = []
        self._prepare_data()

    def _prepare_data(self):
        for idx, row in self.df.iterrows():
            img = cv2.imread(row['Image'])
            h, w = img.shape[:2]
            img = Image.fromarray(img)
            img = img.crop((0, 0, (w // patch_size) * patch_size, (h // patch_size) * patch_size))
            img = np.array(img)
            img_patches = patchify(img, (patch_size, patch_size, 3), step=patch_size)

            mask = cv2.imread(row['Mask'])
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            mask = Image.fromarray(mask)
            mask = mask.crop((0, 0, (w // patch_size) * patch_size, (h // patch_size) * patch_size))
            mask = np.array(mask)
            mask_patches = patchify(mask, (patch_size, patch_size, 3), step=patch_size)

            for i in range(img_patches.shape[0]):
                for j in range(img_patches.shape[1]):
                    img_patch = img_patches[i, j, 0]
                    img_patch = scaler.fit_transform(img_patch.reshape(-1, 3)).reshape(img_patch.shape)
                    mask_patch = rgb_to_2D_label(mask_patches[i, j, 0])
                    self.image_patches.append(img_patch)
                    self.mask_patches.append(mask_patch)

    def __len__(self):
        return len(self.image_patches)

    def __getitem__(self, idx):
        image = torch.tensor(self.image_patches[idx], dtype=torch.float32).permute(2, 0, 1)
        mask = torch.tensor(self.mask_patches[idx], dtype=torch.long)
        return image, mask

# LightningDataModule
class SegmentationDataModule(pl.LightningDataModule):
    def __init__(self, train_csv, test_csv, batch_size=16):
        super().__init__()
        self.train_csv = train_csv
        self.test_csv = test_csv
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = SegmentationDataset(self.train_csv)
        self.val_dataset = SegmentationDataset(self.test_csv)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

In [5]:
import matplotlib.pyplot as plt

# Prepare datamodule
data_module = SegmentationDataModule('train.csv', 'test.csv', batch_size=16)
data_module.setup()

# Determine number of classes dynamically
all_masks = torch.cat([mask.flatten() for _, mask in data_module.train_dataset])
n_classes = len(torch.unique(all_masks))

# Model
model = LitUNet(n_classes=n_classes, in_channels=3)

# Trainer
trainer = pl.Trainer(max_epochs=10, accelerator='auto')

# Train
trainer.fit(model, datamodule=data_module)


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/soumendusekharbhattacharjee/anaconda3/envs/unet_pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/soumendusekharbhattacharjee/anaconda3/envs/unet_pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                                

/Users/soumendusekharbhattacharjee/anaconda3/envs/unet_pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0: 100%|█████████████████████████| 69/69 [00:28<00:00,  2.46it/s, v_num=1]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                        | 0/10 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                           | 0/10 [00:00<?, ?it/s][A
Validation DataLoader 0:  10%|█▉                 | 1/10 [00:00<00:00, 30.41it/s][A
Validation DataLoader 0:  20%|███▊               | 2/10 [00:00<00:00,  8.54it/s][A
Validation DataLoader 0:  30%|█████▋             | 3/10 [00:00<00:00,  8.37it/s][A
Validation DataLoader 0:  40%|███████▌           | 4/10 [00:00<00:00,  8.49it/s][A
Validation DataLoader 0:  50%|█████████▌         | 5/10 [00:00<00:00,  8.53it/s][A
Validation DataLoader 0:  60%|███████████▍       | 6/10 [00:00<00:00,  8.57it/s][A
Validation DataLoader 0:  70%|█████████████▎     | 7/10 [00:00<00:00,  8.53it/s][A
Validation DataLoader 0:  80%|███████████████▏   | 8/10 [00:00<00:00,  8.54it/s

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|█| 69/69 [01:38<00:00,  0.70it/s, v_num=1, val_loss=0.578, val_jac
