In [14]:
import os
import sys
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset
from torchvision.transforms import functional as TF
from dataclasses import dataclass
from __future__ import annotations
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import os
import sys
import time
import torch
from model import UNet
from loss_function import BCEwithDiceLoss, BCELoss2d, dice_loss, dice_coeff
from data_utils import ImageDatasetConfig, ImageDataset
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch import Trainer
from schedulefree.adamw_schedulefree import AdamWScheduleFree

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import torchvision.transforms.functional as TF

class ImageDataset(Dataset):
    def __init__(self, image_ds_config) -> None:
        super().__init__()

        self.foreground_directory = image_ds_config.foreground_dir
        self.background_directory = image_ds_config.background_dir
        self.mode = image_ds_config.mode
        self.image_size = image_ds_config.image_size
        self.augment = image_ds_config.augment
        self.augment_prob = image_ds_config.augment_prob
        self.rotation_degree = [0, 90, 180, 270]

        # Load file paths lazily when needed
        self.image_paths = self._load_image_paths()

        # Pre-load background images to avoid repeated I/O operations
        self.background_images = self._load_background_images()

        # Prepare common transforms
        self.base_transform = T.Compose([
            T.Resize((self.image_size, self.image_size)),
            T.ToTensor(),
        ])


    def __len__(self):
        return len(self.image_paths)
        

    def _load_image_paths(self):
        if self.mode == "train":
            return [os.path.join(self.foreground_directory, "train", x) for x in os.listdir(os.path.join(self.foreground_directory, "train"))]
        elif self.mode == "val":
            return [os.path.join(self.foreground_directory, "val", x) for x in os.listdir(os.path.join(self.foreground_directory, "val"))]
        else:
            return [os.path.join(self.foreground_directory, "test", x) for x in os.listdir(os.path.join(self.foreground_directory, "test"))]


    def _load_background_images(self):
        bg_dir = os.path.join(self.background_directory, self.mode)
        return [os.path.join(bg_dir, x) for x in os.listdir(bg_dir)]


    def __getitem__(self, index):
        img_path = self.image_paths[index]

        # Load and transform image and mask
        img, mask = self._load_and_process_image(img_path)

        # Apply augmentations if enabled
        if self.mode == "train" and self.augment and random.random() < self.augment_prob:
            img, mask = self._apply_augmentation(img, mask)

        # Apply final resizing and convert to tensor
        img = self.base_transform(img)
        mask = self.base_transform(mask)

        return img, mask

    def __len__(self):
        return len(self.image_paths)

    def _load_and_process_image(self, img_path: str):
        # Load the foreground image with alpha channel
        image_alpha = Image.open(img_path)
        assert image_alpha.mode == 'RGBA', "Image should be RGBA"

        img = Image.merge('RGB', image_alpha.split()[:3])
        mask = image_alpha.split()[-1]  # Alpha channel as mask

        # Select and resize a random background image
        bg_img = Image.open(random.choice(self.background_images)).resize(img.size)

        # Composite the foreground over the background using the mask
        bg_img.paste(img, mask=mask)

        return bg_img, mask

    def _apply_augmentation(self, img, mask):
        """Apply the same augmentations to both the image and the mask."""
        
        # Random Rotation
        angle = random.uniform(-10, 10)
        img = TF.rotate(img, angle)
        mask = TF.rotate(mask, angle)

        # Random Color Jitter (only applied to img, not mask, since mask is not RGB)
        color_jitter = T.ColorJitter(brightness=0.2, contrast=0.2, hue=0.2)
        img = color_jitter(img)

        # Random Horizontal Flip
        if random.random() > 0.5:
            img = TF.hflip(img)
            mask = TF.hflip(mask)

        # Random Vertical Flip
        if random.random() > 0.5:
            img = TF.vflip(img)
            mask = TF.vflip(mask)

        return img, mask

In [17]:
@dataclass
class ImageDatasetConfig:
    foreground_dir: str = "shoe_dataset/"
    background_dir: str = "shoe_dataset/bg/"
    mode: str = "train"
    image_size: int = 256
    augment: bool = False
    augment_prob: float = 0.5

img_ds_config = ImageDatasetConfig()

In [4]:
img_ds_config.augment = True

In [5]:
train_dataset = ImageDataset(img_ds_config)
img_ds_config.augment = False
img_ds_config.mode = "val"
val_dataset = ImageDataset(img_ds_config)

In [27]:
@dataclass
class SegmentationConfig:
    n_channels: int = 3
    n_classes: int = 1
    alpha: int = 0.5
    beta: int = 0.5
    smooth: float = 1e-5
    lr: float = 3e-4
    weight_decay: float = 0.1
    batch_size: int = 8
    num_workers: int = 4
    betas: tuple = (0.9, 0.999)
    eps: float = 1e-8
    epochs: int = 1
    device: str = "cpu"

seg_config = SegmentationConfig()

In [36]:
train_dataloader = DataLoader(train_dataset, batch_size=seg_config.batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=seg_config.batch_size, shuffle=False)

In [37]:
len(train_dataloader), len(val_dataloader)

(79, 28)

In [38]:
model = UNet(n_channels=seg_config.n_channels, n_classes=seg_config.n_classes)

In [39]:
class SegmentationWrapper(L.LightningModule):
    def __init__(self, model, config: SegmentationConfig):
        super().__init__()
        self.model = model
        self.config = config
        self.loss_fn = BCEwithDiceLoss(alpha=config.alpha, beta=config.beta, smooth=config.smooth)
        self.dice_loss = dice_loss
        self.optimizer = self.configure_optimizers()

    def training_step(self, batch, batch_idx):
        self.model.train()
        optimizer = self.optimizers()
        optimizer.train()
        optimizer.zero_grad()

        img, mask = batch
        output = self.model(img)
        loss = self.loss_fn(output, mask)

        self.log("train_loss", loss, prog_bar=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        self.model.eval()
        optimizer = self.optimizers()
        optimizer.eval()

        img, mask = batch
        output = self.model(img)
        loss = self.loss_fn(output, mask)

        self.log("val_loss", loss, prog_bar=True)
    
    def configure_optimizers(self):
        return AdamWScheduleFree(self.model.parameters(), lr=self.config.lr, betas=self.config.betas, eps=self.config.eps)

In [40]:
segmentation_wrapper = SegmentationWrapper(model, seg_config)

In [41]:
trainer = Trainer(max_epochs=seg_config.epochs, accelerator="cpu")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [43]:
# trainer.fit(segmentation_wrapper, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)