## Plotting Digits

In [None]:
# --- Jupyter Starter Pack ---

# autoreload: refresh code on every cell run
%reload_ext autoreload
%autoreload 2

# clean warnings
import warnings
warnings.filterwarnings("ignore")

# nicer printing
from pprint import pprint

# numpy / pandas nicer display
import numpy as np
np.set_printoptions(precision=4, suppress=True)

# matplotlib defaults
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 4)
plt.rcParams["figure.dpi"] = 120

# tqdm in notebooks
from tqdm.notebook import tqdm

# optional: make exceptions show only the important frame
%config InlineBackend.figure_format = "retina"

In [None]:
import sys
import importlib
import pathlib, yaml, math
import torch
from hydra.utils import to_absolute_path
from tqdm.auto import tqdm

# Find repo root so that src/ is importable in notebooks
candidates = []
try:
    this_path = pathlib.Path(__file__).resolve()
    candidates.extend([this_path.parent, this_path.parent.parent])
except NameError:
    pass
cwd_path = pathlib.Path.cwd().resolve()
callbacks = [cwd_path, cwd_path.parent, cwd_path.parent.parent]
candidates.extend(callbacks)
explicit_root = pathlib.Path("/home/rbarbano/home/git/multi-agent-diffusion-working-repo")
candidates.append(explicit_root)
repo_root = next((p for p in candidates if (p / "src").exists()), explicit_root)
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

if not importlib.util.find_spec("src"):
    raise ImportError(f"Could not import 'src'. repo_root={repo_root}, sys.path[0]={sys.path[0]}")

from src.envs.aggregator import ImageMaskAggregator
from src.envs.registry import get_optimality_criterion
from src.models.registry import get_model_by_name
from src.samplers.diff_dyms import SDE
from src.samplers.samplers import euler_maruyama_controlled_sampler
from workflows.learning_agents_bptt_fictitious import _load_state


def load_controls_yaml(yaml_path: str):
    with open(yaml_path, "r") as f:
        cfg = yaml.safe_load(f)
    return cfg["controls"]

def iter_runs(controls_cfg):
    """
    Yields dicts with:
      num_agents, scheme, digit, path_to_controls
    """
    for agent_key, schemes in controls_cfg.items():
        # agent_key like "agent_2" or "agent_3"
        num_agents = int(str(agent_key).split("_")[-1])
        for scheme, digits in schemes.items():
            for digit_key, path in digits.items():
                digit = int(str(digit_key).split("_")[-1])
                yield {
                    "agent_key": agent_key,
                    "num_agents": num_agents,
                    "scheme": scheme,
                    "digit": digit,
                    "path_to_controls": path,
                }

