## Plotting Digits

In [None]:
# --- Jupyter Starter Pack ---

# autoreload: refresh code on every cell run
%reload_ext autoreload
%autoreload 2

# clean warnings
import warnings
warnings.filterwarnings("ignore")

# nicer printing
from pprint import pprint

# numpy / pandas nicer display
import numpy as np
np.set_printoptions(precision=4, suppress=True)

# matplotlib defaults
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 4)
plt.rcParams["figure.dpi"] = 120

# tqdm in notebooks
from tqdm.notebook import tqdm

# optional: make exceptions show only the important frame
%config InlineBackend.figure_format = "retina"

In [None]:
#  Load score, agents classifier
import sys
import pathlib
import importlib.util
import torch
from hydra.utils import to_absolute_path
from typing import Dict
import copy

# Find repo root so that src/ is importable in notebooks
candidates = []
try:
    this_path = pathlib.Path(__file__).resolve()
    candidates.extend([this_path.parent, this_path.parent.parent])
except NameError:
    pass
cwd_path = pathlib.Path.cwd().resolve()
callbacks = [cwd_path, cwd_path.parent, cwd_path.parent.parent]
candidates.extend(callbacks)
explicit_root = pathlib.Path("/home/rbarbano/home/git/multi-agent-diffusion-working-repo")
candidates.append(explicit_root)
repo_root = next((p for p in candidates if (p / "src").exists()), explicit_root)
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

if not importlib.util.find_spec("src"):
    raise ImportError(f"Could not import 'src'. repo_root={repo_root}, sys.path[0]={sys.path[0]}")

from src.models.registry import get_model_by_name
from src.samplers.diff_dyms import SDE
from src.envs.aggregator import ImageMaskAggregator
from src.samplers.samplers import euler_maruyama_controlled_sampler


def _load_state(module: torch.nn.Module, checkpoint_path: str, device: torch.device) -> None:
    checkpoint = torch.load(to_absolute_path(checkpoint_path), map_location=device)
    if isinstance(checkpoint, Dict):
        state_dict = checkpoint.get("model_state") or checkpoint.get("state_dict") or checkpoint
    else:
        state_dict = checkpoint
    module.load_state_dict(state_dict, strict=False)


