In [1]:
import jax

print("jax devices", jax.devices())
import jax.numpy as jnp
import time
from train import make_train
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
import wandb
from config import config


def generate_combinations(
    lr,
    ent_coef,
    transition_model_lr,
    max_grad_norm,
    schedule_accelerator,
    num_envs,
    num_minibatches,
    clip_eps,
):
    """
    from chatGPT
    """
    # Generate all possible combinations of lr, ent_coef, and vf_coef
    combinations = np.array(
        list(
            product(
                lr,
                ent_coef,
                transition_model_lr,
                max_grad_norm,
                schedule_accelerator,
                num_envs,
                num_minibatches,
                clip_eps,
            )
        )
    )

    # Split the combinations into separate arrays
    lr_combinations = combinations[:, 0]
    ent_coef_combinations = combinations[:, 1]
    transition_model_lr_combinations = combinations[:, 2]
    max_grad_norm_combinations = combinations[:, 3]
    schedule_accelerator_combinations = combinations[:, 4]
    num_envs_combinations = combinations[:, 5]
    num_minibatches_combinations = combinations[:, 6]
    clips_eps_combinations = combinations[:, 7]

    # Return a tuple of the resulting arrays
    return (
        lr_combinations,
        ent_coef_combinations,
        transition_model_lr_combinations,
        max_grad_norm_combinations,
        schedule_accelerator_combinations,
        num_envs_combinations,
        num_minibatches_combinations,
        clips_eps_combinations,
    )


# group = wandb.util.generate_id()

config.update({"ANNEAL_LR": True})
config.update({"NUM_ENVS": 2})
config.update({"NUM_MINIBATCHES": 1})
ent_coef_search = [0.0, 0.001, 0.01]
lr_search = [0.00025, 0.0025, 0.000025]
transition_model_lr_search = [0.0001]
max_grad_norm_search = [0.05, 0.5, 5]
schedule_accelerator_search = [1.0]
num_envs = [config["NUM_ENVS"]]
num_minibatches = [config["NUM_MINIBATCHES"]]
clip_eps = [0.02, 0.2, 2]
(
    lr_combinations,
    ent_coef_combinations,
    transition_model_lr_combinations,
    max_grad_norm_combinations,
    schedule_accelerator_combinations,
    num_envs_combinations,
    num_minibatches_combinations,
    clips_eps_combinations,
) = generate_combinations(
    lr_search,
    ent_coef_search,
    transition_model_lr_search,
    max_grad_norm_search,
    schedule_accelerator_search,
    num_envs,
    num_minibatches,
    clip_eps,
)
lr_combinations = jnp.array(lr_combinations)
ent_coef_combinations = jnp.array(ent_coef_combinations)
transition_model_lr_combinations = jnp.array(transition_model_lr_combinations)
max_grad_norm_combinations = jnp.array(max_grad_norm_combinations)
schedule_accelerator_combinations = jnp.array(schedule_accelerator_combinations)
num_envs_combinations = jnp.array(num_envs_combinations)
num_minibatches_combinations = jnp.array(num_minibatches_combinations)
clips_eps_combinations = jnp.array(clips_eps_combinations)
combinations = [
    lr_combinations,
    ent_coef_combinations,
    transition_model_lr_combinations,
    max_grad_norm_combinations,
    schedule_accelerator_combinations,
    num_envs_combinations.astype(int),
    num_minibatches_combinations.astype(int),
    clips_eps_combinations,
]

NUMBER_OF_SEEDS = 5
# num_minibatches_combinations = jnp.ones([81,], dtype=jnp.int32) * 2

rng = jax.random.PRNGKey(NUMBER_OF_SEEDS * len(combinations))
rngs = jax.random.split(rng, NUMBER_OF_SEEDS)

train_vvjit = jax.jit(
    jax.vmap(jax.vmap(make_train(config), in_axes=(None, 0)), in_axes=(0, None))
)
t0 = time.time()
outs = jax.block_until_ready(train_vvjit(combinations, rngs))
print(f"time: {time.time() - t0:.2f} s")

