# TTA Debug Notebook (Aligned with run_tta.py)
This notebook reuses the same pipeline + TTA flow as `debug_tests/run_tta.py` and runs on real data/model.


In [None]:
import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import torch.nn as nn

ROOT = Path.cwd().parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))
SAM2_ROOT = ROOT.parent / "sam2"
if SAM2_ROOT.exists() and str(SAM2_ROOT) not in sys.path:
    sys.path.insert(0, str(SAM2_ROOT))

from configs.pipeline_config import load_pipeline_config
from datasets.dataset import load_dataset
from debug_tests.run_tta import (
    _load_constants,
    run_segmentation_with_info,
    load_tta_config,
)
from image_processings.tta import TTALossWeights, run_tta_from_pool, default_multi_view_augment, apply_lora_to_mask_decoder
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

%matplotlib inline

def _to_uint8(image):
    if image.dtype == np.uint8:
        return image
    img = image.astype(np.float32)
    if img.max() > img.min():
        img = (img - img.min()) / (img.max() - img.min())
    return (img * 255).clip(0, 255).astype(np.uint8)

def _overlay(image, mask, color=(255, 0, 0), alpha=0.4):
    base = _to_uint8(image).copy()
    if base.ndim == 2:
        base = np.repeat(base[..., None], 3, axis=2)
    overlay = base.copy()
    overlay[mask] = color
    return (base * (1 - alpha) + overlay * alpha).astype(np.uint8)

def _save_image(path, image):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.imsave(path, _to_uint8(image))


In [None]:
# Load configs (same as run_tta.py)
constants = _load_constants()
pipeline_cfg = load_pipeline_config(ROOT / constants["pipeline_cfg"])
tta_cfg = load_tta_config(ROOT / "configs" / "tta_config.json")

# Load dataset (same utility as run_tta.py)
images, gt_masks, image_names = load_dataset(
    pipeline_cfg.dataset.name,
    data_root=None,
    target_long_edge=pipeline_cfg.dataset.target_long_edge,
    return_paths=True,
)
sample_idx = 0
image = images[sample_idx]
gt_mask = gt_masks[sample_idx]
sample_name = Path(image_names[sample_idx]).stem if image_names else f"sample_{sample_idx:04d}"

output_dir = ROOT / "assets" / "tta_debug" / sample_name
output_dir.mkdir(parents=True, exist_ok=True)


In [None]:
# Build model + predictor (same as run_tta.py)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_sam2(constants["model_cfg"], constants["checkpoint"], device=device)
lora_cfg = tta_cfg.get("lora", {})
if lora_cfg.get("target") == "mask_decoder":
    apply_lora_to_mask_decoder(
        model,
        r=int(lora_cfg.get("rank", 4)),
        lora_alpha=int(lora_cfg.get("alpha", 8)),
        lora_dropout=float(lora_cfg.get("dropout", 0.0)),
        target_modules=lora_cfg.get("target_modules"),
    )
predictor = SAM2ImagePredictor(model)
predictor.model.to(device)
model.train()

trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = None
if trainable_params:
    opt_cfg = tta_cfg.get("optimizer", {})
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=float(opt_cfg.get("lr", 1e-4)),
        weight_decay=float(opt_cfg.get("weight_decay", 0.0)),
    )
max_grad_norm = float(tta_cfg.get("optimizer", {}).get("max_grad_norm", 1.0))


In [None]:
# 1) Run pipeline once to get prompts + mask pool (aligned with run_tta.py)
base_mask, history, vis_image, segments, info = run_segmentation_with_info(image, pipeline_cfg, predictor)
final_prompts = history[-1].prompts

prompts = {
    "point_coords": final_prompts.points,
    "point_labels": final_prompts.labels,
    "box": None,
    "mask_input": None,
    "multimask_output": pipeline_cfg.sam.multimask_output,
}

_save_image(output_dir / "train_input.png", vis_image)

plt.imshow(_to_uint8(vis_image))
plt.title("Training input (pipeline resized)")
plt.axis("off")
plt.show()


In [None]:
# 2) Candidate/augmentation grid (aligned with TTA augment function)
augment_fn = default_multi_view_augment(
    scales=tta_cfg["augment"]["scales"],
    do_flip=tta_cfg["augment"]["use_flip"],
)
candidates = augment_fn(vis_image)

