# TTA Debug Notebook
Visualise the SAM2 TTA pipeline: soft teacher construction, region partition, student predictions, and losses (anchor / entropy / consistency) including augmented view.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

ROOT = Path.cwd().parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))
SAM2_ROOT = ROOT.parent / "sam2"
if SAM2_ROOT.exists() and str(SAM2_ROOT) not in sys.path:
    sys.path.insert(0, str(SAM2_ROOT))

from configs.pipeline_config import load_pipeline_config
from debug_tests.debug_test import MAIN_DIR, _load_constants
from image_processings.tta import TTAPipeline, TTALossWeights, build_soft_teacher, partition_regions
from image_processings.tta.tta_core import default_multi_view_augment

%matplotlib inline

In [None]:
# Replace these placeholders with actual outputs from the main pipeline
constants = _load_constants()
pipeline_cfg = load_pipeline_config(MAIN_DIR / constants["pipeline_cfg"])

nested_masks = [np.zeros((64, 64), dtype=bool)]  # e.g. pool of candidate masks
scores = [1.0]  # scores for each mask
image = np.zeros((64, 64, 3), dtype=np.uint8)  # input image
prompts = {
    "point_coords": np.array([[32, 32]], dtype=np.float32),
    "point_labels": np.array([1], dtype=np.int64),
    "box": None,
    "mask_input": None,
    "multimask_output": False,
}

teacher = build_soft_teacher(nested_masks, scores)
partition = partition_regions(nested_masks, teacher)

# Dummy predictor with required API; replace with SAM2 predictor in practice
class DummyPredictor:
    def predict(self, point_coords, point_labels, box, mask_input, multimask_output, return_logits=False):
        logits = np.random.randn(1, teacher.shape[0], teacher.shape[1]).astype(np.float32)
        return logits, [1.0], None

predictor = DummyPredictor()
pipeline = TTAPipeline(predictor=predictor, loss_weights=TTALossWeights(), augment_fn=default_multi_view_augment())
tta_out = pipeline.step(image, prompts, nested_masks, scores)

In [None]:
# Visualise teacher and partition
def show_mask(ax, mask, title, cmap="magma"):
    ax.imshow(mask, cmap=cmap)
    ax.set_title(title)
    ax.axis("off")

fig, axes = plt.subplots(1, 4, figsize=(12, 3))
show_mask(axes[0], teacher, "Teacher prob")
show_mask(axes[1], partition.sure_fg, "Sure FG", cmap="gray")
show_mask(axes[2], partition.unsure, "Uncertain", cmap="gray")
show_mask(axes[3], partition.sure_bg, "Sure BG", cmap="gray")
plt.show()

# Visualise student preds
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
show_mask(axes[0], tta_out.student_probs, f"Student probs\nLoss total={tta_out.losses['total']:.4f}")
if tta_out.student_probs_aug is not None:
    show_mask(axes[1], tta_out.student_probs_aug, "Augmented probs")
plt.show()

print("Loss breakdown:", tta_out.losses)