<a href="https://colab.research.google.com/github/uma-mahesh-24/CS-254-Lab/blob/main/segformer_hf_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# %% [imports]

In [None]:
import os
import glob
import math
import json
import random
from typing import Optional, Dict, Any

from google.colab import drive

import numpy as np
from sklearn.metrics import mean_squared_error, r2_score

import torch
import torch.nn.functional as F

from transformers import (
    SegformerConfig,
    SegformerForSemanticSegmentation,
    TrainingArguments,
    Trainer,
)

from transformers.models.segformer.modeling_segformer import SemanticSegmenterOutput

# NEW: Hugging Face Datasets for streaming / memory-mapped data
from datasets import Dataset, DatasetDict, load_dataset, Features, Array2D, Value, IterableDataset

# %% [mount drive]

In [None]:
# Mount Google Drive
drive.mount("/content/drive", force_remount=True)

# ==== PATHS (edit if needed) ====
root_dir = "/content/drive/MyDrive"  # Base folder in Drive
data_dir = os.path.join(root_dir, "corrected_dataset")

inputs_dir = os.path.join(data_dir, "inputs")
targets_dir = os.path.join(data_dir, "targets")
weights_dir = os.path.join(data_dir, "weights30_drastic")

output_dir = os.path.join(root_dir, "train-xii-weighted-sse")
os.makedirs(output_dir, exist_ok=True)

print(f"Data root: {data_dir}")
print(f"Output dir: {output_dir}")

Mounted at /content/drive
Data root: /content/drive/MyDrive/corrected_dataset
Output dir: /content/drive/MyDrive/train-xii-weighted-sse


# %% [dataset, new with HF datasets module + random flips]

In [None]:
from typing import List

# NEW: Hugging Face Datasets for streaming / memory-mapped data
from datasets import Dataset, DatasetDict, load_dataset, Features, Array2D, Value, IterableDataset

# Define the data augmentation functions
def random_horizontal_flip(x, y, w):
    """Applies a random horizontal flip to the input, label, and weights."""
    if np.random.rand() < 0.5:
        return np.fliplr(x).copy(), np.fliplr(y).copy(), np.fliplr(w).copy() # Add .copy()
    return x, y, w

def random_vertical_flip(x, y, w):
    """Applies a random vertical flip to the input, label, and weights."""
    if np.random.rand() < 0.5:
        return np.flipud(x).copy(), np.flipud(y).copy(), np.flipud(w).copy() # Add .copy()
    return x, y, w

def npy_generator_augmented(file_list: List[str], inputs_dir: str, targets_dir: str, weights_dir: str,
                            standardize_targets: bool=False, target_mean: Optional[float]=None,
                            target_std: Optional[float]=None):
    """Yields one augmented example at a time, keeping memory usage low."""
    for fname in file_list:
        # Load the data
        x = np.load(os.path.join(inputs_dir, fname)).astype(np.float32)
        y = np.load(os.path.join(targets_dir, fname)).astype(np.float32)
        w = np.load(os.path.join(weights_dir, fname)).astype(np.float32)

        # Apply augmentations (flips are dimension-preserving)
        x, y, w = random_horizontal_flip(x, y, w)
        x, y, w = random_vertical_flip(x, y, w)

        # Standardize targets if enabled
        if standardize_targets and target_mean is not None and target_std is not None:
            y = (y - target_mean) / (target_std + 1e-8)

        yield {
            "file_name": fname,
            "pixel_values": x[None, :, :],
            "labels": y[None, :, :],
            "weights": w[None, :, :],
        }

def npy_generator(file_list: List[str], inputs_dir: str, targets_dir: str, weights_dir: str,
                            standardize_targets: bool=False, target_mean: Optional[float]=None,
                            target_std: Optional[float]=None):
    """Yields one unaugmented example at a time, keeping memory usage low."""
    for fname in file_list:
        # Load the data
        x = np.load(os.path.join(inputs_dir, fname)).astype(np.float32)
        y = np.load(os.path.join(targets_dir, fname)).astype(np.float32)
        w = np.load(os.path.join(weights_dir, fname)).astype(np.float32)

        # Standardize targets if enabled
        if standardize_targets and target_mean is not None and target_std is not None:
            y = (y - target_mean) / (target_std + 1e-8)

        yield {
            "file_name": fname,
            "pixel_values": x[None, :, :],
            "labels": y[None, :, :],
            "weights": w[None, :, :],
        }