def load_cmad_run_for_eval(num_agents: int, path_to_controls: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sde = SDE(mode="VP", device=device)

    # score model
    score_model_cfg = {
        "name": "unet",
        "in_channels": 1,
        "out_channels": 1,
        "model_channels": 64,
        "channel_mult": [1, 4, 4],
        "num_res_blocks": 2,
        "attention_resolutions": [16],
        "max_period": 0.005,
    }
    score_model_name = score_model_cfg.pop("name")
    score_model = get_model_by_name(
        score_model_name,
        marginal_prob_std=sde.marginal_prob_std,
        **score_model_cfg
    ).to(device)
    score_model.eval()
    score_model.requires_grad_(False)
    _load_state(score_model,
                "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/vp/latest.ckpt",
                device)

    # classifier
    classifier_cfg = {
        "name": "cnet",
        "img_size": [28, 28],
        "num_classes": 10,
        "num_hidden_layers": 2,
    }
    classifier_name = classifier_cfg.pop("name")
    classifier = get_model_by_name(classifier_name, **classifier_cfg).to(device)
    classifier.eval()
    _load_state(classifier,
        "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/cnet.pt",
        device
    )

    # controls
    control_cfg = {
        "name": "cond_unet",
        "in_channels": 3,
        "out_channels": 1,
        "model_channels": 32,
        "channel_mult": [1, 2],
        "num_res_blocks": 1,
        "attention_resolutions": [],
        "max_period": 0.005,
    }
    control_name = control_cfg.pop("name")
    weights_root = pathlib.Path(to_absolute_path(path_to_controls))
    control_agents = {}
    for i in range(num_agents):
        m = get_model_by_name(control_name, **control_cfg).to(device)
        m.eval()
        _load_state(m, str(weights_root / f"agent_{i}.pt"), device)
        control_agents[i] = m

    aggregator = ImageMaskAggregator(
        mask_name="split",
        img_dims=(1, 28, 28),
        num_processes=num_agents,
        overlap_size=0,
        use_overlap=False,
        device=device,
    )

    optimality_criterion = get_optimality_criterion(
        name="classifier_ce_with_cooperation",
        classifier=classifier,
        aggregator=aggregator,
    ).to(device)

    return dict(
        device=device,
        sde=sde,
        score_model=score_model,
        control_agents=control_agents,
        aggregator=aggregator,
        classifier=classifier,
        optimality_criterion=optimality_criterion,
    )


@torch.no_grad()
def classification_accuracy(classifier, x: torch.Tensor, target_digit: int) -> float:
    # x: (B,1,28,28)
    logits = classifier(x)
    pred = logits.argmax(dim=1)
    tgt = torch.full_like(pred, target_digit)
    return (pred == tgt).float().mean().item()

def mean_optimality(optimality_criterion, x: torch.Tensor, target_digit: int) -> float:
    """
    Tries common method names; falls back to CE if needed.
    Returns mean scalar (lower is better if it's a loss).
    """
    target = torch.full((x.shape[0],), target_digit, device=x.device, dtype=torch.long)

    if hasattr(optimality_criterion, "get_terminal_cost"):
        v = optimality_criterion.get_terminal_cost(x, target)
        return float(v.mean().item()) if v.ndim else float(v.item())

    if hasattr(optimality_criterion, "get_terminal_state_loss"):
        v = optimality_criterion.get_terminal_state_loss(x, target)
        return float(v.mean().item()) if v.ndim else float(v.item())
    else:
        raise NotImplementedError("Optimality criterion does not have a known method to compute terminal cost.")


def evaluate_yaml_runs(yaml_path: str, total_samples: int = 1024, batch_size: int = 64, num_steps: int = 500):
    controls_cfg = load_controls_yaml(yaml_path)
    results = []

    for run in iter_runs(controls_cfg):
        num_agents = run["num_agents"]
        digit = run["digit"]
        path_to_controls = run["path_to_controls"]

        # load models/controls once per run
        run_env = load_cmad_run_for_eval(num_agents=num_agents, path_to_controls=path_to_controls)
        device = run_env["device"]

        n_batches = math.ceil(total_samples / batch_size)
        acc_sum = 0.0
        opt_sum = 0.0
        n_seen = 0
        all_samples = []
        for _ in tqdm(
			range(n_batches),
			desc=f"{run['agent_key']} | {run['scheme']} | digit {digit}",
			leave=False,
		):
            cur_bs = min(batch_size, total_samples - n_seen)
            if cur_bs <= 0:
                break

            samples = euler_maruyama_controlled_sampler(
                score_model=run_env["score_model"],
                control_agents=run_env["control_agents"],
                aggregator=run_env["aggregator"],
                sde=run_env["sde"],
                image_dim=(1, 28, 28),
                batch_size=cur_bs,
                num_steps=num_steps,
                device=device,
                debug=False,
                optimality_criterion=run_env["optimality_criterion"],
                optimality_target=digit,
                enable_optimality_loss_on_processes=False,
            )

            # sampler returns final aggregated samples when debug=False
            x = samples
            all_samples.append(x.cpu())

            acc = classification_accuracy(run_env["classifier"], x, digit)
            opt = mean_optimality(run_env["optimality_criterion"], x, digit)

            acc_sum += acc * cur_bs
            opt_sum += opt * cur_bs
            n_seen += cur_bs
        
        results.append({
            "agent_key": run["agent_key"],
            "scheme": run["scheme"],
            "digit": digit,
            "path": path_to_controls,
            "n": n_seen,
            "mean_acc": acc_sum / max(1, n_seen),
            "mean_optimality": opt_sum / max(1, n_seen),
        })
        base_name = f"{run['agent_key']}__{run['scheme']}__digit_{digit}"
        all_samples = torch.cat(all_samples, dim=0)[:total_samples]  # (1024,1,28,28)
        torch.save(all_samples, f"{base_name}_samples.pt")

        print(f"[{run['agent_key']}/{run['scheme']}/digit={digit}] "
              f"acc={results[-1]['mean_acc']:.4f}  opt={results[-1]['mean_optimality']:.4f}")

    return results

yaml_path = "paths_to_exps.yaml"
results = evaluate_yaml_runs(yaml_path, total_samples=1024, batch_size=64, num_steps=500)

print("\nSummary:")
for r in results:
    print(f"{r['agent_key']:>7} | {r['scheme']:<10} | digit {r['digit']} | "
          f"acc {r['mean_acc']:.4f} | opt {r['mean_optimality']:.4f}")

In [None]:
import math
import numpy as np
import torch
from tqdm.auto import tqdm

from src.envs.aggregator import ImageMaskAggregator
from src.envs.registry import get_optimality_criterion
from src.models.registry import get_model_by_name
from src.samplers.diff_dyms import SDE
from src.samplers.samplers import euler_maruyama_dps_sampler


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SCORE_CKPT = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/vp/latest.ckpt"
CLF_CKPT   = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/cnet.pt"

NUM_STEPS = 500
TOTAL_SAMPLES = 1024
BATCH_SIZE = 64
GUIDANCE_SCALE = 100.0

DIGITS = [0, 3, 9]
AGENTS = [2, 3]

SAVE_NPY = True   # set False if you only want .pt


sde = SDE(mode="VP", device=DEVICE)

score_model_cfg = dict(
    name="unet",
    in_channels=1,
    out_channels=1,
    model_channels=64,
    channel_mult=[1, 4, 4],
    num_res_blocks=2,
    attention_resolutions=[16],
    max_period=0.005,
)
score_model_name = score_model_cfg.pop("name")
score_model = get_model_by_name(
    score_model_name,
    marginal_prob_std=sde.marginal_prob_std,
    **score_model_cfg
).to(DEVICE)
score_model.eval()
score_model.requires_grad_(False)
_load_state(score_model, SCORE_CKPT, DEVICE)

classifier_cfg = dict(
    name="cnet",
    img_size=[28, 28],
    num_classes=10,
    num_hidden_layers=2,
)
classifier_name = classifier_cfg.pop("name")
classifier = get_model_by_name(classifier_name, **classifier_cfg).to(DEVICE)
classifier.eval()
_load_state(classifier, CLF_CKPT, DEVICE)


@torch.no_grad()
def acc(classifier, x: torch.Tensor, digit: int) -> float:
    logits = classifier(x)
    pred = logits.argmax(dim=1)
    tgt = torch.full_like(pred, digit)
    return (pred == tgt).float().mean().item()


@torch.no_grad()
def terminal_loss(opt_crit, x: torch.Tensor, digit: int) -> float:
    target = torch.full((x.shape[0],), digit, device=x.device, dtype=torch.long)
    if hasattr(opt_crit, "get_terminal_state_loss"):
        v = opt_crit.get_terminal_state_loss(x, target)
        return float(v.mean().item()) if v.ndim else float(v.item())
    if hasattr(opt_crit, "get_terminal_cost"):
        v = opt_crit.get_terminal_cost(x, target)
        return float(v.mean().item()) if v.ndim else float(v.item())
    raise NotImplementedError("Optimality criterion has no get_terminal_state_loss/get_terminal_cost.")


def sample_cdps_for_combo(num_agents: int, digit: int):
    aggregator = ImageMaskAggregator(
        mask_name="split",
        img_dims=(1, 28, 28),
        num_processes=num_agents,
        overlap_size=0,
        use_overlap=False,
        device=DEVICE,
    )
    opt_crit = get_optimality_criterion(
        name="classifier_ce_with_cooperation",
        classifier=classifier,
        aggregator=aggregator,
    ).to(DEVICE)

    all_samples = []
    acc_sum, loss_sum, n_seen = 0.0, 0.0, 0
    n_batches = math.ceil(TOTAL_SAMPLES / BATCH_SIZE)

    for _ in tqdm(range(n_batches), desc=f"CDPS | agents={num_agents} | digit={digit}", leave=False):
        cur_bs = min(BATCH_SIZE, TOTAL_SAMPLES - n_seen)
        if cur_bs <= 0:
            break

        x = euler_maruyama_dps_sampler(
            score_models={k: score_model for k in range(num_agents)},
            aggregator=aggregator,
            sde=sde,
            optimality_loss=opt_crit,
            target=digit,
            image_dim=(1, 28, 28),
            batch_size=cur_bs,
            num_steps=NUM_STEPS,
            device=DEVICE,
            eps=1e-3,
            guidance_scale=GUIDANCE_SCALE,
            debug=False,
        )

        x = x.detach()
        all_samples.append(x.cpu())

        a = acc(classifier, x, digit)
        l = terminal_loss(opt_crit, x, digit)

        acc_sum += a * cur_bs
        loss_sum += l * cur_bs
        n_seen += cur_bs

    samples = torch.cat(all_samples, dim=0)[:TOTAL_SAMPLES]  # (1024,1,28,28)
    mean_acc = acc_sum / n_seen
    mean_loss = loss_sum / n_seen
    return samples, mean_acc, mean_loss

cdps_results = []

for n_agents in AGENTS:
    for d in DIGITS:
        samples, mean_acc, mean_loss = sample_cdps_for_combo(n_agents, d)

        base = f"cdps__agents_{n_agents}__digit_{d}"
        torch.save(samples, f"{base}__samples.pt")
        if SAVE_NPY:
            np.save(f"{base}__samples.npy", samples.numpy())

        cdps_results.append({
            "method": "CDPS",
            "num_agents": n_agents,
            "digit": d,
            "mean_acc": mean_acc,
            "mean_terminal_loss": mean_loss,
            "samples_pt": f"{base}__samples.pt",
            "samples_npy": f"{base}__samples.npy" if SAVE_NPY else None,
        })

        print(f"[CDPS] agents={n_agents} digit={d} | acc={mean_acc:.4f} loss={mean_loss:.4f} | saved {base}__samples.*")

cdps_results

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from pathlib import Path

# directory with your saved samples (current dir by default)
DATA_DIR = Path(".")

# list all sample files
sample_files = sorted(DATA_DIR.glob("*_samples.pt"))

def show_grid(x, title="", nrow=8):
    grid = make_grid(x, nrow=nrow, normalize=True, value_range=(0.0, 1.0))
    grid = grid[0]  # take first channel -> (H, W)

    plt.figure(figsize=(6, 6))
    plt.imshow(grid, cmap="gray")
    plt.axis("off")
    plt.title(title)
    plt.show()

for f in sample_files:
    samples = torch.load(f, map_location="cpu")  # (N,1,28,28)
    samples64 = samples[:64]

    show_grid(
        samples64,
        title=f.stem.replace("_samples", "")
    )