dict_outs = {}
combinations = jnp.stack(combinations, axis=1)
for i in range(len(combinations)):
    (
        lr,
        ent_coef,
        transition_lr,
        max_grad_norm,
        schedule_accelerator,
        num_envs,
        num_minibatches,
        clip_eps,
    ) = combinations[i]
    to_plot_ent_coef = str(round(ent_coef, 4))
    to_plot_lr = str(round(lr, 4))
    to_plot_transition_model_lr = str(round(transition_lr, 4))
    to_plot_max_grad_norm = str(round(max_grad_norm, 4))
    to_plot_schedule_accelerator = str(round(schedule_accelerator, 4))
    to_plot_num_envs = str(round(num_envs, 4))
    to_plot_num_minibatches = str(round(num_minibatches, 4))
    to_plot_clip_eps = str(round(clip_eps, 4))
    new_config = {
        "LR": to_plot_lr,
        "TRANSITION_MODEL_LR": to_plot_transition_model_lr,
        "ENT_COEF": to_plot_ent_coef,
        "MAX_GRAD_NORM": to_plot_max_grad_norm,
        "SCHEDULE_ACCELERATOR": to_plot_schedule_accelerator,
        "NUM_ENVS": to_plot_num_envs,
        "NUM_MINIBATCHES": to_plot_num_minibatches,
        "CLIP_EPS": to_plot_clip_eps,
    }
    config.update(new_config)

    # wandb.init(project="purejaxrl", entity="self-supervisor", config=config, group=group)
    # list_to_log = [j.item() for j in outs["metrics"]["returned_episode_returns"][i].mean(0).mean(-1).reshape(-1)]
    # for a_val_to_log in list_to_log:
    #     wandb.log({"episode_returns": a_val_to_log})
    # wandb.finish()

    # plt.plot(
    #     outs["metrics"]["returned_episode_returns"][i].mean(0).mean(-1).reshape(-1),
    # )
    dict_outs[
        f"ent_coef={to_plot_ent_coef}, lr={to_plot_lr}, transition_lr={to_plot_transition_model_lr}, max_grad_norm={to_plot_max_grad_norm}, schedule_accelerator={to_plot_schedule_accelerator}, num_envs={to_plot_num_envs}, num_minibatches={to_plot_num_minibatches}, clip_eps={to_plot_clip_eps}"
    ] = round(
        outs["metrics"]["returned_episode_returns"][i]
        .mean(0)
        .mean(-1)
        .reshape(-1)[:-1000]
        .mean()
        .item(),
        1,
    )
# plt.savefig("hyperparam_search.png")
# plt.close()
dict_outs = {
    k: v
    for k, v in sorted(
        dict_outs.items(), key=lambda item: np.mean(item[1]), reverse=True
    )
}
headers = [
    "return",
    "ent_coef",
    "lr",
    "transition_lr",
    "max_grad_norm",
    "schedule_accelerator",
    "num_envs",
    "num_minibatches",
    "clip_eps",
]
print("|".join(headers))
print("-" * (len(headers) * 12))
for key, value in dict_outs.items():
    (
        ent_coef,
        lr,
        transition_lr,
        max_grad_norm,
        schedule_accelerator,
        num_envs,
        num_minibatches,
        clip_eps,
    ) = key.split(", ")
    row_values = [
        "{:.2f}".format(value),
        ent_coef.split("=")[1],
        lr.split("=")[1],
        transition_lr.split("=")[1],
        max_grad_norm.split("=")[1],
        schedule_accelerator.split("=")[1],
        num_envs.split("=")[1],
        clip_eps.split("=")[1],
    ]
    print("|".join(row_values))

jax devices [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
jax devices [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]


  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


time: 28.16 s
return|ent_coef|lr|transition_lr|max_grad_norm|schedule_accelerator|num_envs|num_minibatches|clip_eps
------------------------------------------------------------------------------------------------------------
259.90|0.0009999999|0.0025|1e-04|0.5|1.0|2.0|0.02
259.90|0.0009999999|0.0025|1e-04|0.5|1.0|2.0|0.19999999
259.90|0.0009999999|0.0025|1e-04|0.5|1.0|2.0|2.0
246.20|0.0|0.0025|1e-04|0.049999997|1.0|2.0|0.02
246.20|0.0|0.0025|1e-04|0.049999997|1.0|2.0|0.19999999
246.20|0.0|0.0025|1e-04|0.049999997|1.0|2.0|2.0
244.40|0.0009999999|0.0002|1e-04|5.0|1.0|2.0|0.02
244.40|0.0009999999|0.0002|1e-04|5.0|1.0|2.0|0.19999999
244.40|0.0009999999|0.0002|1e-04|5.0|1.0|2.0|2.0
237.40|0.01|0.0025|1e-04|0.5|1.0|2.0|0.02
237.40|0.01|0.0025|1e-04|0.5|1.0|2.0|0.19999999
237.40|0.01|0.0025|1e-04|0.5|1.0|2.0|2.0
236.80|0.01|0.0002|1e-04|5.0|1.0|2.0|0.02
236.80|0.01|0.0002|1e-04|5.0|1.0|2.0|0.19999999
236.80|0.01|0.0002|1e-04|5.0|1.0|2.0|2.0
224.40|0.0009999999|0.0025|1e-04|0.049999997|1.0|2.