## Plotting Digits

In [None]:
# --- Jupyter Starter Pack ---

# autoreload: refresh code on every cell run
%reload_ext autoreload
%autoreload 2

# clean warnings
import warnings
warnings.filterwarnings("ignore")

# nicer printing
from pprint import pprint

# numpy / pandas nicer display
import numpy as np
np.set_printoptions(precision=4, suppress=True)

# matplotlib defaults
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 4)
plt.rcParams["figure.dpi"] = 120

# tqdm in notebooks
from tqdm.notebook import tqdm

# optional: make exceptions show only the important frame
%config InlineBackend.figure_format = "retina"

In [None]:
# NUM_AGENTS = 2
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-01-26/11-28-41/9/weights"
NUM_AGENTS = 3
path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-01-26/18-06-34/9/weights"


In [None]:
#  Load score, agents classifier
import sys
import pathlib
import importlib.util
import torch
from hydra.utils import to_absolute_path
from typing import Dict
import copy

# Find repo root so that src/ is importable in notebooks
candidates = []
try:
    this_path = pathlib.Path(__file__).resolve()
    candidates.extend([this_path.parent, this_path.parent.parent])
except NameError:
    pass
cwd_path = pathlib.Path.cwd().resolve()
candidates.extend([cwd_path, cwd_path.parent, cwd_path.parent.parent])
explicit_root = pathlib.Path("/home/rbarbano/home/git/multi-agent-diffusion-working-repo")
candidates.append(explicit_root)
repo_root = next((p for p in candidates if (p / "src").exists()), explicit_root)
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

if not importlib.util.find_spec("src"):
    raise ImportError(f"Could not import 'src'. repo_root={repo_root}, sys.path[0]={sys.path[0]}")

from src.models.registry import get_model_by_name
from src.samplers.diff_dyms import SDE


def _load_state(module: torch.nn.Module, checkpoint_path: str, device: torch.device) -> None:
    checkpoint = torch.load(to_absolute_path(checkpoint_path), map_location=device)
    if isinstance(checkpoint, Dict):
        state_dict = checkpoint.get("model_state") or checkpoint.get("state_dict") or checkpoint
    else:
        state_dict = checkpoint
    module.load_state_dict(state_dict, strict=False)


device = "cuda" if torch.cuda.is_available() else "cpu"
sde = SDE(mode="VE", sigma=25, device=device)
score_model_cfg = {
    "name": "unet",
    "in_channels": 1,
    "out_channels": 1,
    "model_channels": 64,
    "channel_mult": [1, 4, 4],
    "num_res_blocks": 2,
    "attention_resolutions": [16],
    "max_period": 0.005,
}
score_model_name = score_model_cfg.pop("name")
score_model = get_model_by_name(
    score_model_name,
    marginal_prob_std=sde.marginal_prob_std,
    **score_model_cfg
).to(device)
score_model.eval()
# Freezes the model’s parameters so they don’t get gradients or updates. 
# It does not stop autograd from building a graph or computing gradients with respect to other tensors.
score_model.requires_grad_(False)
_load_state(score_model,"/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/latest.ckpt", device)

# Load classifier for guidance
classifier_cfg = {
    "name": "cnet",
    "img_size": [28, 28],
    "num_classes": 10,
    "num_hidden_layers": 2,
}
classifier_name = classifier_cfg.pop("name")
classifier = get_model_by_name(
    classifier_name,
    **classifier_cfg
).to(device)
classifier.eval()
_load_state(classifier, "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/cnet.pt", device)

# Load Control Agents (N agents)
control_cfg = {
    "name": "unet",
    "in_channels": 2,
    "out_channels": 1,
    "model_channels": 32,
    "channel_mult": [1, 2],
    "num_res_blocks": 1,
    "attention_resolutions": [],
    "max_period": 0.005,
}
control_name = control_cfg.pop("name")
weights_root = pathlib.Path(to_absolute_path(path_to_controls))
control_agents_init = {}
control_agent_optim = {}
for i in range(NUM_AGENTS):
    control_agents_init[i] = get_model_by_name(
        control_name,
        marginal_prob_std=sde.marginal_prob_std,
        **control_cfg,
    ).to(device)
    control_agents_init[i].eval()
    control_agent_optim[i] = copy.deepcopy(control_agents_init[i])	
    weight_path = weights_root / f"agent_{i}.pt"
    weight_path = weights_root / f"agent_{i}.pt"
    _load_state(control_agent_optim[i], str(weight_path), device)

from src.envs.aggregator import ImageMaskAggregator
aggregator = ImageMaskAggregator(
        mask_name="split",
        img_dims=(1, 28, 28),
        num_processes=NUM_AGENTS,
        overlap_size=0,
        use_overlap=False,
        device=device,
    )

from src.samplers.samplers import euler_maruyama_controlled_sampler

