In [None]:
import os
import re
import glob
import torch
import imageio
import matplotlib.pyplot as plt
from trainers.torch.networks import SplitValueSharedActorCritic
from mlagents_envs.base_env import ObservationSpec, DimensionProperty, ObservationType, ActionSpec, BehaviorSpec
from mlagents.trainers.settings import NetworkSettings, TrainerSettings
from trainers.policy.torch_policy import TorchPolicy
import matplotlib.image as mpimg
import numpy as np

def natural_sort_key(path):
    # extract all integers and sort by them
    nums = re.findall(r'\d+', os.path.basename(path))
    return [int(n) for n in nums] if nums else [float('inf')]

def build_policy(seed: int,
                 behavior_spec: BehaviorSpec,
                 trainer_settings: TrainerSettings,
                 network_settings: NetworkSettings,
                 tanh_squash: bool,
                 separate_critic: bool,
                 condition_sigma_on_obs: bool,
                 load_critic_only: str):
    """
    Instantiate a TorchPolicy just as you did interactively.
    """
    policy = TorchPolicy(
        seed=seed,
        behavior_spec=behavior_spec,
        trainer_settings=trainer_settings,
        tanh_squash=tanh_squash,
        separate_critic=separate_critic,
        condition_sigma_on_obs=condition_sigma_on_obs,
        load_critic_only=load_critic_only
    )
    return policy

def make_grid_positions(xmin, xmax, ymin, ymax, step=1):
    """
    Returns a list containing a single tensor of shape [N, obs_dim],
    where N = ((xmax-xmin)/step+1)*((ymax-ymin)/step+1).
    """
    positions = []
    for y in range(ymin, ymax + 1, step):
        for x in range(xmin, xmax + 1, step):
            # [x-0.5, 0.5, y-0.5] + zeros(9)
            pos = [x - 0.5, 0.5, y - 0.5] + [0.0] * 9
            positions.append(pos)
    tensor = torch.tensor(positions, dtype=torch.float32)
    return [tensor]

def plot_field(field,
               out_path,
               capture_path,
               cmap='viridis',
               alpha=0.4,
               vmin=None,
               vmax=None,
               dpi=150):
    """
    field         : 2D numpy array of your grid values
    out_path      : where to save the overlayed png
    capture_path  : path to your background Capture.PNG
    cmap          : matplotlib colormap
    alpha         : heatmap transparency (0..1)
    vmin,vmax     : manual color‐scale limits (optional)
    dpi           : resolution for saved PNG
    """
    # 1) load background
    bg = mpimg.imread(capture_path)        # shape = (H, W, 3)
    rows, cols = field.shape

    # 2) prep figure so data‐coords = grid coords
    fig, ax = plt.subplots(figsize=(cols/2, rows/2))
    extent = [cols, 0, rows, 0]  # Flip y-axis by swapping bottom and top

    # 3) draw background (zorder=0)
    ax.imshow(bg,
              extent=extent,
              aspect='auto',
              zorder=0)

    # 4) overlay translucent heatmap (zorder=1)
    flipped_field = np.flipud(field)  # Flip the field on the y-axis
    im = ax.imshow(flipped_field,
                   cmap=cmap,
                   alpha=alpha,
                   extent=extent,
                   interpolation='nearest',
                   aspect='auto',
                   vmin=vmin if vmin is not None else np.nanmin(field),
                   vmax=vmax if vmax is not None else np.nanmax(field),
                   zorder=1)

    # 5) strip axes/margins
    ax.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    # 6) save
    fig.savefig(out_path,
                dpi=dpi,
                bbox_inches='tight',
                pad_inches=0)
    plt.close(fig)

def make_gif_from_images(image_paths, out_gif, fps=2):
    imgs = []
    for im_path in image_paths:
        imgs.append(imageio.imread(im_path))
    imageio.mimsave(out_gif, imgs, fps=fps)

if __name__ == "__main__":
    # === CONFIGURATION ===
    run_name = "4-22-ppo-0"
    seeker_dir = f"/home/rmarr/Projects/visibility-game-env/results/{run_name}/Seeker"
    out_dir    = os.path.join(seeker_dir, "heatmap_frames")
    gif_path   = os.path.join(seeker_dir, f"{run_name}_seeker_evolution.gif")

    os.makedirs(out_dir, exist_ok=True)

    # === RECREATE YOUR SPECS & POLICY ===
    seed = 5404
    observation_specs = [ObservationSpec(
        name="position_observation",
        shape=(12,),
        dimension_property=(DimensionProperty.NONE,),
        observation_type=ObservationType.DEFAULT
    )]
    action_spec = ActionSpec(continuous_size=0, discrete_branches=(5,))
    behavior_spec = BehaviorSpec(
        observation_specs=observation_specs,
        action_spec=action_spec
    )
    trainer_settings = TrainerSettings(dual_critic=True)
    network_settings = NetworkSettings(
        deterministic=False,
        memory=None,
        hidden_units=128,
        num_layers=2,
    )

    policy = build_policy(
        seed=seed,
        behavior_spec=behavior_spec,
        trainer_settings=trainer_settings,
        network_settings=network_settings,
        tanh_squash=False,
        separate_critic=False,
        condition_sigma_on_obs=False,
        load_critic_only="position_only"
    )
    modules = policy.get_modules()

    # === PRECOMPUTE GRID POSITIONS ===
    # Here we build the same 30×30 grid you used interactively
    positions = make_grid_positions(-4, 25, -4, 25, step=1)  # gives 30×30 = 900 points

    # === FIND & SORT .pt FILES ===
    pt_files = glob.glob(os.path.join(seeker_dir, "*.pt"))
    pt_files.sort(key=natural_sort_key)

    print(f"Found {len(pt_files)} checkpoint files.")

    # If you want to normalize the color scale across all frames,
    # you could precompute all values once and track global min/max.
    # For simplicity, we do per-frame color scaling here.

    frame_paths = []
    for idx, ckpt in enumerate(pt_files):
        print(f"[{idx+1}/{len(pt_files)}] Loading {os.path.basename(ckpt)} …")
        sd = torch.load(ckpt, map_location="cpu")
        # Strict load of policy weights
        modules['Policy'].load_state_dict(sd['Policy'], strict=True)

        # Forward pass through critic
        vals = policy.actor.critic_pass(positions)[0]['extrinsic']
        grid = vals.detach().cpu().numpy().reshape((30, 30))

        # Save a PNG frame
        frame_path = os.path.join(out_dir, f"frame_{idx:03d}.png")
        plot_field(grid,
                out_path='overlayed.png',
                capture_path='map.PNG',
                cmap='hot',
                alpha=0.5)
        frame_paths.append(frame_path)

    # === MAKE THE GIF ===
    print(f"Writing GIF to {gif_path}")
    make_gif_from_images(frame_paths, gif_path, fps=2)
    print("Done.")

separate_critic: False
Found 6 checkpoint files.
[1/6] Loading Seeker-499496.pt …
[2/6] Loading Seeker-999446.pt …
