# TTA Debug Notebook
Structured visualisation to inspect TTA training input, step losses, and prediction evolution.


In [None]:
import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import sys

ROOT = Path.cwd().parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))
SAM2_ROOT = ROOT.parent / "sam2"
if SAM2_ROOT.exists() and str(SAM2_ROOT) not in sys.path:
    sys.path.insert(0, str(SAM2_ROOT))

from configs.pipeline_config import load_pipeline_config
from debug_tests.debug_test import MAIN_DIR, _load_constants
from image_processings.tta import TTAPipeline, TTALossWeights
from image_processings.tta.tta_core import default_multi_view_augment

%matplotlib inline

def _to_uint8(image):
    if image.dtype == np.uint8:
        return image
    img = image.astype(np.float32)
    if img.max() > img.min():
        img = (img - img.min()) / (img.max() - img.min())
    return (img * 255).clip(0, 255).astype(np.uint8)

def _overlay(image, mask, color=(255, 0, 0), alpha=0.4):
    base = _to_uint8(image).copy()
    if base.ndim == 2:
        base = np.repeat(base[..., None], 3, axis=2)
    overlay = base.copy()
    overlay[mask] = color
    return (base * (1 - alpha) + overlay * alpha).astype(np.uint8)

def _save_image(path, image):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.imsave(path, _to_uint8(image))


In [None]:
# Replace these placeholders with actual outputs from the main pipeline if needed.
constants = _load_constants()
pipeline_cfg = load_pipeline_config(MAIN_DIR / constants["pipeline_cfg"])

sample_name = "sample_0001"
output_dir = Path("assets") / "tta_debug" / sample_name

pseudo_mask = np.zeros((64, 64), dtype=bool)  # pseudo label mask
image = np.zeros((64, 64, 3), dtype=np.uint8)  # training input image
prompts = {
    "point_coords": np.array([[32, 32]], dtype=np.float32),
    "point_labels": np.array([1], dtype=np.int64),
    "box": None,
    "mask_input": None,
    "multimask_output": False,
}

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

# Dummy predictor with required API; replace with SAM2 predictor in practice
class DummyPredictor:
    def predict(self, point_coords, point_labels, box, mask_input, multimask_output, return_logits=False):
        h, w = image.shape[:2]
        logits = np.random.randn(1, h, w).astype(np.float32)
        return logits, [1.0], None

predictor = DummyPredictor()


In [None]:
# 1) Training input visualisation (original + augmented candidates)
augment_fn = default_multi_view_augment(scales=(0.75, 1.0, 1.25), do_flip=True)
candidates = augment_fn(image)

# Evaluate candidates with a simple score: mean prob on pseudo-mask after align-back.
candidate_scores = []
for img_aug, align_back in candidates:
    logits, _, _ = predictor.predict(
        point_coords=prompts["point_coords"],
        point_labels=prompts["point_labels"],
        box=prompts["box"],
        mask_input=prompts["mask_input"],
        multimask_output=prompts["multimask_output"],
        return_logits=True,
    )
    probs = sigmoid(logits[0])
    aligned = align_back(probs)
    score = float(np.mean(aligned[pseudo_mask])) if pseudo_mask.any() else float(np.mean(aligned))
    candidate_scores.append(score)

best_idx = int(np.argmax(candidate_scores)) if candidate_scores else -1

fig, axes = plt.subplots(1, max(1, len(candidates)), figsize=(4 * max(1, len(candidates)), 4))
if len(candidates) == 1:
    axes = [axes]

for idx, (img_aug, _) in enumerate(candidates):
    title = f"candidate {idx}"
    if idx == 0:
        title += " (train input)"
    if idx == best_idx:
        title += " (best)"
    axes[idx].imshow(_to_uint8(img_aug))
    axes[idx].set_title(title)
    axes[idx].axis("off")

plt.tight_layout()
_save_image(output_dir / "candidates_grid.png", np.array(fig.canvas.renderer.buffer_rgba()))
plt.show()

_save_image(output_dir / "train_input.png", image)


In [None]:
# 2) Three-step training process with loss breakdown
pipeline = TTAPipeline(
    predictor=predictor,
    loss_weights=TTALossWeights(anchor=1.0, entropy=0.1, consistency=0.5),
    augment_fn=augment_fn,
)

step_outputs = []
for step in range(3):
    out = pipeline.step(image, prompts, pseudo_mask)
    step_outputs.append(out)

step_losses = []
for idx, out in enumerate(step_outputs, start=1):
    loss_dict = {k: float(v) for k, v in out.losses.items()}
    loss_dict["step"] = idx
    step_losses.append(loss_dict)

(output_dir / "step_losses.json").write_text(json.dumps(step_losses, indent=2), encoding="utf-8")

# Print a small table for quick comparison
headers = ["step", "total", "anchor", "entropy", "consistency"]
print("\t".join(headers))
for row in step_losses:
    print("\t".join(str(row.get(h, "")) for h in headers))


In [None]:
# 3) Prediction evolution: per-step probs and overlays
prob_maps = []
overlays = []

for out in step_outputs:
    probs = out.student_probs
    if hasattr(probs, "detach"):
        probs = probs.detach().cpu().numpy()
    if probs.ndim == 3 and probs.shape[0] == 1:
        probs = probs[0]
    prob_maps.append(probs)
    overlays.append(_overlay(image, probs > 0.5))

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for idx, probs in enumerate(prob_maps, start=1):
    axes[idx - 1].imshow(probs, cmap="magma")
    axes[idx - 1].set_title(f"step {idx} probs")
    axes[idx - 1].axis("off")
plt.tight_layout()
_save_image(output_dir / "step_compare_probs.png", np.array(fig.canvas.renderer.buffer_rgba()))
plt.show()

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for idx, overlay in enumerate(overlays, start=1):
    axes[idx - 1].imshow(overlay)
    axes[idx - 1].set_title(f"step {idx} overlay")
    axes[idx - 1].axis("off")
plt.tight_layout()
_save_image(output_dir / "step_compare_overlay.png", np.array(fig.canvas.renderer.buffer_rgba()))
plt.show()