def create_hf_datasets(inputs_dir: str, targets_dir: str, weights_dir: str,
                       standardize_targets: bool=False, target_mean: Optional[float]=None,
                       target_std: Optional[float]=None) -> DatasetDict:
    # Collect file list
    all_files = sorted([os.path.basename(p) for p in glob.glob(os.path.join(inputs_dir, "*.npy"))])
    split_idx = int(0.9 * len(all_files))
    train_files = all_files[:split_idx]
    val_files = all_files[split_idx:]

    # Use IterableDataset.from_generator with the augmented generator
    train_ds = IterableDataset.from_generator(
        npy_generator_augmented,
        gen_kwargs=dict(
            file_list=train_files,
            inputs_dir=inputs_dir,
            targets_dir=targets_dir,
            weights_dir=weights_dir,
            standardize_targets=standardize_targets,
            target_mean=target_mean,
            target_std=target_std,
        )
    )

    # Use a separate, unaugmented generator for the validation set
    val_ds = IterableDataset.from_generator(
        npy_generator,
        gen_kwargs=dict(
            file_list=val_files,
            inputs_dir=inputs_dir,
            targets_dir=targets_dir,
            weights_dir=weights_dir,
            standardize_targets=standardize_targets,
            target_mean=target_mean,
            target_std=target_std,
        )
    )

    return DatasetDict({"train": train_ds, "validation": val_ds})

# %% [split]

In [None]:
use_standardization = False
t_mean, t_std = 0.0, 1.0

if use_standardization:
    stats_path = os.path.join(targets_dir, "standardization_vals.json")
    if not os.path.exists(stats_path):
        # compute stats only once on training files
        all_files = sorted([f for f in os.listdir(inputs_dir) if f.endswith(".npy")])
        split_idx = int(0.9 * len(all_files))
        train_files = all_files[:split_idx]

        sample_vals = []
        for fname in train_files:
            y = np.load(os.path.join(targets_dir, fname)).astype(np.float32)
            valid = y[y >= 0]
            sample_vals.append(valid)
        sample_vals = np.concatenate(sample_vals)
        t_mean = float(sample_vals.mean())
        t_std  = float(sample_vals.std() + 1e-8)

        with open(stats_path, "w") as f:
            json.dump({"t_mean": t_mean, "t_std": t_std}, f, indent=4)

    with open(stats_path, "r") as f:
        stats = json.load(f)
    t_mean, t_std = stats["t_mean"], stats["t_std"]

hf_datasets = create_hf_datasets(
    inputs_dir=inputs_dir,
    targets_dir=targets_dir,
    weights_dir=weights_dir,
    standardize_targets=use_standardization,
    target_mean=t_mean,
    target_std=t_std,
)

train_ds = hf_datasets["train"]
val_ds = hf_datasets["validation"]

print("HF IterableDataset created (streaming). Train and validation will be loaded on-the-fly.")

HF IterableDataset created (streaming). Train and validation will be loaded on-the-fly.


# %% [custom loss functions]

In [None]:
def weighted_sse_loss(pred, target, weights):
    diff = (pred - target) ** 2
    diff = diff * weights  # simpler and avoids creating a new tensor via torch.mul

    if diff.numel() == 0:
        return torch.tensor(0.0, device=pred.device)
    return diff.sum()  # still sum — it's SSE

def weighted_mse_loss(pred, target, weights):
    diff = (pred - target) ** 2
    diff = diff * weights  # simpler and avoids creating a new tensor via torch.mul
    mask = target >= 0

    if diff.numel() == 0:
        return torch.tensor(0.0, device=pred.device)
    return diff.sum() / mask.sum()  # taking mean

def masked_sse_loss_no_weights(pred, target, weights):
    """
    Computes SSE loss using a binary mask derived from the target, ignoring weights.
    Assumes target >= 0 is the valid mask.
    """
    mask = target >= 0
    diff = (pred - target) ** 2
    masked_diff = diff[mask]  # Apply the mask

    if masked_diff.numel() == 0:
        return torch.tensor(0.0, device=pred.device)
    return masked_diff.sum()  # Sum the masked squared differences

def masked_mse_loss(pred, target):
    mask = target >= 0
    diff = (pred - target) ** 2
    diff = diff * mask  # boolean mask auto-broadcasts

    if mask.sum() == 0:
        return torch.tensor(0.0, device=pred.device)

    return diff.sum() / mask.sum()  # use *mean* to make it comparable across batch sizes

def masked_mean_abs_loss(pred, target):
    mask = target >= 0
    diff = torch.abs(pred - target)

    if mask.sum() == 0:
        return torch.tensor(0.0, device=pred.device)

    return diff[mask].mean()

# %% [model]

In [None]:
class SegformerForPixelRegression(SegformerForSemanticSegmentation):
    """
    Wrap SegformerForSemanticSegmentation but treat it as a per-pixel regression model:
    - num_labels=1
    - loss = MSE between logits and labels (both shape: [B,1,H,W])
    """
    def forward(self, pixel_values: torch.FloatTensor, labels: Optional[torch.FloatTensor] = None, weights: Optional[torch.FloatTensor] = None, **kwargs):
        valid_keys = ["output_attentions", "output_hidden_states", "return_dict"]
        filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_keys}

        outputs = super().forward(pixel_values=pixel_values, labels=None, **filtered_kwargs)
        logits = outputs.logits  # [B,1,H/4,W/4]

        # 🔧 resize logits back to input size (e.g., HxW)
        logits = F.interpolate(logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False)

        loss = None
        if labels is not None and weights is not None:
            if labels.ndim == 3:
                labels = labels.unsqueeze(1)
            if weights.ndim == 3:
                weights = weights.unsqueeze(1)
            # loss = F.mse_loss(logits, labels)
            # loss = weighted_sse_loss(logits, labels, weights)
            # loss = weighted_mse_loss(logits, labels, weights)
            loss = masked_sse_loss_no_weights(logits, labels, weights)

        return SemanticSegmenterOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# Start with a small SegFormer; you can scale to b1/b2/b3 etc.
