In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


# Preprocess image (resize, normalize)
IMG_SIZE = (256, 256)
NUM_CLASSES = 21 # 20 classes + the background

# Transformation class
class VOCTransforms:
    def __init__(self, img_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transform = transforms.Compose([
            transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToImage(),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean=mean, std=std),
        ])
        self.target_transform = transforms.Compose([
            transforms.Resize(img_size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToImage(),
            transforms.ToDtype(torch.long, scale=False),
        ])
    def __call__(self, img, target):
        image = self.transform(img)
        target = self.target_transform(target)
        target = target.squeeze(0)
        return image, target

# Load dataset, image and segmentation mask

voc_transforms = VOCTransforms(IMG_SIZE)

DATA_ROOT = '/data'

train_dataset = VOCSegmentation(
    root=DATA_ROOT,
    year='2012',
    image_set='train',
    download=False,
    transforms=voc_transforms
)

val_dataset = VOCSegmentation(
    root=DATA_ROOT,
    year='2012',
    image_set='val',
    download=False,
    transforms=voc_transforms
)

BATCH_SIZE = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4 
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4 
)


# Define U-Net model:
    # Encoder path (downsampling):
        # conv → relu → conv → relu → maxpool
        # repeat, doubling channels

    # Decoder path (upsampling):
        # upsample → concat with encoder feature map → convs
        # repeat, halving channels

    # Final 1x1 conv → class scores for each pixel

# Forward pass → compute loss → backprop → update

# Use trained model to predict segmentation masks
