In [1]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"  # Forces JAX to run on CPU

import jax #pip install jax
print(f"JAX is using: {jax.devices()[0]}")
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax #pip install distrax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper
import jaxmarl
from jaxmarl.wrappers.baselines import LogWrapper
from jaxmarl.environments.overcooked import Overcooked, overcooked_layouts, layout_grid_to_dict
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer
from jaxmarl import make
import hydra #pip install hydra-core
from omegaconf import OmegaConf

import matplotlib.pyplot as plt

### Added to modify the structure of the definitions
from functools import partial

JAX is using: TFRT_CPU_0


In [2]:
import jax.lax as lax
import pickle
from datetime import datetime
from PIL import Image
import time
import re
from textwrap import dedent

# Save the model parameters only
def save_model(runner_state, save_dir, file):
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Construct the full file path
    filename = os.path.join(save_dir, file)

    # Extract train_state (first element of the tuple)
    train_state = runner_state[0]

    # Save only the model parameters
    with open(filename, 'wb') as f:
        pickle.dump({'params': train_state.params}, f)

    print(f"Model parameters saved to {filename}")

# Load the model parameters
def load_model(load_dir, file, network):
    filename = os.path.join(load_dir, file)

    if not os.path.exists(filename):
        raise FileNotFoundError(f"Model file {filename} not found.")

    with open(filename, "rb") as f:
        saved_data = pickle.load(f)

    # Ensure the loaded model has the correct parameter structure
    if "params" not in saved_data:
        raise ValueError("Invalid saved model file: missing 'params'.")

    model_params = saved_data["params"]
    print(f"Model parameters loaded from {filename}")

    return model_params

In [3]:
def choose_tx(config):
    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac
    
    if config["ANNEAL_LR"]:
        tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
    
    return tx

def find_closest_checkpoint(fixed_step, load_dir):
    # List all files in the directory
    files = os.listdir(load_dir)
    
    # Filter files that match the pattern 'trained_model_{step}.pkl'
    checkpoint_files = [f for f in files if f.startswith("trained_model_") and f.endswith(".pkl")]
    
    # Extract the step numbers from the filenames
    steps = []
    for file in checkpoint_files:
        step_str = file.split('_')[-1].split('.')[0]  # Extract the step number
        try:
            step = int(step_str)
            steps.append(step)
        except ValueError:
            continue  # Skip files that don't have a valid step number
    
    if len(steps) == 0:
        raise FileNotFoundError(f"No checkpoint files found in {load_dir}")
    
    # Find the closest step to the fixed step
    closest_step = min(steps, key=lambda x: abs(x - fixed_step))
    closest_filename = f"trained_model_{closest_step}.pkl"
    
    print(f"Found checkpoint: {closest_filename} for step {closest_step}")
    
    # Load the model (assuming load_model function is available)
    loaded_params = load_model(load_dir, closest_filename, network)
    
    return loaded_params, closest_step

import imageio.v3 as iio

def custom_animate(state_seq, agent_view_size, filename="animation.gif"):
    """Animate a GIF give a state sequence and save if to file."""
    import imageio
    TILE_PIXELS = 32

    padding = agent_view_size - 2  # show

    def get_frame(state):
            grid = np.asarray(state.maze_map[padding:-padding, padding:-padding, :])
            # Render the state
            frame = OvercookedVisualizer._render_grid(
                    grid,
                    tile_size=TILE_PIXELS,
                    highlight_mask=None,
                    agent_dir_idx=state.agent_dir_idx,
                    agent_inv=state.agent_inv
            )
            return frame

    frame_seq =[get_frame(state) for state in state_seq]
    
    imageio.mimsave(filename, frame_seq, 'GIF', duration=2)
    
def gif_to_mp4(gif_filename):
    """Convert a GIF to MP4 using imageio and ffmpeg."""
    mp4_filename = gif_filename.replace(".gif", ".mp4")
    gif = iio.imread(gif_filename)
    fps = 2  # Set desired FPS

    iio.imwrite(mp4_filename, gif, extension=".mp4", fps=fps)
    print(f"MP4 saved to {mp4_filename}")

In [4]:
class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"
    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu

        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

