In [2]:
!python -m pip install --upgrade cityscapesscripts[gui] timm torchvision transformers appdirs requests datasets nvidia-dali-cuda120 torch==2.5.1 triton

Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting transformers
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting appdirs
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting datasets
  Downloading datasets-3.3.1-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-dali-cuda120
  Downloading nvidia_dali_cuda120-1.46.0.tar.gz (1.6 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting triton
  Downloading triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting cityscapesscripts[gui]
  Downloading cityscapesScripts-2.2.4-py3-none-any.whl.metadata (9.8 kB)
Collecting nvidia

In [None]:
#!/usr/bin/env python
"""
SEGMENTER: More Official-Like Training Script with GPU-Accelerated Augmentation via NVIDIA DALI
and Automatic Resource-Based Tuning

This script reproduces the "Segmenter" paper's approach while:
  - Detecting system CPU cores and available RAM to automatically tune the DataLoader.
  - Checking GPU memory and printing a summary so you can adjust batch size if desired.
  - Using channels_last, AMP, and torch.compile (if available) for improved GPU throughput.
  - Optionally using Distributed Data Parallel (DDP) for multi-GPU training.
  - Optionally using NVIDIA DALI for GPU-accelerated data loading & augmentation.

It supports large or small datasets (ADE20K, Cityscapes, Pascal Context, or a toy dataset),
and includes standard or advanced augmentations (Mixup, CutMix, ColorJitter, etc.).
See the argument list for details.

Example usage (with DALI, on ADE20K):
python segmenter_official_like.py \
    --dataset_name ade20k \
    --num_classes 150 \
    --lr 0.001 \
    --total_iterations 161728 \
    --batch_size 16 \
    --grad_accum_steps 1 \
    --eval_freq 2 \
    --num_epochs 64 \
    --start_epoch 0 \
    --model_name vit_base_patch16_384 \
    --decoder mask \
    --stoch_depth 0.1 \
    --backbone_dropout 0.0 \
    --decoder_drop_path_rate 0.0 \
    --advanced_aug \
    --cutmix \
    --dataset_normalization vit \
    --iter_warmup 0.0 \
    --min_lr 1e-05 \
    --clip_grad 0.0 \
    --cache_data \
    --use_dali \
    --distilled false \
    --version normal \
    [--ddp]

NOTE:
  - If using multiple GPUs, run with torch.distributed launch or torchrun to enable --ddp.
  - For GCP or Colab usage, set the runtime to your desired accelerator (GPU/TPU).
"""

import os
import sys
import json
import stat
import tarfile
import getpass
import requests
import appdirs
import math
import numpy as np
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import timm  # Vision Transformers, etc.

# Allow PIL to load truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Enable cuDNN benchmarking if input sizes are fixed
torch.backends.cudnn.benchmark = True

########################################
# 0) Automatic Hardware Detection
########################################
try:
    import psutil
    cpu_logical = psutil.cpu_count(logical=True)
    cpu_physical = psutil.cpu_count(logical=False)
    mem = psutil.virtual_memory()
    free_ram_gb = mem.available / (1024**3)
    print(f"Detected CPU: {cpu_logical} logical cores, {cpu_physical} physical cores.")
    print(f"Available system RAM: {free_ram_gb:.1f} GB")
except ImportError:
    print("psutil not found. Install psutil for CPU/RAM detection.")
    cpu_logical = 4
    cpu_physical = 2
    free_ram_gb = 8

########################################
# 1) Download/Fetch Datasets (Illustrative Examples)
########################################
def login_cityscapes():
    """Logs in to cityscapes-dataset.com. Prompts user if missing credentials."""
    appname = "cityscapesLoginApp"
    appauthor = "cityscapes"
    data_dir = appdirs.user_data_dir(appname, appauthor)
    credentials_file = os.path.join(data_dir, 'credentials.json')
    if os.path.isfile(credentials_file):
        with open(credentials_file, 'r') as f:
            credentials = json.load(f)
    else:
        username = input("Cityscapes username/email: ")
        password = getpass.getpass("Cityscapes password: ")
        credentials = {'username': username, 'password': password}
        store_q = f"Store credentials unencrypted in '{credentials_file}'? [y/N] "
        if input(store_q).strip().lower() in ['y', 'yes']:
            os.makedirs(data_dir, exist_ok=True)
            with open(credentials_file, 'w') as f:
                json.dump(credentials, f)
            os.chmod(credentials_file, stat.S_IREAD | stat.S_IWRITE)
    sess = requests.Session()
    r = sess.get("https://www.cityscapes-dataset.com/login", allow_redirects=False)
    r.raise_for_status()
    credentials['submit'] = 'Login'
    r2 = sess.post("https://www.cityscapes-dataset.com/login", data=credentials, allow_redirects=False)
    r2.raise_for_status()
    if r2.status_code != 302:
        if os.path.isfile(credentials_file):
            os.remove(credentials_file)
        raise Exception("Cityscapes login failed. Check credentials.")
    return sess

def parse_size_to_bytes(size_str):
    """Convert e.g. '12MB' => 12582912 bytes."""
    size_str = size_str.strip().upper()
    if size_str.endswith("KB"):
        return float(size_str[:-2]) * 1024
    elif size_str.endswith("MB"):
        return float(size_str[:-2]) * 1024 * 1024
    elif size_str.endswith("GB"):
        return float(size_str[:-2]) * 1024 * 1024 * 1024
    else:
        raise ValueError(f"Unknown size format '{size_str}' in parse_size_to_bytes")

def get_cityscapes_packages(session):
    r = session.get("https://www.cityscapes-dataset.com/downloads/?list", allow_redirects=False)
    r.raise_for_status()
    return r.json()

def download_cityscapes_packages(pkg_names, destination_path, session, resume=False):
    """Download official cityscapes zips with optional resume."""
    pkgs_info = get_cityscapes_packages(session)
    name_to_id = {p['name']: p['packageID'] for p in pkgs_info}
    name_to_bytes = {p['name']: parse_size_to_bytes(p['size']) for p in pkgs_info}
    # Validate
    invalid = [n for n in pkg_names if n not in name_to_id]
    if invalid:
        raise ValueError(f"Package(s) not recognized: {invalid}")
    os.makedirs(destination_path, exist_ok=True)

    for pkg in pkg_names:
        pkg_id = name_to_id[pkg]
        out_path = os.path.join(destination_path, pkg)
        # Check if existing
        if os.path.exists(out_path) and not resume:
            raise FileExistsError(f"{out_path} exists. Pass resume=True to continue partial.")
        # Get MD5
        md5_url = f"https://www.cityscapes-dataset.com/md5-sum/?packageID={pkg_id}"
        r_md5 = session.get(md5_url, allow_redirects=False)
        r_md5.raise_for_status()
        md5_expected = r_md5.text.split()[0]

        # Download
        data_url = f"https://www.cityscapes-dataset.com/file-handling/?packageID={pkg_id}"
        mode = 'ab' if resume else 'wb'
        resume_header = {}
        start_offset = 0
        if resume and os.path.exists(out_path):
            start_offset = os.path.getsize(out_path)
            resume_header = {'Range': f'bytes={start_offset}-'}
        r_data = session.get(data_url, allow_redirects=False, stream=True, headers=resume_header)
        r_data.raise_for_status()
        assert r_data.status_code in [200, 206], f"Got {r_data.status_code} from server."

        total_size = name_to_bytes[pkg]
        with open(out_path, mode) as f_out, tqdm(
            desc=f"Download {pkg}", total=total_size,
            initial=start_offset, unit='B', unit_scale=True
        ) as pbar:
            for chunk in r_data.iter_content(chunk_size=8192):
                f_out.write(chunk)
                pbar.update(len(chunk))
        # Check MD5
        import hashlib
        calc_md5 = hashlib.md5()
        with open(out_path, "rb") as fch:
            for chunk in iter(lambda: fch.read(4096), b""):
                calc_md5.update(chunk)
        if calc_md5.hexdigest() != md5_expected:
            raise ValueError(f"MD5 mismatch for {pkg} => corrupted download.")

def fetch_cityscapes_official(root_dir="datasets/Cityscapes_official"):
    """Automate Cityscapes official data fetch (requires credentials)."""
    left_dir = os.path.join(root_dir, "leftImg8bit")
    gt_dir = os.path.join(root_dir, "gtFine")
    if os.path.isdir(left_dir) and os.path.isdir(gt_dir):
        print("[Cityscapes] Found => skipping download.")
        return root_dir
    print("[Cityscapes] Downloading official zips.")
    sess = login_cityscapes()
    desired = ["leftImg8bit_trainvaltest.zip", "gtFine_trainvaltest.zip"]
    download_cityscapes_packages(desired, root_dir, sess, resume=False)
    print("[Cityscapes] Download complete. Please unzip these files in place. The final folder should have leftImg8bit/, gtFine/.")
    return root_dir

def fetch_voc2007_official(root_dir="datasets/VOC2007_official"):
    voc_root = os.path.join(root_dir, "VOCdevkit", "VOC2007")
    success_file = os.path.join(voc_root, "_SUCCESS.txt")
    if os.path.exists(success_file):
        print("[VOC2007] Found => skip download.")
        return root_dir
    print("[VOC2007] Downloading train/val/test tar files from official site.")
    os.makedirs(root_dir, exist_ok=True)
    tv_url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar"
    tst_url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar"
    tv_tar = os.path.join(root_dir, "VOCtrainval_06-Nov-2007.tar")
    tst_tar = os.path.join(root_dir, "VOCtest_06-Nov-2007.tar")

    def dl(url, outpath):
        if os.path.exists(outpath):
            print("[VOC2007] Already downloaded =>", outpath)
            return
        print("[VOC2007] Downloading =>", url)
        r = requests.get(url, stream=True)
        r.raise_for_status()
        with open(outpath, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

    dl(tv_url, tv_tar)
    dl(tst_url, tst_tar)

    def extract_tar(tfile, dest):
        print("[VOC2007] Extracting =>", tfile)
        with tarfile.open(tfile, "r") as tf:
            tf.extractall(dest)
    extract_tar(tv_tar, root_dir)
    extract_tar(tst_tar, root_dir)

    os.makedirs(os.path.dirname(success_file), exist_ok=True)
    with open(success_file, "w") as f:
        f.write("Done.\n")

    return root_dir

def fetch_pascal_context_official(root_dir="datasets/PascalContext_official"):
    url = "https://cs.stanford.edu/~roozbeh/pascal-context/trainval.tar.gz"
    success = os.path.join(root_dir, "VOCdevkit", "VOC2010", "_SUCCESS.txt")
    if os.path.exists(success):
        print("[PascalContext] Found => skip download.")
        return root_dir
    os.makedirs(root_dir, exist_ok=True)
    tar_fp = os.path.join(root_dir, "trainval.tar.gz")
    if not os.path.exists(tar_fp):
        print("[PascalContext] Downloading =>", url)
        r = requests.get(url, stream=True)
        r.raise_for_status()
        with open(tar_fp, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print("[PascalContext] Extracting =>", tar_fp)
    with tarfile.open(tar_fp, "r:gz") as tf:
        tf.extractall(path=root_dir)
    Path(os.path.dirname(success)).mkdir(parents=True, exist_ok=True)
    with open(success, "w") as f:
        f.write("Done.\n")
    return root_dir

def fetch_ade20k_via_hf(root_dir="datasets/ADE20K_HF"):
    """Fetch ADE20K from Hugging Face hub if not found locally."""
    import datasets
    success_file = os.path.join(root_dir, "_SUCCESS.txt")
    if os.path.exists(success_file):
        print("[ADE20K] Found => skipping download.")
        return root_dir
    os.makedirs(root_dir, exist_ok=True)
    print("[ADE20K] Downloading from huggingface.co 'zhoubolei/scene_parse_150'")
    ds_all = datasets.load_dataset("zhoubolei/scene_parse_150", split=None)

    # Subdirs
    img_t_dir = os.path.join(root_dir, "images", "training")
    msk_t_dir = os.path.join(root_dir, "annotations", "training")
    img_v_dir = os.path.join(root_dir, "images", "validation")
    msk_v_dir = os.path.join(root_dir, "annotations", "validation")
    for d in [img_t_dir, msk_t_dir, img_v_dir, msk_v_dir]:
        os.makedirs(d, exist_ok=True)

    for i, sample in enumerate(ds_all["train"]):
        sample["image"].save(os.path.join(img_t_dir, f"train_{i:06d}.jpg"))
        sample["annotation"].save(os.path.join(msk_t_dir, f"train_{i:06d}.png"))
    for i, sample in enumerate(ds_all["validation"]):
        sample["image"].save(os.path.join(img_v_dir, f"val_{i:06d}.jpg"))
        sample["annotation"].save(os.path.join(msk_v_dir, f"val_{i:06d}.png"))

    with open(success_file, "w") as f:
        f.write("Done\n")
    return root_dir


########################################
# 2) PyTorch Dataset & Transforms (CPU-based or fallback)
########################################
def build_transforms(dataset_name="ade20k", train=True, advanced_aug=False, normalization="none"):
    short_size = 512
    t_list = []
    if train:
        t_list.append(T.RandomResizedCrop(short_size, scale=(0.5, 2.0)))
        t_list.append(T.RandomHorizontalFlip(0.5))
        if advanced_aug:
            t_list.append(T.ColorJitter(0.4, 0.4, 0.4, 0.2))
            t_list.append(T.RandomGrayscale(p=0.1))
    else:
        # For evaluation, resize then center crop to get fixed (512, 512) dimensions
        t_list.append(T.Resize(short_size))
        t_list.append(T.CenterCrop(short_size))

    t_list.append(T.ToTensor())

    if normalization.lower() == "vit":
        norm_mean, norm_std = [0.5]*3, [0.5]*3
        t_list.append(T.Normalize(mean=norm_mean, std=norm_std))
    elif normalization.lower() == "imagenet":
        t_list.append(T.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]))
    return T.Compose(t_list)


class SegmentationDataset(Dataset):
    """
    A minimal example of a segmentation dataset loader that:
      - loads images & segmentation masks
      - can optionally cache in memory
      - applies transformations
    """
    def __init__(self, dataset_name, split="train", root_dir=None,
                 transform=None, target_transform=None, num_samples=50,
                 cache_data=False):
        super().__init__()
        self.dataset_name = dataset_name.lower()
        self.split = split.lower()
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        self.num_samples = num_samples
        self.cache_data = cache_data

        self.img_paths = []
        self.mask_paths = []

        if self.dataset_name == "toy":
            self.samples = [None]*num_samples
        elif self.dataset_name == "cityscapes":
            # cityscapes: {root}/leftImg8bit/{split}/{city}/*.png
            left_dir = os.path.join(root_dir, "leftImg8bit", self.split)
            gt_dir = os.path.join(root_dir, "gtFine", self.split)
            for city in sorted(os.listdir(left_dir)):
                city_img_path = os.path.join(left_dir, city)
                city_msk_path = os.path.join(gt_dir, city)
                if not os.path.isdir(city_img_path):
                    continue
                for fn in sorted(os.listdir(city_img_path)):
                    if fn.endswith("_leftImg8bit.png"):
                        base = fn.replace("_leftImg8bit.png", "")
                        mfn = base + "_gtFine_labelIds.png"
                        self.img_paths.append(os.path.join(city_img_path, fn))
                        self.mask_paths.append(os.path.join(city_msk_path, mfn))
        elif self.dataset_name == "voc2007":
            segfile = os.path.join(root_dir, "VOCdevkit", "VOC2007",
                                   "ImageSets", "Segmentation", f"{self.split}.txt")
            lines = []
            with open(segfile, "r") as f:
                lines = [l.strip() for l in f if l.strip()]
            for img_id in lines:
                jpg = os.path.join(root_dir, "VOCdevkit", "VOC2007", "JPEGImages", img_id + ".jpg")
                msk = os.path.join(root_dir, "VOCdevkit", "VOC2007", "SegmentationClass", img_id + ".png")
                self.img_paths.append(jpg)
                self.mask_paths.append(msk)
        elif self.dataset_name == "ade20k":
            if self.split == "train":
                self.img_dir = os.path.join(root_dir, "images", "training")
                self.msk_dir = os.path.join(root_dir, "annotations", "training")
            else:
                self.img_dir = os.path.join(root_dir, "images", "validation")
                self.msk_dir = os.path.join(root_dir, "annotations", "validation")
            imgs = sorted(os.listdir(self.img_dir))
            msks = sorted(os.listdir(self.msk_dir))
            for i, m in zip(imgs, msks):
                self.img_paths.append(os.path.join(self.img_dir, i))
                self.mask_paths.append(os.path.join(self.msk_dir, m))
        elif self.dataset_name == "pascalcontext":
            voc2010 = os.path.join(root_dir, "VOCdevkit", "VOC2010")
            seg_file = os.path.join(voc2010, "ImageSets", "SegmentationContext", f"{self.split}.txt")
            with open(seg_file, "r") as f:
                lines = [l.strip() for l in f if l.strip()]
            for img_id in lines:
                jpg = os.path.join(voc2010, "JPEGImages", img_id + ".jpg")
                msk = os.path.join(voc2010, "SegmentationClassContext", img_id + ".png")
                self.img_paths.append(jpg)
                self.mask_paths.append(msk)

        # Optionally cache
        self.cached = False
        if self.cache_data and len(self.img_paths) > 0:
            self.cached = True
            self.cache = []
            from PIL import Image
            print(f"[{dataset_name}] Caching {len(self.img_paths)} samples in memory...")
            for i in range(len(self.img_paths)):
                try:
                    img = Image.open(self.img_paths[i]).convert("RGB")
                    mask = Image.open(self.mask_paths[i])
                    self.cache.append((img, mask))
                except OSError as e:
                    print(f"Warning: skipping corrupted image {self.img_paths[i]} => {e}")

    def __len__(self):
        if self.dataset_name == "toy":
            return len(self.samples)
        return len(self.img_paths)

    def __getitem__(self, idx):
        from PIL import Image, ImageDraw
        import random

        if self.dataset_name == "toy":
            # generate random shapes
            img = Image.new("RGB", (128, 128), (0, 0, 0))
            mask = Image.new("L", (128, 128), 0)
            draw = ImageDraw.Draw(img)
            shape = random.choice(["rectangle", "ellipse"])
            col = tuple(np.random.randint(0, 256, size=3))
            label = random.randint(1, 3)  # class label in [1..3]
            x1, y1 = np.random.randint(0, 64, size=2)
            x2 = x1 + np.random.randint(20, 64)
            y2 = y1 + np.random.randint(20, 64)
            if shape == "rectangle":
                draw.rectangle([x1, y1, x2, y2], fill=col)
                ImageDraw.Draw(mask).rectangle([x1, y1, x2, y2], fill=label)
            else:
                draw.ellipse([x1, y1, x2, y2], fill=col)
                ImageDraw.Draw(mask).ellipse([x1, y1, x2, y2], fill=label)
        else:
            # Load from disk or cache
            if self.cached:
                img, mask = self.cache[idx]
            else:
                from PIL import Image
                img = Image.open(self.img_paths[idx]).convert("RGB")
                mask = Image.open(self.mask_paths[idx])

        if self.transform:
            img = self.transform(img)
        else:
            img = T.ToTensor()(img)

        # By default, we match mask dims to image dims
        if self.target_transform:
            mask = self.target_transform(mask)
        else:
            mask = mask.resize((img.shape[2], img.shape[1]), resample=Image.NEAREST)
            mask_np = np.array(mask, dtype=np.int64)
            # For ADE20K official labeling: shift by -1, ignoring 0 => 255
            if self.dataset_name == "ade20k":
                # label 0 => ignore, 1..150 => 0..149
                mask_np = np.where(mask_np == 0, 255, mask_np - 1)
            mask = torch.from_numpy(mask_np)
        return img, mask

########################################
# 3) Optional Mixup / CutMix for segmentation
########################################
def mixup_segmentation_naive(batch, alpha=0.2):
    """Naive Mixup for segmentation: blend images; use random alpha to combine labels by choosing a random proportion."""
    imgs, masks = zip(*batch)
    imgs = torch.stack(imgs, 0)
    masks = torch.stack(masks, 0)
    b = imgs.size(0)
    if b < 2:
        return imgs, masks
    idx = torch.randperm(b)
    lam = float(np.random.beta(alpha, alpha))
    mixed_imgs = lam * imgs + (1 - lam) * imgs[idx]
    # For segmentations, there's no single numeric mixing. We can mask out a fraction => approximate approach
    rand_mask = (torch.rand_like(masks.float()) < lam)
    mixed_masks = torch.where(rand_mask, masks, masks[idx])
    return mixed_imgs, mixed_masks

def cutmix_segmentation_naive(batch, alpha=0.2):
    """CutMix for segmentation: copy & paste a rectangle from another image’s region into current image & mask."""
    imgs, masks = zip(*batch)
    imgs = torch.stack(imgs, 0)
    masks = torch.stack(masks, 0)
    b, c, h, w = imgs.shape
    if b < 2:
        return imgs, masks
    idx = torch.randperm(b)
    lam = float(np.random.beta(alpha, alpha))

    rw = int(w * np.sqrt(1 - lam))
    rh = int(h * np.sqrt(1 - lam))
    rx = np.random.randint(0, w)
    ry = np.random.randint(0, h)
    x1 = np.clip(rx - rw // 2, 0, w)
    x2 = np.clip(rx + rw // 2, 0, w)
    y1 = np.clip(ry - rh // 2, 0, h)
    y2 = np.clip(ry + rh // 2, 0, h)

    imgs[:, :, y1:y2, x1:x2] = imgs[idx, :, y1:y2, x1:x2]
    masks[:, y1:y2, x1:x2] = masks[idx, y1:y2, x1:x2]
    return imgs, masks

########################################
# 4) (Optional) NVIDIA DALI pipeline for GPU-based data loading
########################################
try:
    from nvidia.dali.pipeline import Pipeline
    import nvidia.dali.fn as fn
    import nvidia.dali.types as types
    from nvidia.dali.plugin.pytorch import DALIGenericIterator
except ImportError:
    Pipeline = None  # If DALI not installed, fallback to CPU-based.

class SegmentationTrainPipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id,
                 file_root, mask_root, crop_size=512,
                 random_shuffle=True, seed=12):
        super(SegmentationTrainPipeline, self).__init__(batch_size, num_threads, device_id, seed=seed)
        self.input = fn.readers.file(file_root=file_root, random_shuffle=random_shuffle, seed=seed, name="Reader")
        self.mask_input = fn.readers.file(file_root=mask_root, random_shuffle=random_shuffle, seed=seed, name="ReaderMask")
        self.crop_size = crop_size

    def define_graph(self):
        jpegs, _ = self.input()
        masks, _ = self.mask_input()
        images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
        masks = fn.decoders.image(masks, device="cpu", output_type=types.GRAY)
        images = fn.resize(images, resize_x=self.crop_size, resize_y=self.crop_size)
        masks = fn.resize(masks, resize_x=self.crop_size, resize_y=self.crop_size, interp_type=types.INTERP_NN)
        flip_coin = fn.random.coin_flip(probability=0.5)
        images = fn.flip(images, horizontal=flip_coin)
        masks = fn.flip(masks, horizontal=flip_coin)
        images = fn.brightness_contrast(images,
                                        brightness=fn.random.uniform(range=[0.8, 1.2]),
                                        contrast=fn.random.uniform(range=[0.8, 1.2]))
        return images, masks

class SegmentationValPipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id,
                 file_root, mask_root, crop_size=512, random_shuffle=False):
        super(SegmentationValPipeline, self).__init__(batch_size, num_threads, device_id)
        self.input = fn.readers.file(file_root=file_root, random_shuffle=random_shuffle, name="Reader")
        self.mask_input = fn.readers.file(file_root=mask_root, random_shuffle=random_shuffle, name="ReaderMask")
        self.crop_size = crop_size

    def define_graph(self):
        jpegs, _ = self.input()
        masks, _ = self.mask_input()
        images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
        masks = fn.decoders.image(masks, device="cpu", output_type=types.GRAY)
        images = fn.resize(images, resize_x=self.crop_size, resize_y=self.crop_size)
        masks = fn.resize(masks, resize_x=self.crop_size, resize_y=self.crop_size, interp_type=types.INTERP_NN)
        return images, masks


########################################
# 5) Segmenter Model (ViT backbone, plus linear or mask-transformer decoder)
########################################
class MaskTransformerDecoder(nn.Module):
    """
    Minimal mask transformer:
      - We store a class embedding for each of K classes.
      - We run 2 transformer-encoder layers on the concat of patch embeddings + class embeddings.
      - Then compute mask logits via patch_enc dot class_enc^T.
    """
    def __init__(self, emb_dim, num_layers, num_heads, num_classes, dropout=0.1, drop_path_rate=0.0):
        super().__init__()
        self.num_classes = num_classes
        # Class embeddings
        self.class_emb = nn.Parameter(torch.randn(num_classes, emb_dim))
        # Basic stack of transformer-encoder layers
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=emb_dim,
                nhead=num_heads,
                dim_feedforward=4 * emb_dim,
                dropout=dropout,
                activation='gelu',
                batch_first=False
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(emb_dim)

    def forward(self, patch_tokens):
        B, N, E = patch_tokens.shape
        K = self.num_classes
        # Expand class embeddings => shape (B, K, E)
        class_embs = self.class_emb.unsqueeze(0).expand(B, K, E)

        # Concat patch + class => shape (B, N+K, E)
        # but we pass them as seq-first => (N+K, B, E)
        x = torch.cat([patch_tokens, class_embs], dim=1)
        x = x.transpose(0, 1)

        # pass through layers
        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        x = x.transpose(0, 1)
        # separate out patch_enc & class_enc
        patch_enc = x[:, :N, :]    # (B, N, E)
        class_enc = x[:, N:, :]    # (B, K, E)
        # L2 normalize
        patch_enc = F.normalize(patch_enc, dim=-1)
        class_enc = F.normalize(class_enc, dim=-1)
        # logits => (B, N, K)
        logits = torch.bmm(patch_enc, class_enc.transpose(1,2))
        return logits  # shape (B, N, K)

class SegmenterModel(nn.Module):
    """
    Based on timm Vision Transformer backbone, optionally with drop_path,
    and 2 possible decoders: 'linear' or 'mask'.
    """
    def __init__(self,
                 backbone_name="vit_base_patch16_384",
                 num_classes=150,
                 drop_path_rate=0.1,
                 decoder="mask",
                 img_size=512,
                 backbone_dropout=0.0,
                 decoder_drop_path_rate=0.0,
                 distilled=False,
                 version="normal"):
        super().__init__()
        self.num_classes = num_classes
        self.decoder_type = decoder
        self.distilled = distilled
        self.version = version

        # Create the ViT model from timm
        self.backbone = timm.create_model(
            backbone_name, pretrained=True,
            drop_path_rate=drop_path_rate,
            img_size=(img_size, img_size),
        )
        # Remove final classifier
        if hasattr(self.backbone, "drop_rate"):
            self.backbone.drop_rate = backbone_dropout
        if hasattr(self.backbone, "head"):
            self.backbone.head = nn.Identity()
        if hasattr(self.backbone, "classifier"):
            self.backbone.classifier = nn.Identity()

        # some vit_xxx have "forward_features" or "forward"
        self.backbone.forward = self.backbone.forward_features

        # figure out embedding dimension
        emb_dim = getattr(self.backbone, 'num_features', None)
        if emb_dim is None:
            # fallback for certain older timm ViT definitions
            emb_dim = self.backbone.embed_dim

        # Create the decoder
        if decoder == "linear":
            # A simple 1x1 conv on patch embeddings => upsample to full
            self.seg_head = nn.Conv2d(emb_dim, num_classes, kernel_size=1)
        else:
            # Our mask transformer with 2 layers
            self.mask_transformer = MaskTransformerDecoder(
                emb_dim, num_layers=2, num_heads=12,
                num_classes=num_classes, dropout=0.1,
                drop_path_rate=decoder_drop_path_rate
            )

    def forward(self, x):
        B, C, H, W = x.shape
        # forward backbone => shape (B, N, E)
        # typical for timm's ViT is (B, (1+N), E) if it includes a CLS token
        out = self.backbone(x)
        # out shape => (B, 1+N, E) if there's a CLS token
        # remove the CLS token if present
        seq_len = out.shape[1]
        if seq_len > (H*W)//(16*16)+1:
            # depends on patch size. Some timm models keep more tokens.
            out = out[:, 1:, :]  # remove CLS
        elif seq_len == (H*W)//(16*16)+1:
            out = out[:, 1:, :]

        # Now => shape (B, N, E)
        B, N, E = out.shape
        # Suppose patch layout is sqrt(N) x sqrt(N)
        dpr = int(math.sqrt(N))
        # Ensure it is square
        if dpr*dpr != N:
            raise ValueError("Non-square patch layout => check your image/patch size alignment.")

        # (B, E, dpr, dpr)
        patch_tokens = out.transpose(1,2).reshape(B, E, dpr, dpr)

        if self.decoder_type == "linear":
            # 1) linear decode
            logits_patch = self.seg_head(patch_tokens)
            # 2) upsample to (H, W)
            logits = F.interpolate(logits_patch, size=(H, W), mode='bilinear', align_corners=False)
            return logits
        else:
            # Mask transformer
            patch_seq = patch_tokens.flatten(2).transpose(1,2)  # => (B, N, E)
            logits_patch = self.mask_transformer(patch_seq)     # => (B, N, K)
            # shape => (B, K, N)
            logits_patch = logits_patch.transpose(1,2).reshape(B, self.num_classes, dpr, dpr)
            # upsample
            logits = F.interpolate(logits_patch, size=(H, W), mode='bilinear', align_corners=False)
            return logits


########################################
# 6) Optimizer, Scheduler, Evaluate
########################################
def poly_lr_scheduler(optimizer, base_lr, curr_iter, max_iter, power=0.9,
                      iter_warmup=0.0, min_lr=0.0):
    """
    Poly LR schedule: lr = base_lr * (1 - t)^(power), with optional warmup
    """
    if curr_iter < iter_warmup and iter_warmup > 0:
        lr = base_lr * (curr_iter / iter_warmup)
    else:
        effective_iter = curr_iter - iter_warmup if iter_warmup > 0 else curr_iter
        effective_max  = max_iter - iter_warmup if iter_warmup > 0 else max_iter
        lr = base_lr * (1 - effective_iter / effective_max) ** power

    lr = max(lr, min_lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def evaluate(model, loader, num_classes=150, ignore_index=255, device='cuda',
             im_size=512, window_size=512, window_stride=512):
    model.eval()
    total_correct = 0
    total_pixels = 0
    conf_mat = np.zeros((num_classes, num_classes), dtype=np.int64)

    with torch.no_grad():
        for batch in loader:
            # Check if batch comes from DALI (dict) or a standard DataLoader (tensor tuple)
            if isinstance(batch, (list, tuple)) and isinstance(batch[0], dict):
                imgs = batch[0]["img"]
                masks = batch[0]["mask"]
            else:
                imgs, masks = batch

            imgs = imgs.to(device)
            masks = masks.to(device)

            logits = model(imgs)  # shape (B, K, H, W)
            preds = logits.argmax(dim=1)  # (B, H, W)

            valid = (masks != ignore_index)
            total_correct += preds.eq(masks).logical_and(valid).sum().item()
            total_pixels += valid.sum().item()

            # Confusion matrix
            p_cpu = preds[valid].cpu().numpy()
            m_cpu = masks[valid].cpu().numpy()
            combined = m_cpu * num_classes + p_cpu
            bin_count = np.bincount(combined, minlength=num_classes*num_classes)
            conf_mat += bin_count.reshape(num_classes, num_classes)

    pix_acc = total_correct / (total_pixels + 1e-10)
    IoUs = []
    for c in range(num_classes):
        gt = conf_mat[c, :].sum()
        dt = conf_mat[:, c].sum()
        inter = conf_mat[c, c]
        union = gt + dt - inter
        if union > 0:
            IoUs.append(inter / union)
    mIoU = np.mean(IoUs) if len(IoUs) > 0 else 0
    model.train()
    return pix_acc, mIoU



########################################
# 7) Main Training Loop
########################################
def main():
    import argparse
    # Filter out potential Jupyter args (like -f or .json)
    sys.argv = [arg for arg in sys.argv if not (arg.startswith('-f') or arg.endswith('.json'))]

    parser = argparse.ArgumentParser()
    # Data & training schedule
    parser.add_argument("--dataset_name", default="ade20k",
                        choices=["toy","ade20k","voc2007","cityscapes","pascalcontext"])
    parser.add_argument("--num_classes", default=150, type=int)
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--grad_accum_steps", default=1, type=int)
    parser.add_argument("--num_epochs", default=64, type=int)
    parser.add_argument("--start_epoch", default=0, type=int)
    parser.add_argument("--eval_freq", default=2, type=int)
    parser.add_argument("--total_iterations", default=160000, type=int)  # e.g. 161728
    parser.add_argument("--power", default=0.9, type=float)
    parser.add_argument("--iter_warmup", default=0.0, type=float)
    parser.add_argument("--min_lr", default=0.0, type=float)
    parser.add_argument("--clip_grad", default=0.0, type=float)
    # Optim
    parser.add_argument("--lr", default=0.001, type=float)
    parser.add_argument("--weight_decay", default=0.0, type=float)
    parser.add_argument("--momentum", default=0.9, type=float)
    # Model
    parser.add_argument("--model_name", default="vit_base_patch16_384")
    parser.add_argument("--stoch_depth", default=0.1, type=float)
    parser.add_argument("--backbone_dropout", default=0.0, type=float)
    parser.add_argument("--decoder", default="mask", choices=["linear","mask"])
    parser.add_argument("--decoder_drop_path_rate", default=0.0, type=float)
    parser.add_argument("--distilled", action="store_true")
    parser.add_argument("--version", default="normal")
    parser.add_argument("--backbone_img_size", default=512, type=int)
    # Augment
    parser.add_argument("--advanced_aug", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--dataset_normalization", default="vit")
    parser.add_argument("--cache_data", action="store_true")
    parser.add_argument("--resume", action="store_true", default=True)
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--ddp", action="store_true", default=False)
    parser.add_argument("--use_dali", action="store_true", default=False)

    # Inference
    parser.add_argument("--inference_im_size", default=512, type=int)
    parser.add_argument("--window_size", default=512, type=int)
    parser.add_argument("--window_stride", default=512, type=int)

    args = parser.parse_args()
    if args.debug:
        print("Running in debug mode...")

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Set checkpoint path to Google Drive
    checkpoint_path = "/content/drive/MyDrive/your_folder/checkpoint.pth"
    log_path = "/content/drive/MyDrive/your_folder/training_logs.txt"

    print(f"Using batch size {args.batch_size} and grad_accum={args.grad_accum_steps}.")

    ########################
    # DDP setup
    ########################
    if args.ddp:
        import torch.distributed as dist
        dist.init_process_group(backend="nccl", init_method="env://")
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        device = torch.device("cuda", local_rank)
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if device.type == "cuda":
        gpu_props = torch.cuda.get_device_properties(device)
        gpu_mem_gb = gpu_props.total_memory/(1024**3)
        print(f"Detected GPU: {gpu_props.name} with total mem {gpu_mem_gb:.1f} GB")

    # auto enable caching if enough system RAM
    global free_ram_gb
    if not args.cache_data and free_ram_gb > 20:
        args.cache_data = True
        print("Automatically enabling data caching (cache_data=True) due to high available system RAM.")

    # fetch dataset if needed
    dsname = args.dataset_name.lower()
    if dsname == "toy":
        root = None
    elif dsname == "ade20k":
        root = fetch_ade20k_via_hf("datasets/ADE20K_HF")
    elif dsname == "cityscapes":
        root = fetch_cityscapes_official("datasets/Cityscapes_official")
    elif dsname == "voc2007":
        root = fetch_voc2007_official("datasets/VOC2007_official")
    elif dsname == "pascalcontext":
        root = fetch_pascal_context_official("datasets/PascalContext_official")
    else:
        raise ValueError("Unknown dataset:", dsname)

    # num_workers
    num_workers = min(cpu_physical, 16)
    print(f"Setting DataLoader num_workers = {num_workers}")

    # Build or DALI
    if args.use_dali:
        if Pipeline is None:
            raise ImportError("Please install nvidia-dali to use --use_dali")
        device_id = device.index if device.type=='cuda' else 0
        # Hard-coded for e.g. ADE20K layout
        train_file_root = os.path.join(root, "images", "training")
        train_mask_root = os.path.join(root, "annotations", "training")
        val_file_root   = os.path.join(root, "images", "validation")
        val_mask_root   = os.path.join(root, "annotations", "validation")

        train_pipeline = SegmentationTrainPipeline(
            batch_size=args.batch_size,
            num_threads=num_workers,
            device_id=device_id,
            file_root=train_file_root,
            mask_root=train_mask_root,
            crop_size=args.backbone_img_size,
            random_shuffle=True)
        train_pipeline.build()
        train_loader = DALIGenericIterator(
            train_pipeline, output_map=["img","mask"],
            size=train_pipeline.epoch_size("Reader"), auto_reset=True)

        val_pipeline = SegmentationValPipeline(
            batch_size=1,
            num_threads=num_workers,
            device_id=device_id,
            file_root=val_file_root,
            mask_root=val_mask_root,
            crop_size=args.inference_im_size,
            random_shuffle=False)
        val_pipeline.build()
        val_loader = DALIGenericIterator(
            val_pipeline, output_map=["img","mask"],
            size=val_pipeline.epoch_size("Reader"),
            auto_reset=True)
    else:
        # standard PyTorch dataset
        train_tf = build_transforms(dsname, train=True,
                                    advanced_aug=args.advanced_aug,
                                    normalization=args.dataset_normalization)
        val_tf   = build_transforms(dsname, train=False,
                                    advanced_aug=False,
                                    normalization=args.dataset_normalization)

        if dsname=="toy":
            train_ds = SegmentationDataset("toy", "train",
                                           transform=train_tf,
                                           cache_data=args.cache_data)
            val_ds   = SegmentationDataset("toy", "val",
                                           transform=val_tf,
                                           cache_data=args.cache_data)
        else:
            train_ds = SegmentationDataset(dsname, "train", root,
                                           transform=train_tf,
                                           cache_data=args.cache_data)
            val_ds   = SegmentationDataset(dsname, "val", root,
                                           transform=val_tf,
                                           cache_data=args.cache_data)

        if args.ddp:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
            val_sampler   = torch.utils.data.distributed.DistributedSampler(val_ds, shuffle=False)
        else:
            train_sampler, val_sampler = None, None

        train_loader = DataLoader(train_ds, batch_size=args.batch_size,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler, drop_last=True,
                                  num_workers=num_workers, pin_memory=True,
                                  persistent_workers=True, prefetch_factor=4)

        val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
                                sampler=val_sampler, drop_last=False,
                                num_workers=num_workers, pin_memory=True,
                                persistent_workers=True, prefetch_factor=4)

    # Build model
    model = SegmenterModel(
        backbone_name=args.model_name,
        num_classes=args.num_classes,
        drop_path_rate=args.stoch_depth,
        decoder=args.decoder,
        img_size=args.backbone_img_size,
        backbone_dropout=args.backbone_dropout,
        decoder_drop_path_rate=args.decoder_drop_path_rate,
        distilled=args.distilled,
        version=args.version
    ).to(device)

    # channels_last if GPU
    if device.type == "cuda":
        model = model.to(memory_format=torch.channels_last)

    # DDP
    if args.ddp:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[device.index] if device.type=='cuda' else None
        )

    # Optional compile (PyTorch 2.0+)
    if hasattr(torch, 'compile'):
        model = torch.compile(model)

    # Optim
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # AMP
    scaler = torch.cuda.amp.GradScaler() if device.type=='cuda' else None

    # Possibly resume
    global_step = 0
    best_mIoU = 0.0
    start_epoch = args.start_epoch
    # checkpoint_path = "checkpoint.pth"
    if args.resume and os.path.isfile(checkpoint_path):
        rank0 = (not args.ddp) or (args.ddp and torch.distributed.get_rank()==0)
        if rank0:
            print("Resuming from checkpoint.pth ...")
        ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        global_step = ckpt.get("global_step", 0)
        start_epoch = ckpt.get("epoch", 0)
        best_mIoU   = ckpt.get("best_mIoU", 0.0)

    # Train
    rank0 = (not args.ddp) or (args.ddp and torch.distributed.get_rank()==0)
    if rank0:
        print(f"Start training for {args.num_epochs} epochs (total iterations: {args.total_iterations}).")

    for epoch in range(start_epoch, args.num_epochs):
        if args.ddp and (not args.use_dali):
            train_loader.sampler.set_epoch(epoch)

        model.train()
        epoch_loss = 0.0
        num_batches = 0

        iterator = train_loader if not rank0 else tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
        for batch in iterator:
            if args.use_dali:
                # DALI returns list of dicts
                imgs = batch[0]["img"]
                masks = batch[0]["mask"]
            else:
                imgs, masks = batch

            imgs = imgs.to(device, memory_format=torch.channels_last)
            masks = masks.to(device)

            # optional mixup/cutmix
            if args.mixup and args.cutmix:
                raise ValueError("Pick either --mixup or --cutmix, not both.")
            if args.mixup:
                imgs, masks = mixup_segmentation_naive(list(zip(imgs, masks)), alpha=0.2)
            if args.cutmix:
                imgs, masks = cutmix_segmentation_naive(list(zip(imgs, masks)), alpha=0.2)

            # set LR
            poly_lr_scheduler(optimizer, args.lr, global_step, args.total_iterations,
                              power=args.power, iter_warmup=args.iter_warmup, min_lr=args.min_lr)

            # forward + backward
            with torch.cuda.amp.autocast(enabled=(scaler is not None)):
                logits = model(imgs)
                loss = nn.CrossEntropyLoss(ignore_index=255)(logits, masks) / args.grad_accum_steps

            if scaler is not None:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            epoch_loss += loss.item()*args.grad_accum_steps
            num_batches += 1
            global_step += 1

            if rank0 and isinstance(iterator, tqdm):
                iterator.set_postfix({
                    'loss': f"{loss.item()*args.grad_accum_steps:.4f}",
                    'lr':   f"{optimizer.param_groups[0]['lr']:.6f}"
                })

            if global_step % args.grad_accum_steps == 0:
                if args.clip_grad>0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                if scaler is not None:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()

            if global_step >= args.total_iterations:
                break

        avg_loss = epoch_loss / max(num_batches, 1)
        if rank0:
            print(f"Epoch {epoch} completed. Avg Loss: {avg_loss:.4f}")

        # Evaluate
        if ((epoch+1)%args.eval_freq==0 or (epoch == args.num_epochs-1)) and rank0:
            pixacc, miou = evaluate(model, val_loader,
                                    num_classes=args.num_classes,
                                    ignore_index=255,
                                    device=device,
                                    im_size=args.inference_im_size,
                                    window_size=args.window_size,
                                    window_stride=args.window_stride)
            print(f"Evaluation at epoch {epoch}: pixAcc={pixacc:.4f}  mIoU={miou:.4f}")
            # save best
            if miou > best_mIoU:
                best_mIoU = miou
                print(f"New best mIoU={best_mIoU:.4f} => saving checkpoint.")
                torch.save({
                    "epoch": epoch,
                    "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "best_mIoU": best_mIoU,
                }, checkpoint_path)

        if global_step >= args.total_iterations:
            if rank0:
                print("Reached total iteration limit. Stopping.")
            break

    if rank0:
        print("Training complete!")
        print(f"Best mIoU => {best_mIoU:.4f}")
        print("Done.")

if __name__ == "__main__":
    main()


Detected CPU: 12 logical cores, 6 physical cores.
Available system RAM: 36.9 GB
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using batch size 32 and grad_accum=1.
Detected GPU: NVIDIA A100-SXM4-40GB with total mem 39.6 GB
Automatically enabling data caching (cache_data=True) due to high available system RAM.
[ADE20K] Found => skipping download.
Setting DataLoader num_workers = 6
[ade20k] Caching 20210 samples in memory...