def load_cmad_run(num_agents: int, path_to_controls: str):
    """Load models, controls, and run uncontrolled/controlled samplers."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sde = SDE(mode="VP", device=device)
    score_model_cfg = {
        "name": "unet",
        "in_channels": 1,
        "out_channels": 1,
        "model_channels": 64,
        "channel_mult": [1, 4, 4],
        "num_res_blocks": 2,
        "attention_resolutions": [16],
        "max_period": 0.005,
    }
    score_model_name = score_model_cfg.pop("name")
    score_model = get_model_by_name(
        score_model_name,
        marginal_prob_std=sde.marginal_prob_std,
        **score_model_cfg
    ).to(device)
    score_model.eval()
    score_model.requires_grad_(False)
    _load_state(score_model, "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/vp/latest.ckpt", device)

    classifier_cfg = {
        "name": "cnet",
        "img_size": [28, 28],
        "num_classes": 10,
        "num_hidden_layers": 2,
    }
    classifier_name = classifier_cfg.pop("name")
    classifier = get_model_by_name(
        classifier_name,
        **classifier_cfg
    ).to(device)
    classifier.eval()
    _load_state(classifier, "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/checkpoints/cnet.pt", device)

    control_cfg = {
        "name": "unet",
        "in_channels": 2,
        "out_channels": 1,
        "model_channels": 32,
        "channel_mult": [1, 2],
        "num_res_blocks": 1,
        "attention_resolutions": [],
        "max_period": 0.005,
    }
    control_name = control_cfg.pop("name")
    weights_root = pathlib.Path(to_absolute_path(path_to_controls))
    control_agents_init = {}
    control_agent_optim = {}
    for i in range(num_agents):
        control_agents_init[i] = get_model_by_name(
            control_name,
            marginal_prob_std=sde.marginal_prob_std,
            **control_cfg,
        ).to(device)
        control_agents_init[i].eval()
        control_agent_optim[i] = copy.deepcopy(control_agents_init[i])
        weight_path = weights_root / f"agent_{i}.pt"
        _load_state(control_agent_optim[i], str(weight_path), device)

    aggregator = ImageMaskAggregator(
        mask_name="split",
        img_dims=(1, 28, 28),
        num_processes=num_agents,
        overlap_size=0,
        use_overlap=False,
        device=device,
    )

    controlled_samples, controlled_info = euler_maruyama_controlled_sampler(
        score_model=score_model,
        control_agents=control_agent_optim,
        aggregator=aggregator,
        sde=sde,
        image_dim=(1, 28, 28),
        batch_size=64,
        num_steps=500,
        device=device,
        debug=True,
        save_debug_info=False
    )

    return {
        "device": device,
        "controlled_samples": controlled_samples,
        "controlled_info": controlled_info,
        "aggregator": aggregator,
        "sde": sde,
    }


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_samples_grid_small(samples_np, name_fig):
    samples_np = np.asarray(samples_np)
    B = samples_np.shape[0]

    ncols = 8
    nrows = 8

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(0.6 * ncols, 0.6 * nrows),
    )
    axes = np.atleast_2d(axes)

    for i in range(nrows * ncols):
        ax = axes.flat[i]
        if i < B:
            img = samples_np[i]
            if img.ndim == 3 and img.shape[0] in (1, 3):
                img = img[0] if img.shape[0] == 1 else np.transpose(img, (1, 2, 0))
            ax.imshow(
                img,
                cmap="gray" if img.ndim == 2 else None,
                vmin=0.0,
                vmax=1.0,
            )
        ax.axis("off")

    # remove ALL padding (subplots + figure border)
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

    save_name = name_fig + ".png"
    fig.savefig(
        save_name,
        dpi=300,
        pad_inches=0,      # critical
    )
    plt.show()

In [None]:
from pathlib import Path
import json

def load_runs_json(path: str | Path) -> dict:
    path = Path(path)
    return json.loads(path.read_text())

def make_digit_runners(runs_json: dict, *, num_agents: int = 2):
    """
    Returns dict[digit_str] -> callable() that loads that run and saves the grid figure.
    Expects your JSON structure: data["runs"][digit][agents][config_name][0]["run_dir"]
    """
    runners = {}

    runs = runs_json["runs"]
    for digit, agents_map in runs.items():
        if str(num_agents) not in agents_map:
            continue

        # take the first config entry under this digit/agents
        cfg_map = agents_map[str(num_agents)]
        cfg_name, run_list = next(iter(cfg_map.items()))
        run_dir = run_list[0]["run_dir"]

        def _mk_runner(digit=digit, run_dir=run_dir):
            def _run():
                path_to_controls = Path(run_dir) / "weights"
                print(num_agents)
                run_data = load_cmad_run(num_agents, path_to_controls)

                device = run_data["device"]
                controlled_samples = run_data["controlled_samples"]
                controlled_info = run_data["controlled_info"]
                aggregator = run_data["aggregator"]
                sde = run_data["sde"]

                num_agents_local = len(controlled_info["per_agent"])
                controlled_samples_np = controlled_samples.detach().cpu().numpy()
                controlled_states = [controlled_info["per_agent"][i][-1] for i in range(num_agents_local)]

                plot_samples_grid_small(controlled_samples_np, f"c_map_A{num_agents}_digit{digit}")

                return {
                    "digit": digit,
                    "run_dir": run_dir,
                    "device": device,
                    "num_agents": num_agents_local,
                    "controlled_samples_np": controlled_samples_np,
                    "controlled_states": controlled_states,
                    "aggregator": aggregator,
                    "sde": sde,
                    "config_name": cfg_name,
                }
            return _run

        runners[str(digit)] = _mk_runner()

    return runners


data = load_runs_json("matched_runs.json")

digit_runners = make_digit_runners(data, num_agents=2)

# run for ALL digits
for digit in sorted(digit_runners.keys(), key=int):
    print(f"Running digit {digit}...")
    digit_runners[digit]()