In [1]:
import ast
import importlib
import json
import os
import pickle
import sys
from glob import glob
from pathlib import Path

import climlab
import fedrl_climate_envs
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import tqdm
import xarray as xr

In [2]:
BASE_DIR = "/gws/nopw/j04/ai4er/users/pn341/climate-rl-fedrl"
RECORDS_DIR = f"{BASE_DIR}/records"
DATASETS_DIR = f"{BASE_DIR}/datasets"
IMGS_DIR = f"{BASE_DIR}/results/imgs/"
STEP_COUNT = 20000
NUM_STEPS = 200

sys.path.append(BASE_DIR)

In [3]:
ENV_ID = "EnergyBalanceModel-v2"
EXPERIMENT_ID = "ebm-v2-optim-L-20k-a2-fed05"
OPTIM_GROUP = "ebm-v1-optim-L-20k"

ALGO = "ddpg"
SEED = 1
CLIENT_ID = 0

os.environ["INFERENCE"] = "1"

In [4]:
def get_make_env(algo):
    file_path = Path(f"{BASE_DIR}/rl-algos/{algo}/main.py").resolve()
    source = file_path.read_text()

    parsed = ast.parse(source, filename=str(file_path))
    func_defs = [
        node
        for node in parsed.body
        if isinstance(node, ast.FunctionDef) and node.name == "make_env"
    ]

    if not func_defs:
        raise ValueError(f"'make_env' not found in {file_path}")

    make_env_code = ast.Module(body=func_defs, type_ignores=[])
    compiled = compile(make_env_code, filename=str(file_path), mode="exec")

    local_namespace = {"gym": gym, "np": np, "BASE_DIR": BASE_DIR}
    exec(compiled, local_namespace)
    return local_namespace["make_env"]


def get_actor(algo):
    module_path = f"rl-algos.{algo}.{algo}_actor"
    actor_module = importlib.import_module(module_path)
    Actor = getattr(actor_module, "Actor")
    return Actor


def get_agent(algo):
    module_path = f"rl-algos.{algo}.{algo}_agent"
    agent_module = importlib.import_module(module_path)
    Agent = getattr(agent_module, "Agent")
    return Agent

In [5]:
if "64L" in EXPERIMENT_ID:
    actor_layer_size = critic_layer_size = 64
else:
    with open(
        f"{BASE_DIR}/param_tune/results/{OPTIM_GROUP}/best_results.json",
        "r",
    ) as file:
        opt_params = {
            k: v
            for k, v in json.load(file)[ALGO].items()
            if k not in {"algo", "episodic_return", "date"}
        }
        for key, value in opt_params.items():
            if key == "actor_critic_layer_size":
                actor_layer_size = critic_layer_size = value

In [6]:
actor_weights_fn = glob(
    RECORDS_DIR
    + f"/{EXPERIMENT_ID}_*/*_{ALGO}_torch__{SEED}__*__*/fedrl-weights/actor/*-{STEP_COUNT}.pth"
)[0]

In [7]:
make_env = get_make_env(ALGO)
env_args = [
    ENV_ID,
    SEED,
    CLIENT_ID,
    0,
    False,
    "test",
]
if ALGO in ["ppo", "trpo"]:
    env_args = env_args + [0.99, 10]
else:
    env_args = env_args + [10]
envs = gym.vector.SyncVectorEnv([make_env(*env_args)])

Loading NCEP surface temperature data ...
Loading NCEP surface temperature data ...
[RL Env] Environment ID: 0
[RL Env] Number of clients: 2


  logger.deprecation(


In [8]:
Actor = get_actor(ALGO)
actor = Actor(envs, actor_layer_size).to("cpu")
actor_weights = torch.load(actor_weights_fn)
actor.load_state_dict(actor_weights)

<All keys matched successfully>