# Simple score: mean prob over pseudo-mask (if any) after align-back
pseudo_mask = info.get_mask_pool()[0]["mask"] if info.get_mask_pool() else base_mask
candidate_scores = []
for img_aug, align_back in candidates:
    logits, _, _ = predictor.predict(
        point_coords=prompts["point_coords"],
        point_labels=prompts["point_labels"],
        box=prompts["box"],
        mask_input=prompts["mask_input"],
        multimask_output=prompts["multimask_output"],
        return_logits=True,
    )
    probs = 1.0 / (1.0 + np.exp(-logits[0]))
    aligned = align_back(probs)
    score = float(np.mean(aligned[pseudo_mask])) if pseudo_mask.any() else float(np.mean(aligned))
    candidate_scores.append(score)

best_idx = int(np.argmax(candidate_scores)) if candidate_scores else -1

fig, axes = plt.subplots(1, max(1, len(candidates)), figsize=(4 * max(1, len(candidates)), 4))
if len(candidates) == 1:
    axes = [axes]
for idx, (img_aug, _) in enumerate(candidates):
    title = f"candidate {idx}"
    if idx == 0:
        title += " (train input)"
    if idx == best_idx:
        title += " (best)"
    axes[idx].imshow(_to_uint8(img_aug))
    axes[idx].set_title(title)
    axes[idx].axis("off")

plt.tight_layout()
_save_image(output_dir / "candidates_grid.png", np.array(fig.canvas.renderer.buffer_rgba()))
plt.show()


In [None]:
# 3) Run TTA steps with real model, collect losses + outputs
tta_loss_weights = TTALossWeights(
    anchor=float(tta_cfg["loss_weights"]["anchor"]),
    entropy=float(tta_cfg["loss_weights"]["entropy"]),
    consistency=float(tta_cfg["loss_weights"]["consistency"]),
    regularization=float(tta_cfg["loss_weights"].get("regularization", 0.0)),
)

def _optimizer_step(total_loss, _losses):
    if optimizer is None:
        return
    if not isinstance(total_loss, torch.Tensor) or not total_loss.requires_grad:
        return
    optimizer.zero_grad(set_to_none=True)
    total_loss.backward()
    if max_grad_norm > 0:
        nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
    optimizer.step()

step_outputs = []
step_losses = []
for step in range(int(tta_cfg.get("tta_steps", 3))):
    tta_out = run_tta_from_pool(
        predictor,
        vis_image,
        info.get_mask_pool(),
        prompts,
        loss_weights=tta_loss_weights,
        selection_strategy=tta_cfg.get("pseudo_label", {}).get("strategy", "score_top_k"),
        top_k=int(tta_cfg.get("pseudo_label", {}).get("top_k_masks", 3)),
        augment_fn=augment_fn,
        optimizer_step_fn=_optimizer_step,
    )
    step_outputs.append(tta_out["tta_outputs"])
    step_losses.append({"step": step + 1, **{k: float(v) for k, v in tta_out["tta_outputs"].losses.items()}})

(output_dir / "step_losses.json").write_text(json.dumps(step_losses, indent=2), encoding="utf-8")
print(json.dumps(step_losses, indent=2))


In [None]:
# 4) Prediction evolution (step1 -> step3)
prob_maps = []
overlays = []

for out in step_outputs:
    probs = out.student_probs
    if hasattr(probs, "detach"):
        probs = probs.detach().cpu().numpy()
    if probs.ndim == 3 and probs.shape[0] == 1:
        probs = probs[0]
    prob_maps.append(probs)
    overlays.append(_overlay(vis_image, probs > 0.5))

fig, axes = plt.subplots(1, len(prob_maps), figsize=(4 * len(prob_maps), 4))
if len(prob_maps) == 1:
    axes = [axes]
for idx, probs in enumerate(prob_maps, start=1):
    axes[idx - 1].imshow(probs, cmap="magma")
    axes[idx - 1].set_title(f"step {idx} probs")
    axes[idx - 1].axis("off")
plt.tight_layout()
_save_image(output_dir / "step_compare_probs.png", np.array(fig.canvas.renderer.buffer_rgba()))
plt.show()

fig, axes = plt.subplots(1, len(overlays), figsize=(4 * len(overlays), 4))
if len(overlays) == 1:
    axes = [axes]
for idx, overlay in enumerate(overlays, start=1):
    axes[idx - 1].imshow(overlay)
    axes[idx - 1].set_title(f"step {idx} overlay")
    axes[idx - 1].axis("off")
plt.tight_layout()
_save_image(output_dir / "step_compare_overlay.png", np.array(fig.canvas.renderer.buffer_rgba()))
plt.show()
