In [None]:
# FOR USE IN CONJUNCTION WITH https://github.com/lweitkamp/feudalnets-pytorch


# fun_minigrid_train.py
"""
Minimal Feudal Network (FuN) training script for the MiniGrid environment.

Runs in a single Jupyter cell.
Prints average extrinsic reward every 10 epochs.
"""

import gymnasium as gym
import numpy as np
import torch
from types import SimpleNamespace

# Import wrappers from the MiniGrid package
from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper

import numpy as np
import torch
import torch
import torch.nn.functional as F
from torch.distributions import Categorical


def take_action(action_logits):
    probs = F.softmax(action_logits, dim=-1)
    dist = Categorical(probs)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    entropy = dist.entropy()
    return action.cpu().numpy(), log_prob, entropy


class ReturnWrapper:
    def __init__(self, env):
        self.env = env

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        done = np.logical_or(terminated, truncated)

        # Enforce that `info` is a list of dicts
        if isinstance(info, list):
            fixed_info = []
            for i in range(len(info)):
                entry = info[i] if isinstance(info[i], dict) else {}
                entry = entry.copy()
                if done[i]:
                    entry["returns/episodic_reward"] = reward[i]
                fixed_info.append(entry)
            info = fixed_info
        else:
            # Single environment case (shouldn't happen in SyncVectorEnv)
            info = {} if not isinstance(info, dict) else info.copy()
            if done:
                info["returns/episodic_reward"] = reward

        return obs, reward, terminated, truncated, info

    def __getattr__(self, name):
        return getattr(self.env, name)



from storage import Storage
from feudalnet import FeudalNetwork, feudal_loss


def make_minigrid_envs(env_name: str, num_envs: int, seed: int = 0):
    """Create vectorised MiniGrid environments that output 84×84 RGB frames."""
    def make_env(rank: int):
        def _thunk():
            env = gym.make(env_name)
            env = RGBImgObsWrapper(env, tile_size=14)
            env = ImgObsWrapper(env)
            env = ReturnWrapper(env)
            env.reset(seed=seed + rank)
            return env
        return _thunk

    return gym.vector.SyncVectorEnv([make_env(i) for i in range(num_envs)])



args = SimpleNamespace(
    env_name="MiniGrid-Empty-Random-6x6-v0",
    num_workers=8,
    num_steps=64,
    epochs=50_000,
    seed=0,
    lr=5e-4,
    grad_clip=5.0,
    entropy_coef=0.01,
    time_horizon=10,
    hidden_dim_manager=128,
    hidden_dim_worker=16,
    gamma_w=0.99,
    gamma_m=0.999,
    alpha=0.5,
    eps=1e-5,
    dilation=10,
    mlp=0,
    cuda=torch.cuda.is_available(),
)
device = torch.device("cuda" if args.cuda else "cpu")
args.device = device
torch.manual_seed(args.seed)
np.random.seed(args.seed)

envs = make_minigrid_envs(args.env_name, args.num_workers, args.seed)
feudalnet = FeudalNetwork(
    num_workers=args.num_workers,
    input_dim=envs.single_observation_space.shape,
    hidden_dim_manager=args.hidden_dim_manager,
    hidden_dim_worker=args.hidden_dim_worker,
    n_actions=envs.single_action_space.n,
    time_horizon=args.time_horizon,
    dilation=args.dilation,
    device=device,
    mlp=args.mlp,
    args=args,
)
optimizer = torch.optim.RMSprop(
    feudalnet.parameters(), lr=args.lr, alpha=0.99, eps=1e-5
)

goals, states, masks = feudalnet.init_obj()
obs, _ = envs.reset()

episode_rewards_history = []
epoch_rewards_history = []

for epoch in range(1, args.epochs + 1):
    feudalnet.repackage_hidden()
    goals = [g.detach() for g in goals]

    storage = Storage(
        size=args.num_steps,
        keys=["r", "r_i", "v_w", "v_m", "logp", "entropy",
                "s_goal_cos", "m", "ret_w", "ret_m", "adv_m", "adv_w"],
    )

    epoch_rewards = []

    for _ in range(args.num_steps):
        action_dist, goals, states, value_m, value_w = feudalnet(
            obs, goals, states, masks[-1]
        )
        actions, logp, entropy = take_action(action_dist)
        obs, reward, terminated, truncated, infos = envs.step(actions)
        done = np.logical_or(terminated, truncated)

        for i, info in enumerate(infos):
            episode_rewards_history.append(reward[i])
            epoch_rewards.append(reward[i])


        mask = torch.FloatTensor(1 - done).unsqueeze(-1).to(device)
        masks.pop(0)
        masks.append(mask)

        storage.add({
            "r": torch.FloatTensor(reward).unsqueeze(-1).to(device),
            "r_i": feudalnet.intrinsic_reward(states, goals, masks),
            "v_w": value_w,
            "v_m": value_m,
            "logp": logp.unsqueeze(-1),
            "entropy": entropy.unsqueeze(-1),
            "s_goal_cos": feudalnet.state_goal_cosine(states, goals, masks),
            "m": mask,
        })

    with torch.no_grad():
        *_, next_v_m, next_v_w = feudalnet(obs, goals, states, mask, save=False)

    optimizer.zero_grad()
    loss, _ = feudal_loss(storage, next_v_m.detach(), next_v_w.detach(), args)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(feudalnet.parameters(), args.grad_clip)
    optimizer.step()

    epoch_rewards_history.append(epoch_rewards)
    if epoch % 10 == 0:
        last_10_rewards = [r for sub in epoch_rewards_history[-10:] for r in sub]
        if last_10_rewards:
            avg = np.mean(last_10_rewards)
            print(f"Epoch {epoch:4d} | average extrinsic reward (last 10 epochs): {avg:.3f}")
        else:
            print(f"Epoch {epoch:4d} | no episodes completed in the last 10 epochs")

envs.close()

Epoch   10 | average extrinsic reward (last 10 epochs): 0.053
Epoch   20 | average extrinsic reward (last 10 epochs): 0.013
Epoch   30 | average extrinsic reward (last 10 epochs): 0.038
Epoch   40 | average extrinsic reward (last 10 epochs): 0.032
Epoch   50 | average extrinsic reward (last 10 epochs): 0.035
Epoch   60 | average extrinsic reward (last 10 epochs): 0.031
Epoch   70 | average extrinsic reward (last 10 epochs): 0.035
Epoch   80 | average extrinsic reward (last 10 epochs): 0.060
Epoch   90 | average extrinsic reward (last 10 epochs): 0.034
Epoch  100 | average extrinsic reward (last 10 epochs): 0.036
Epoch  110 | average extrinsic reward (last 10 epochs): 0.058
Epoch  120 | average extrinsic reward (last 10 epochs): 0.067
Epoch  130 | average extrinsic reward (last 10 epochs): 0.052
Epoch  140 | average extrinsic reward (last 10 epochs): 0.043
Epoch  150 | average extrinsic reward (last 10 epochs): 0.045
Epoch  160 | average extrinsic reward (last 10 epochs): 0.068
Epoch  1

KeyboardInterrupt: 