uncontrolled_samples, uncontrolled_info = euler_maruyama_controlled_sampler(
        score_model=score_model,
        control_agents=control_agents_init,
        aggregator=aggregator,
        sde=sde,
        image_dim=(1, 28, 28),
        batch_size=16,
        num_steps=500,
        device=device,
        debug=True,
    )

controlled_samples, controlled_info = euler_maruyama_controlled_sampler(
        score_model=score_model,
        control_agents=control_agent_optim,
        aggregator=aggregator,
        sde=sde,
        image_dim=(1, 28, 28),
        batch_size=16,
        num_steps=500,
        device=device,
        debug=True,
    )


In [None]:
num_agents = len(uncontrolled_info["per_agent"])

# detach to numpy for plotting
uncontrolled_samples_np = uncontrolled_samples.detach().cpu().numpy()
controlled_samples_np = controlled_samples.detach().cpu().numpy()

uncontrolled_states = [uncontrolled_info["per_agent"][i][-1] for i in range(num_agents)]
controlled_states = [controlled_info["per_agent"][i][-1] for i in range(num_agents)]

In [None]:
import numpy as np

# try to enable LaTeX rendering; fall back to mathtext if LaTeX is missing
_use_tex = True
if _use_tex:
    try:
        plt.rcParams["text.usetex"] = True
    except Exception:
        plt.rcParams["text.usetex"] = False
else:
    plt.rcParams["text.usetex"] = False

# Font sizes and spacing
TITLE_FONTSIZE = 16
TITLE_PAD = 8
ROW_LABEL_FONTSIZE = 24


def _pick_image(arr: np.ndarray, step: int = -1, batch: int = 0) -> np.ndarray:
    if arr.ndim == 5:  # (steps, batch, c, h, w)
        return arr[step, batch]
    if arr.ndim == 4:  # (batch, c, h, w)
        return arr[batch]
    return arr


def _to_hw(img: np.ndarray) -> np.ndarray:
    if img.ndim == 3 and img.shape[0] in (1, 3):
        if img.shape[0] == 1:
            return img[0]
        return np.transpose(img, (1, 2, 0))
    return img


def _clip01(x: np.ndarray) -> np.ndarray:
    return np.clip(x, 0.0, 1.0)


def _seam_ys(img: np.ndarray, num_agents: int) -> list[float]:
    h = img.shape[0]
    return [(k * h / num_agents) - 0.5 for k in range(1, num_agents)]


def plot_agents_row(states_np, samples_np, batch_idx: int, axes_row, row_label: str):
    num_agents = len(states_np)
    state_imgs = [_clip01(_to_hw(_pick_image(st, step=-1, batch=batch_idx))) for st in states_np]
    sample_img = _clip01(_to_hw(_pick_image(samples_np, batch=batch_idx)))

    # Top row (uncontrolled) should not include 'u' in the X superscripts
    is_uncontrolled = ("Uncontrolled" in row_label)
    titles = [
        (fr"$X^{{{i + 1}}}_0$" if is_uncontrolled else fr"$X^{{u,{i + 1}}}_0$")
        for i in range(num_agents)
    ] + [r"$Y_0$"]

    for ax, img, title in zip(axes_row, state_imgs + [sample_img], titles):
        is_gray = img.ndim == 2 or (img.ndim == 3 and img.shape[-1] == 1)
        ax.imshow(img.squeeze(), cmap="gray" if is_gray else None, vmin=0.0, vmax=1.0)
        for y in _seam_ys(img, num_agents):
            ax.axhline(y, color="red", lw=2.5)
        ax.set_title(title, fontsize=TITLE_FONTSIZE, pad=TITLE_PAD)
        ax.axis("off")
    axes_row[0].set_ylabel(row_label, rotation=90, fontsize=ROW_LABEL_FONTSIZE, labelpad=10)


# two rows: uncontrolled (top) and controlled (bottom)
for batch_idx in range(uncontrolled_samples_np.shape[0]):
    rows = [
        (uncontrolled_states, uncontrolled_samples_np, r"$\mathrm{Uncontrolled\ SDEs}$"),
        (controlled_states, controlled_samples_np, r"$\mathrm{Controlled\ SDEs}$"),
    ]
    fig, axes = plt.subplots(len(rows), num_agents + 1, figsize=(3 * (num_agents + 1), 3 * len(rows)))
    axes = np.atleast_2d(axes)
    for row_idx, (st_list, smp_np, label) in enumerate(rows):
        plot_agents_row(st_list, smp_np, batch_idx=batch_idx, axes_row=axes[row_idx], row_label=label)
        # extra text to avoid clipping when tight layouts shrink ylabels
        y_center = 1.0 - (row_idx + 0.5) / len(rows)
        fig.text(0.04, y_center, label, va="center", ha="left", rotation=90, fontsize=ROW_LABEL_FONTSIZE)
    plt.tight_layout(pad=0.4, w_pad=0.6, h_pad=1.8)
    fig.subplots_adjust(left=0.08, wspace=0.08, hspace=0.2)
    plt.show()