In [None]:
from pathlib import Path

from torch.utils.data import DataLoader

from src.pixseg.datasets import ADE20K, resolve_metadata
from src.pixseg.utils.transform import SegmentationTransform

root = Path(r"path/to/ADE20K")
# Preset metadata of ADE20K, which includes labels, background index, etc
metadata = resolve_metadata("ADE20K")
# Make sure the data is formatted correctly for model input
transforms = SegmentationTransform(size=(512, 512), mask_fill=metadata.ignore_index)
dataset = ADE20K(root, split="training", transforms=transforms)
train_loader = DataLoader(dataset, batch_size=4, drop_last=True, shuffle=True)

In [None]:
from torch.optim.lr_scheduler import PolynomialLR

from src.pixseg.learn import DiceLoss, Padam
from src.pixseg.models import PSPNET_ResNet50_Weights, pspnet_resnet50
from src.pixseg.utils.transform import SegmentationAugment

# Create PSPNet with ResNet-50 backbone and enable auxiliary loss
model = pspnet_resnet50(num_classes=metadata.num_classes, aux_loss=True)
# Or initialize with pretrained weights
model = pspnet_resnet50(weights=PSPNET_ResNet50_Weights.DEFAULT)
criterion = DiceLoss(ignore_index=metadata.ignore_index)
optimizer = Padam(model.parameters(), lr=0.1, weight_decay=5e-4, partial=0.125)
lr_scheduler = PolynomialLR(optimizer, total_iters=100, power=0.9)

In [None]:
from torch import Tensor

from src.pixseg.utils.metrics import MetricStore
from src.pixseg.utils.rng import seed
from src.pixseg.utils.visual import exhibit_figure, plot_confusion_matrix

# Fix random seeds of random, numpy and pytorch
seed(42)
# This project separates data transforms and data augmentations, so that visualizations can show the original images
train_augment = SegmentationAugment(hflip=0.5, mask_fill=metadata.ignore_index)
for i in range(100):  # Set your number of epochs
    model.train()
    # Store and calculate metrics efficiently
    train_ms = MetricStore(metadata.num_classes)
    for j, (images, masks) in enumerate(train_loader):
        images, masks = train_augment(images, masks)
        ...
        logits: dict[str, Tensor] = model(images)
        train_ms.store_results(masks, logits["out"].argmax(1))
        ...

    # Print training results
    print(train_ms.summarize())
    # Visualize confusion matrix and save as image
    plot_confusion_matrix(train_ms.confusion_matrix, metadata.labels)
    exhibit_figure(show=False, save_to=Path("confusion_matrix.png"))