base_ckpt = "nvidia/segformer-b0-finetuned-ade-512-512"

# configure single-channel input + single-channel output
cfg = SegformerConfig.from_pretrained(base_ckpt)
cfg.num_channels = 1          # single-channel inputs
cfg.num_labels = 1            # one regression map
cfg.ignore_mismatched_sizes = True

# load pretrained (encoder/decoder) weights where sizes match
model = SegformerForPixelRegression.from_pretrained(
    base_ckpt,
    config=cfg,
    ignore_mismatched_sizes=True,
)

Some weights of SegformerForPixelRegression were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([1, 256, 1, 1]) in the model instantiated
- segformer.encoder.patch_embeddings.0.proj.weight: found shape torch.Size([32, 3, 7, 7]) in the checkpoint and torch.Size([32, 1, 7, 7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# %% [trainer utils]

In [None]:
def data_collator(batch):
    # Since all images are same size (512x512), no need for dynamic padding.
    return {
        "pixel_values": torch.stack([torch.from_numpy(b["pixel_values"]) for b in batch]),  # [B,1,H,W]
        "labels": torch.stack([torch.from_numpy(b["labels"]) for b in batch]),              # [B,1,H,W]
        "weights": torch.stack([torch.from_numpy(b["weights"]) for b in batch]),            # [B,1,H,W]
    }


def compute_metrics(eval_pred):
    preds, labels = eval_pred

    # Convert to numpy if tensors (HF sometimes gives np.ndarray, sometimes torch.Tensor)
    if isinstance(preds, torch.Tensor):
        preds = preds.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    # Ensure shapes: [B,1,H,W] -> [B,H,W]
    if preds.ndim == 4 and preds.shape[1] == 1:
        preds = preds[:, 0, :, :]
    if labels.ndim == 4 and labels.shape[1] == 1:
        labels = labels[:, 0, :, :]

    # Mask invalid labels
    mask = labels >= 0
    preds = preds[mask]
    labels = labels[mask]

    # Avoid division by zero in empty mask case
    if preds.size == 0:
        return {"mae": np.nan, "rmse": np.nan, "r2": np.nan}

    # De-standardize if enabled
    if use_standardization:
        preds = preds * (t_std + 1e-8) + t_mean
        labels = labels * (t_std + 1e-8) + t_mean

    mae = float(np.mean(np.abs(preds - labels)))
    rmse = float(np.sqrt(np.mean((preds - labels) ** 2)))

    # r2_score can crash on constant labels (avoid that)
    try:
        r2 = float(r2_score(labels, preds))
    except ValueError:
        r2 = float("nan")

    return {"mae": mae, "rmse": rmse, "r2": r2}

# %% [training args]

In [None]:
args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-4,
    weight_decay=0.01,
    # num_train_epochs=50,              # keep as-is, single continuous run
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    logging_steps=50,
    eval_strategy="steps",      # <- use correct HF arg name
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,               # keep disk usage low
    load_best_model_at_end=True,
    metric_for_best_model="rmse",
    greater_is_better=False,
    fp16=True,
    dataloader_num_workers=2,
    report_to="none",
    dataloader_pin_memory=True,       # slight speed boost
    save_safetensors=True,            # safer + faster checkpoint format
    max_steps=39400 + 7000,
    ignore_data_skip=True
)

trainer = Trainer(
    model=model,
    args=args,
    data_collator=data_collator,
    train_dataset=train_ds,   # (replace with HF Dataset object later)
    eval_dataset=val_ds,      # (replace with HF Dataset object later)
    compute_metrics=compute_metrics,
)


# %% [train]

In [None]:
try:
  trainer.train(resume_from_checkpoint=True)
except Exception as e:
  print(f"Exception occurred: {e}")
  print("Starting training from scratch")
  trainer.train()

Step,Training Loss,Validation Loss,Mae,Rmse,R2
39600,17091.1012,7625.02832,0.053171,0.095347,0.879943
39800,16205.76,7294.237305,0.051199,0.093255,0.885152
40000,28686.63,7914.490234,0.053276,0.097141,0.875383
40200,13614.93,7346.259277,0.050334,0.093587,0.884334
40400,15589.1863,7302.291992,0.049919,0.093308,0.885022
40600,14734.0462,6932.196777,0.048791,0.090914,0.890848
40800,26250.0075,7412.987305,0.050371,0.094011,0.883284
41000,14423.1625,6965.774902,0.048372,0.091132,0.890324
41200,13798.075,6911.824219,0.048373,0.090779,0.891172
41400,14357.5163,6782.726562,0.047937,0.089926,0.893207


# %% [eval]

In [None]:
metrics = trainer.evaluate()
print(metrics)