In [5]:
def reshape_info(info):
    def reshape_fn(x):
        if isinstance(x, dict):
            return jax.tree.map(lambda y: y.reshape(-1), x)
        # For (16, 2) shaped arrays, reshape to (32,)
        elif len(x.shape) == 2 and x.shape[1] == 2:
            return x.reshape(-1)
        # For (16,) shaped arrays
        else:
            # Repeat each element twice since we have 2 agents
            return jnp.repeat(x, 2)
    return jax.tree.map(reshape_fn, info)

In [6]:
def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))

def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}

In [None]:
def make_config(config):
    env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
    print(f'LR: {config["LR"]}')
    print(f'NUM_ENVS: {config["NUM_ENVS"]}')
    print(f'NUM_STEPS: {config["NUM_STEPS"]}')
    print(f'TOTAL_TIMESTEPS: {config["TOTAL_TIMESTEPS"]}')
    print(f'UPDATE_EPOCHS: {config["UPDATE_EPOCHS"]}')
    print(f'NUM_MINIBATCHES: {config["NUM_MINIBATCHES"]}')

    config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
    print(f'NUM_ACTORS: {config["NUM_ACTORS"]}')

    config["NUM_UPDATES"] = int(
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    print(f'NUM_UPDATES: {config["NUM_UPDATES"]}')

    config["NUM_SAVES"] = int(config["NUM_UPDATES"] // config["SAVE_EVERY_N_EPOCHS"])
    print(f'NUM_SAVES: {config["NUM_SAVES"]}')

    config["MINIBATCH_SIZE"] = (
        config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    print(f'MINIBATCH_SIZE: {config["MINIBATCH_SIZE"]}')
    return config


In [8]:
def make_train(config):
    env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
    
    env = LogWrapper(env)
    
    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(seed, rng, save_dir):

        # INIT NETWORK

        # Creates an instance of the ActorCritic model.
        # The ActorCritic class is initialized with:
        # env.action_space().n: The number of possible actions in the environment.
        # config["ACTIVATION"]: The type of activation function to use (e.g., ReLU or Tanh).
        network = ActorCritic(env.action_space().n, activation=config["ACTIVATION"])

        # Splits the random number generator rng into two separate RNGs (rng and _rng), so they can be used independently.
        rng, _rng = jax.random.split(rng)

        # Creates a zero-initialized array of the same shape as the observation space of the environment (env.observation_space().shape), then flattens it into a 1D array (init_x).
        # This serves as a sample input to initialize the network.
        init_x = jnp.zeros(env.observation_space().shape)
        init_x = init_x.flatten()

        # Initializes the parameters of the network (network_params) by passing the random key _rng and the input init_x (observation example).
        # This is necessary to set up the weights and biases of the neural network layers.
        network_params = network.init(_rng, init_x)

        # This block initializes the optimizer (tx) that will be used to update the network parameters.
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))

        # Creates a TrainState object that holds the model's parameters (network_params), the optimizer (tx), and the function to apply the model (network.apply).
        # This will be used for model training, including parameter updates.
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        # Splits the RNG (rng) into two new RNGs: one (_rng) used for resetting the environment, and another (reset_rng) for each environment if you're running multiple environments in parallel (config["NUM_ENVS"]).
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])

        # Initializes the environment(s). jax.vmap is used to vectorize the env.reset function, so it can reset config["NUM_ENVS"] environments in parallel.
        # reset_rng: The random keys for resetting each environment are passed in here.
        # obsv: The initial observations returned by the environment(s).
        # env_state: The initial state of the environment(s).
        obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
        
        # TRAIN LOOP
        @jax.jit
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES

            # Action Selection: It uses the current policy (pi) to sample an action based on the current observations.
            # Environment Step: It then steps the environment using the selected actions, receiving observations, rewards, done flags, and additional info.
            # Transition Recording: A transition is created, which includes the action taken, value estimate, rewards, and log probabilities for the action. This transition is used to calculate losses later.

            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)

                obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])

                pi, value = network.apply(train_state.params, obs_batch)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)

                env_act = {k:v.flatten() for k,v in env_act.items()}

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])

                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
                    rng_step, env_state, env_act
                )

                info = reshape_info(info)
                transition = Transition(
                    batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
                    action,
                    value,
                    batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
                    log_prob,
                    obs_batch,
                    info

                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition
            
            # This line performs the _env_step function repeatedly for NUM_STEPS. jax.lax.scan is a JAX function that allows you to apply a function repeatedly over a sequence of data, which is used here to simulate multiple environment interactions (collecting trajectories for NUM_STEPS).
            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            # After collecting the trajectory, this part calculates the advantages of each transition using GAE. The last observation (last_obs) is passed through the network to get the value (last_val) for the last state.
            train_state, env_state, last_obs, rng = runner_state
            last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
            _, last_val = network.apply(train_state.params, last_obs_batch)
            
            # This function implements the GAE algorithm to compute advantages and targets (the value targets).
            # delta is computed using the Bellman equation for each transition.
            # gae is calculated iteratively in reverse order of the trajectory, allowing for more stable learning by incorporating multiple future rewards.
            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):

                # This function updates the model by applying the computed gradients and losses.
                # Loss Calculation: The loss function has two parts:
                # Value loss: The loss for the value function is the mean squared error between predicted values and the target values (calculated using GAE).
                # Actor loss: The loss for the policy (actor) is based on the surrogate objective from the PPO (Proximal Policy Optimization) algorithm, involving the ratio between the current and previous probabilities of actions.
                # Entropy loss: A term to encourage exploration by adding entropy to the objective function.
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                # This code shuffles the data (traj_batch, advantages, targets) into minibatches for efficient training. The training data is reshaped, and the order is randomized to reduce bias during training.
                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ACTORS"]
                ), "batch size must be equal to number of steps * number of actors"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )

                # This line performs the minibatch updates using the scan function to apply the _update_minbatch function across all minibatches.
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            # This line performs the epoch updates using the scan function to apply the _update_epoch function across all epochs.
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )

            # After all epochs of training are completed, the train_state (the updated model) is returned along with metrics (e.g., information about rewards, losses, etc.), and the random number generator (rng) is updated.
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        # This line splits the rng (random number generator) into two separate random number generators: rng and _rng. This is done so that each part of the code can use a different random stream. 
        rng, _rng = jax.random.split(rng)

        # Here, the runner_state is initialized as a tuple that includes:
        # train_state: The current state of the model (including parameters and optimization state).
        # env_state: The current state of the environment (e.g., positions, internal states of agents).
        # obsv: The current batch of observations (e.g., the states that the agents are observing from the environment).
        # _rng: The random number generator that will be used for further random operations in the loop.
        runner_state = (train_state, env_state, obsv, _rng)

        # This line performs an update loop using jax.lax.scan. Here’s how it works:
        # jax.lax.scan is a JAX function that allows you to loop over some operation while maintaining the state between iterations, making it ideal for iterative processes like training loops.
        # The _update_step function is applied iteratively, where each iteration updates the state of the runner (i.e., the model and environment) and records the metrics (e.g., loss, performance metrics).
        # runner_state: This contains all the information required for each update step (including train_state, env_state, etc.).
        # None: The second argument is None because the loop doesn’t need any additional data passed each time; it's just evolving the state.
        # config["NUM_UPDATES"]: This specifies how many times the _update_step function will be applied. Essentially, this determines how many updates (iterations) will be made to the model.
        #runner_state, metric = jax.lax.scan(
        #    _update_step, runner_state, None, config["NUM_UPDATES"]
        #)

        metrics = {}
        save_intervals = max(1, config["NUM_UPDATES"] // config["NUM_SAVES"])

        for step in range(config["NUM_UPDATES"]):
            runner_state, metric = _update_step(runner_state, step)

            if step % save_intervals == 0 or step == config["NUM_UPDATES"] - 1:
                save_model(runner_state, save_dir, f"trained_model_{step}.pkl")

            # Store each metric separately in the dictionary
            if not metrics:  # Initialize keys on the first iteration
                metrics = {key: [] for key in metric}

            for key in metric:
                metrics[key].append(metric[key])  # Append new values for each key

            if step % 100 == 0:
                print(f"Seed {seed}, step {step} completed")

        # After executing the loop, runner_state and metric will be updated:
        # runner_state: The final state of the model, environment, and random number generator after all the updates.
        # metric: A collection of metrics generated during the updates (e.g., losses, rewards, or performance indicators).
        return {"runner_state": runner_state, "metrics": jax.tree.map(lambda x: jnp.stack(x), metrics)}

    return train

In [None]:
# Get one of the classic layouts (cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit)
# Or make your own! (P=pot, O=onion, A=agent, B=plates, X=deliver)

def load_custom_layout(layout_name):
    if layout_name == "custom_1":
        custom_layout_grid = dedent(
        """
        WWBWW
        WA AW
        P   P
        W   W
        WOXOW
        """
        ).strip()

    elif layout_name == "custom_2":
        custom_layout_grid = dedent(
        """
        WWBWW
        WA AW
        P   W
        P   W
        WOXOW
        """
        ).strip()

    elif layout_name == "custom_3":
        custom_layout_grid = dedent(
        """
        WBWWBW
        P    P
        W AA W
        O    O
        WXWWXW
        """
        ).strip()

    elif layout_name == "custom_4":
        custom_layout_grid = dedent(
        """
        WBWXW
        P W O
        WAWAW
        P W O
        WBWXW
        """
        ).strip()

    elif layout_name == "custom_5":
        custom_layout_grid = dedent(
        """
        WPWPW
        W B W
        WAWAW
        W O W
        WXWXW
        """
        ).strip()
    
    custom_layout = layout_grid_to_dict(custom_layout_grid)
    print(custom_layout)
    return custom_layout

In [None]:
layout_name = "custom_1"

current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
initial_dir = f'/data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/Symmetric_Agents/{layout_name}/Checkpoints_{current_datetime}/'

print(layout_name)
print(current_datetime)
custom_layout = load_custom_layout(layout_name)

# set hyperparameters:
config = {
    # Number of possible actions in the environment. 
    "NUM_ACTIONS": 6, 
    # Controls how much the model updates its weights during optimization.
    "LR": 1e-4, 
    # Number of parallel environments running simultaneously.
    "NUM_ENVS": 8, 
    #Number of steps collected before running an update.
    "NUM_STEPS": 1500, 
    # Total number of timesteps for training.
    "TOTAL_TIMESTEPS": 5e8, 
    # Number of times each collected batch of experiences is used for gradient updates. More epochs allow for better learning from collected data but can lead to overfitting.
    "UPDATE_EPOCHS": 4, 
    # Number of minibatches used during training updates. More minibatches reduce variance but increase computational cost.
    "NUM_MINIBATCHES": 4, 
    # Discount factor for future rewards. A value close to 1 favors long-term rewards, while lower values prioritize immediate rewards.
    "GAMMA": 0.99, 
    # Controls Generalized Advantage Estimation (GAE). Higher values lead to smoother, less biased advantage estimates.
    "GAE_LAMBDA": 0.99, 
    # PPO-specific parameter that limits policy updates to prevent excessive changes. Lower values ensure stability but may slow training.
    "CLIP_EPS": 0.2, 
    # Coefficient for entropy regularization, encouraging exploration. Higher values lead to more randomness in actions.
    "ENT_COEF": 0.1, 
    # Coefficient for the value function loss. Higher values prioritize accurate value function learning.
    "VF_COEF": 0.5, 
    # Limits the magnitude of gradients, preventing instability.
    "MAX_GRAD_NORM": 0.5, 
    # Specifies the activation function for the network. Tanh helps with stable gradient flow but might limit expressiveness compared to ReLU.
    "ACTIVATION": "tanh",
    # The RL environment being used.
    "ENV_NAME": "overcooked",
    # Allows customization of environment settings, such as custom layouts.
    "ENV_KWARGS": {
      "layout" : custom_layout #write custom_layout without quotation marks if you want to use the custom layout
    },
    # If enabled, the learning rate decreases over time, improving stability in later training stages.
    "ANNEAL_LR": True, 
    # Ensures reproducibility by fixing the random seed.
    "SEED": 0,
    # Runs multiple training instances with different seeds for robustness.
    "NUM_SEEDS": 3,
    # Saves model checkpoints every N epochs.
    "SAVE_EVERY_N_EPOCHS": 50,
}

#Comment the following line if you want to use the custom_layout
#config["ENV_KWARGS"]["layout"] = overcooked_layouts[config["ENV_KWARGS"]["layout"]]

config = make_config(config)

rng = jax.random.PRNGKey(config["SEED"])
rngs = jax.random.split(rng, config["NUM_SEEDS"])

out=[]
# Training loop
for seed in range(config["NUM_SEEDS"]):
    save_dir = f'{initial_dir}Seed_{seed}/'
    train_jit = make_train(config)
    out.append(train_jit(seed, rngs[seed], save_dir))

# Extract only the necessary data (returns) and save it
results = {"returns": [out[i]["metrics"]["returned_episode_returns"] for i in range(config["NUM_SEEDS"])]}
np.savez(os.path.join(initial_dir, "plot_data.npz"), **results)
print(f"Data saved in {initial_dir}/plot_data.npz")

custom_4
2025-03-21_20-18-12
LR: 0.05
NUM_ENVS: 8
NUM_STEPS: 3000
TOTAL_TIMESTEPS: 5000000000.0
UPDATE_EPOCHS: 1000
NUM_MINIBATCHES: 100
NUM_ACTORS: 16
NUM_UPDATES: 208333
MINIBATCH_SIZE: 480


KeyboardInterrupt: 

In [None]:
# Plot the results
for i in range(config["NUM_SEEDS"]):
    returns = jnp.stack(out[i]["metrics"]["returned_episode_returns"])  # Stack list into a single array
    mean_returns = returns.mean(axis=-1).reshape(-1)  # Compute mean across the stacked dimension
    plt.plot(mean_returns, label=f"Seed {i}")
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.legend()

# Guardar la imagen en el directorio especificado
save_path = os.path.join(initial_dir, "plot.png")
plt.savefig(save_path)

# Opcional: mostrar la imagen
plt.show()

In [None]:
max_steps = int(0.17e6)  # Convertir timesteps a update steps

plt.figure(figsize=(10, 5))  # Opcional: Ajustar tamaño del gráfico

for i in range(config["NUM_SEEDS"]):
    returns = jnp.stack(out[i]["metrics"]["returned_episode_returns"])  
    mean_returns = returns.mean(axis=-1).reshape(-1)
    
    # Graficar solo hasta max_steps
    plt.plot(mean_returns[:max_steps], label=f"Seed {i}")

plt.xlabel("Update Step")
plt.ylabel("Return")
plt.legend()

# Guardar la imagen
save_path = os.path.join(initial_dir, "plot_zoom.png")
plt.savefig(save_path)
plt.show()

In [None]:
max_return_per_seed = {}

for i in range(config["NUM_SEEDS"]):
    returns = jnp.stack(out[i]["metrics"]["returned_episode_returns"])
    mean_returns = returns.mean(axis=-1).reshape(-1)
    
    max_index = jnp.argmax(mean_returns)  # Índice del máximo retorno
    max_step = max_index / config["NUM_STEPS"]  # Convertir a update steps
    max_value = mean_returns[max_index]  # Obtener el valor máximo
    
    max_return_per_seed[f"Seed {i}"] = (int(max_step), float(max_value))  # Guardar en diccionario

print(max_return_per_seed)

RECREATING RESULTS OF TRAINING

In [10]:
config = {
    "NUM_ACTIONS": 6,
    "LR": 1e-2,
    "NUM_ENVS": 8,
    "NUM_STEPS": 1000,
    "TOTAL_TIMESTEPS": 6e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 4,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.99,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "overcooked",
    "ENV_KWARGS": {
      "layout" : "custom_layout" #write custom_layout without quotation marks if you want to use the custom layout
    },
    "ANNEAL_LR": True,
    "SEED": 0,
    "NUM_SAVES": 1000,
    "NUM_SEEDS": 3
}

In [None]:
#OLD PLOT_DATA

#REVISAR

# Recreate values and figure of evolution
def OLD_recreate_results(data_path, figure = 1):
    data = np.load(data_path, allow_pickle=True)
    returns_list = data["returns"]

    # 🔹 Reconstruct `out`
    out_reconstructed = [{"metrics": {"returned_episode_returns": returns}} for returns in returns_list]

    # 🔹 Compute `max_steps_per_seed`
    max_steps_per_seed = {}
    mean_returns_per_seed = [] 

    for i in range(len(out_reconstructed)):
        returns = jnp.stack(out_reconstructed[i]["metrics"]["returned_episode_returns"])
        mean_returns = returns.mean(axis=-1).reshape(-1)

        mean_returns_per_seed.append(mean_returns)
        
        max_index = int(jnp.argmax(mean_returns))  # Index of max return
        max_step = max_index / config["NUM_STEPS"]  # Convert to update steps
        max_value = float(mean_returns[max_index])  # Get max value
        
        max_steps_per_seed[f"Seed {i}"] = (max_step, max_value)

    # 🔹 Recreate the plot
    if figure == 1:
        plt.figure()
        for i in range(len(out_reconstructed)):
            plt.plot(mean_returns_per_seed[i], label=f"Seed {i}")
    
        plt.xlabel("Update Step")
        plt.ylabel("Return")
        plt.legend()
    
        save_path = os.path.join(os.path.dirname(data_path), "plot.png")
        plt.savefig(save_path)
        plt.show()

    print("Max steps per seed (reconstructed):")
    print(max_steps_per_seed)
    return mean_returns_per_seed, max_steps_per_seed

load_datetime = '2025-03-17_10-32-29'
load_layout_name = 'custom_2'

original_dir = f'/data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/Symmetric_Agents/{load_layout_name}/Checkpoints_{load_datetime}/'

print_figure = 0

mean_returns_per_seed, max_steps_per_seed = OLD_recreate_results(os.path.join(original_dir, "plot_data.npz"), figure = print_figure)

Max steps per seed (reconstructed):
{'Seed 0': (1804.399, 200.0), 'Seed 1': (1670.799, 220.0), 'Seed 2': (17607.999, 240.0)}


In [178]:
# NEW REVISAR

def recreate_results(directory, figure=1):
    # Find all files matching the pattern "plot_data_seed_{seed}.npz"
    seed_files = [f for f in os.listdir(directory) if re.match(r"plot_data_seed_\d+\.npz", f)]
    seed_files.sort(key=lambda x: int(re.findall(r"\d+", x)[0]))  # Sort by seed number

    if not seed_files:
        print("No seed files found!")
        return None, None, None
    
    max_steps_per_seed = {}
    mean_returns_per_seed = [] 

    i = 0
    for file in seed_files:
        file_path = os.path.join(directory, file)
        data = np.load(file_path, allow_pickle=True)
        returns_list = jnp.expand_dims(data["returns"], axis=0)
        out_reconstructed = [{"metrics": {"returned_episode_returns": returns}} for returns in returns_list]
    
        returns = jnp.stack(out_reconstructed[0]["metrics"]["returned_episode_returns"])
        mean_returns = returns.mean(axis=-1).reshape(-1)

        mean_returns_per_seed.append(mean_returns)
        
        max_index = int(jnp.argmax(mean_returns))  # Index of max return
        max_step = max_index / config["NUM_STEPS"]  # Convert to update steps
        max_value = float(mean_returns[max_index])  # Get max value
        
        max_steps_per_seed[f"Seed {i}"] = (max_step, max_value)

        i += 1
        
    # Recreate the plot
    if figure == 1:
        plt.figure()
        for i in range(len(seed_files)):
            plt.plot(mean_returns_per_seed[i], label=f"Seed {i}")
    
        plt.xlabel("Update Step")
        plt.ylabel("Return")
        plt.legend()
    
        save_path = os.path.join(directory, "plot.png")
        plt.savefig(save_path)
        plt.show()

    print("Max steps per seed (reconstructed):")
    print(max_steps_per_seed)
    return mean_returns_per_seed, max_steps_per_seed

In [None]:
# Directory containing the seed files
load_datetime = '2025-03-18_15-22-00'
load_layout_name = 'custom_5'
original_dir = f'/data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/Symmetric_Agents/{load_layout_name}/Checkpoints_{load_datetime}/'

# Run the function
print_figure = 0
mean_returns_per_seed, max_steps_per_seed = recreate_results(original_dir, figure=print_figure)

Max steps per seed (reconstructed):
{'Seed 0': (8874.799, 180.0), 'Seed 1': (12084.799, 180.0), 'Seed 2': (8452.799, 180.0)}


In [None]:
max_steps = int(0.08e8)  # Convertir timesteps a update steps

plt.figure(figsize=(10, 5))  # Opcional: Ajustar tamaño del gráfico

for i in range(config["NUM_SEEDS"]):    
    # Graficar solo hasta max_steps
    plt.plot(mean_returns[:max_steps], label=f"Seed {i}")

plt.xlabel("Update Step")
plt.ylabel("Return")
plt.legend()

# Guardar la imagen
save_path = os.path.join(original_dir, "plot_zoom.png")
plt.savefig(save_path)
plt.show()

In [15]:
def find_value_return_interval(seed, mean_returns_per_seed, config, searched_value, min_step, max_step):
    for step in range(int(min_step), int(max_step + 1)):
        scaled_step = step // config["NUM_STEPS"]  # Adjust step according to NUM_STEPS
        if mean_returns_per_seed[seed][step] == searched_value:
            print(f"First return value {searched_value} in seed {seed} at update_step: {step}. In update step: {scaled_step}")
            return scaled_step

    print(f"No return value {searched_value} found in seed {seed} within the specified range.")
    return None

def find_different_value_return_interval(seed, mean_returns_per_seed, config, searched_value, min_step, max_step):
    for step in range(int(min_step), int(max_step + 1)):
        scaled_step = step // config["NUM_STEPS"]  # Adjust step according to NUM_STEPS
        if mean_returns_per_seed[seed][step] != searched_value:
            print(f"First return value {searched_value} in seed {seed} at update_step: {step}. In update step: {scaled_step}")
            return scaled_step

    print(f"No return value {searched_value} found in seed {seed} within the specified range.")
    return None

In [102]:
seed = 2
searched_value = 0
min_step = 4e6
max_step = 4.3e6

found_step = find_value_return_interval(seed, mean_returns_per_seed, config, searched_value, min_step, max_step)

First return value 0 in seed 2 at update_step: 4063599. In update step: 4063


In [None]:
def get_trained_steps(original_dir, config, mean_returns):
    seed_pattern = re.compile(r"Seed_(\d+)")
    step_pattern = re.compile(r"trained_model_(\d+)\.pkl")

    results = {}  # Final dictionary with the reconstructed values

    for seed_folder in os.listdir(original_dir):
        seed_match = seed_pattern.match(seed_folder)
        if seed_match:
            seed = int(seed_match.group(1))
            seed_path = os.path.join(original_dir, seed_folder)
            
            if os.path.isdir(seed_path):
                steps = []
                for file in os.listdir(seed_path):
                    step_match = step_pattern.match(file)
                    if step_match:
                        steps.append(int(step_match.group(1)))

                steps.sort() 
                results[seed] = {}

                for step in steps:
                    scaled_step = step * config["NUM_STEPS"]                   
                    results[seed][step] = mean_returns[scaled_step]

    return results

In [None]:
steps_per_seed = get_trained_steps(original_dir, config, mean_returns)
print(steps_per_seed[2])

LOADING AND EVALUATING

In [None]:
fixed_step = 25000  # Change this to the desired step

seed_idx = 0  # Change if needed
load_datetime = '2025-03-18_15-22-00'
load_layout_name = 'custom_5'
custom_layout = load_custom_layout(load_layout_name)

original_dir = f'/data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/Symmetric_Agents/{load_layout_name}/Checkpoints_{load_datetime}/'
load_dir = f'{original_dir}Seed_{seed_idx}/'
load_filename = f"trained_model_{fixed_step}.pkl"

# set hyperparameters:
config = {
    "NUM_ACTIONS": 6,
    "LR": 1e-2,
    "NUM_ENVS": 8,
    "NUM_STEPS": 1000,
    "TOTAL_TIMESTEPS": 2e8,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 4,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.99,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "overcooked",
    "ENV_KWARGS": {
      "layout" : custom_layout #write custom_layout without quotation marks if you want to use the custom layout
    },
    "ANNEAL_LR": True,
    "SEED": 0,
    "NUM_SAVES": 500,
    "NUM_SEEDS": 3
}
#Comment the following line if you want to use the custom_layout
#config["ENV_KWARGS"]["layout"] = overcooked_layouts[config["ENV_KWARGS"]["layout"]]

config = make_config(config)

# Initialize network
network = ActorCritic(config["NUM_ACTIONS"], activation=config["ACTIVATION"])

# Load parameters into the network
loaded_params, closest_step = find_closest_checkpoint(fixed_step, load_dir)
    
# Create the train state with loaded parameters
train_state = TrainState.create(
    apply_fn=network.apply,
    params=loaded_params,  # Restored parameters
    tx=choose_tx(config),
)

print(f"Model restored successfully from step {closest_step}")
print(f"with mean return {float(mean_returns_per_seed[seed_idx][closest_step * config['NUM_STEPS']])}!")

FrozenDict({
    wall_idx: Array([ 4,  5,  6,  7,  8, 13, 15, 17, 22, 24, 26, 31, 33, 35, 40, 41, 42,
           43, 44], dtype=int32),
    agent_idx: Array([23, 25], dtype=int32),
    goal_idx: Array([41, 43], dtype=int32),
    plate_pile_idx: Array([15], dtype=int32),
    onion_pile_idx: Array([33], dtype=int32),
    pot_idx: Array([5, 7], dtype=int32),
    height: 6,
    width: 9,
})
LR: 0.01
NUM_ENVS: 8
NUM_STEPS: 1000
TOTAL_TIMESTEPS: 200000000.0
UPDATE_EPOCHS: 4
NUM_MINIBATCHES: 4
NUM_ACTORS: 16
NUM_UPDATES: 25000
MINIBATCH_SIZE: 4000
Found checkpoint: trained_model_24999.pkl for step 24999
Model parameters loaded from /data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/custom_5/Checkpoints_2025-03-18_15-22-00/Seed_0/trained_model_24999.pkl
Model restored successfully from step 24999
with mean return 180.0!


In [197]:
# Set environment parameters
max_steps = 1000
key = jax.random.PRNGKey(0)
key, key_r, key_a = jax.random.split(key, 3)

# Choose layout
#layout = overcooked_layouts["cramped_room"]
layout = custom_layout

# Instantiate environment
env = make('overcooked', layout=layout, max_steps=max_steps)

# Reset environment
obs, state = env.reset(key_r)
print('List of agents in environment:', env.agents)

# Visualization setup
viz = OvercookedVisualizer()
state_seq = []

List of agents in environment: ['agent_0', 'agent_1']


In [None]:
# Run environment loop using the trained model
for _ in range(max_steps):
    state_seq.append(state)
    
    # Get model-based actions
    key, key_s = jax.random.split(key, 2)
    
    actions = {}
    for i, agent in enumerate(env.agents):
        agent_obs = obs[agent]  # Extract observation for each agent
        action_logits, value = network.apply(train_state.params, agent_obs.flatten())  # Get model's action distribution
        action = action_logits.sample(seed=key_s)
        key, key_s = jax.random.split(key)
        actions[agent] = action
    
    # Step environment
    obs, state, rewards, dones, infos = env.step(key_s, state, actions)

In [199]:
# FIX VISUALIZATION (SAVING IT WORKS)

# Render to screen
#for s in state_seq:
#    viz.render(env.agent_view_size, s, highlight=False)
#    time.sleep(0.1)

# Save animation
agent_view_size = 5

output_filename = f"{load_dir}/animation_trained_model_{closest_step}_max_steps_{max_steps}.gif"

custom_animate(state_seq, agent_view_size=agent_view_size, filename=output_filename)

print(f"Animation saved to {output_filename} with adjusted speed")

gif_to_mp4(output_filename)

Animation saved to /data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/custom_5/Checkpoints_2025-03-18_15-22-00/Seed_0//animation_trained_model_24999_max_steps_1000.gif with adjusted speed
MP4 saved to /data/samuel_lozano/hfsp_collective_learning/data_JaxMARL/custom_5/Checkpoints_2025-03-18_15-22-00/Seed_0//animation_trained_model_24999_max_steps_1000.mp4
