In [1]:
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
import gym, numpy as np
from itertools import count
from collections import deque

# ───────────────────────── FuN building blocks ──────────────────────────
class DilatedLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, dilation: int):
        super().__init__()
        self.hidden_size, self.dilation = hidden_size, dilation
        self.cell = nn.LSTMCell(input_size, hidden_size)
        self.register_buffer("_h", torch.zeros(dilation, 1, hidden_size))
        self.register_buffer("_c", torch.zeros_like(self._h))

    def reset(self, batch: int, device: torch.device):
        self._h = torch.zeros(self.dilation, batch, self.hidden_size, device=device)
        self._c = torch.zeros_like(self._h)

    def forward(self, x: torch.Tensor, t: int):
        i = t % self.dilation
        h, c = self.cell(x, (self._h[i], self._c[i]))
        self._h = self._h.clone()
        self._c = self._c.clone()
        self._h[i] = h
        self._c[i] = c
        return h


class Perception(nn.Module):
    def __init__(self, percept_size: int, hidden_size: int):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(7 * 7 * 3, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, percept_size)
        )

    def forward(self, x: torch.Tensor):
        return self.encoder(x.view(x.size(0), -1))  # Flatten the grid



class Manager(nn.Module):
    def __init__(self, percept_size: int, state_size: int, dilation: int):
        super().__init__()
        self.space = nn.Linear(percept_size, state_size)
        self.rnn = DilatedLSTM(state_size, state_size, dilation)

    def reset(self, batch: int, device: torch.device):
        self.rnn.reset(batch, device)

    def forward(self, z: torch.Tensor, t: int):
        s = F.relu(self.space(z))
        g_hat = self.rnn(s, t)
        return s, F.normalize(g_hat, dim=-1)


class Worker(nn.Module):
    def __init__(self, percept_size: int, action_n: int, state_size: int, goal_size: int):
        super().__init__()
        self.goal_size, self.action_n = goal_size, action_n
        self.goal_embed = nn.Linear(state_size, goal_size, bias=False)
        self.rnn = nn.LSTMCell(percept_size, goal_size * action_n)
        self.register_buffer("_h", torch.zeros(1, goal_size * action_n))
        self.register_buffer("_c", torch.zeros_like(self._h))

    def reset(self, batch: int, device: torch.device):
        shape = (batch, self.goal_size * self.action_n)
        self._h = torch.zeros(*shape, device=device)
        self._c = torch.zeros_like(self._h)

    def forward(self, z: torch.Tensor, goal_sum: torch.Tensor):
        h, c = self.rnn(z, (self._h, self._c))
        self._h, self._c = h, c
        U = h.view(-1, self.goal_size, self.action_n)
        w = self.goal_embed(goal_sum)
        logits = (w.unsqueeze(1) @ U).squeeze(1)
        return logits


class FuNAgent(nn.Module):
    def __init__(self, percept_size=128, hidden=256, state=64,
                 goal=8, dilation=10, action_n=7):
        super().__init__()
        self.dilation = dilation
        self.percept = Perception(percept_size, hidden)
        self.manager = Manager(percept_size, state, dilation)
        self.worker = Worker(percept_size, action_n, state, goal)

        self.worker_v = nn.Sequential(nn.Linear(state, 128), nn.ReLU(), nn.Linear(128, 1))
        self.manager_v = nn.Sequential(nn.Linear(state, 128), nn.ReLU(), nn.Linear(128, 1))

        self._g_queue: deque[torch.Tensor] = deque(maxlen=dilation)

    def reset(self, batch: int, device: torch.device):
        self._g_queue.clear()
        self.manager.reset(batch, device)
        self.worker.reset(batch, device)

    def forward(self, x: torch.Tensor, t: int):
        z = self.percept(x)
        s_t, g_t = self.manager(z, t)
        if torch.rand(1).item() < 0.05:
            noise = torch.randn_like(g_t, device=g_t.device)
            g_t = F.normalize(noise, dim=-1)
        self._g_queue.append(g_t)
        g_sum = torch.stack(tuple(self._g_queue)).sum(0)
        logits = self.worker(z, g_sum)
        return logits, s_t, g_t


def discounted_cumsum(x: torch.Tensor, gamma: float) -> torch.Tensor:
    out = torch.empty_like(x, device=x.device)
    G = torch.tensor(0.0, device=x.device)
    for t in reversed(range(len(x))):
        G = x[t] + gamma * G
        out[t] = G
    return out


def preprocess(frame: np.ndarray) -> torch.Tensor:
    frame = torch.tensor(frame, dtype=torch.float32) / 10.0  # Normalize discrete values
    return frame.unsqueeze(0).to(device)  # Add batch dimension


from gym import Wrapper

class CustomUnlockRewardWrapper(Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.key_picked = False

    def reset(self, **kwargs):
        self.key_picked = False
        obs, info = self.env.reset(**kwargs)
        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        try:
            if not self.key_picked and self.env.carrying and self.env.carrying.type == "key":
                reward = 1.0
                self.key_picked = True
            elif terminated and self.env.door and self.env.door.is_open:
                reward = 1.0
            else:
                reward = 0.0
        except:
            pass
        
        return obs, reward, terminated, truncated, info


# ─────────────────────────── main loop ────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

ENV_ID = "MiniGrid-Empty-Random-5x5-v0"#"MiniGrid-Unlock-v0"
NUM_EPISODES = 20_000
DILATION = 10
HORIZON = DILATION
GAMMA_EXT = 0.99
GAMMA_INT = 0.90
INTRINSIC_SCALE = 0.1
LR_MANAGER = 3e-4
LR_WORKER = 3e-4
CLIP_NORM = 40.0

env = CustomUnlockRewardWrapper(gym.make(ENV_ID, render_mode="rgb_array"))
action_n = env.action_space.n
agent = FuNAgent(dilation=DILATION, action_n=action_n).to(device)

optimizer = optim.Adam([
    {"params": list(agent.percept.parameters()) +
               list(agent.worker.parameters()) +
               list(agent.worker_v.parameters()),
     "lr": LR_WORKER},
    {"params": list(agent.manager.parameters()) +
               list(agent.manager_v.parameters()),
     "lr": LR_MANAGER},
])

for ep in range(1, NUM_EPISODES + 1):
    obs, _ = env.reset()
    img = preprocess(obs["image"])
    agent.reset(batch=1, device=device)

    states, goals = [], []
    logps, entrs = [], []
    r_exts = []

    for t in count():
        logits, s_t, g_t = agent(img, t)
        dist = torch.distributions.Categorical(logits=logits.squeeze(0))
        a = dist.sample()
        logps.append(dist.log_prob(a))
        entrs.append(dist.entropy())

        obs, r_ext, term, trunc, _ = env.step(a.item())
        img = preprocess(obs["image"])

        states.append(s_t.squeeze(0))
        goals.append(g_t.squeeze(0))
        r_exts.append(torch.tensor(r_ext, dtype=torch.float32, device=device))

        if term or trunc:
            break

    s = torch.stack(states)
    g = torch.stack(goals)
    logp = torch.stack(logps)
    entr = torch.stack(entrs)
    r_ext = torch.stack(r_exts)
    T = s.size(0)

    r_int = torch.zeros(T, device=device)
    for t in range(T):
        h = min(HORIZON, t)
        if h > 0:
            diff = s[t].unsqueeze(0) - s[t - h:t]
            sim = F.cosine_similarity(diff, g[t - h:t], dim=-1)
            r_int[t] = sim.mean()

    R_ext = discounted_cumsum(r_ext, GAMMA_EXT)
    R_int = discounted_cumsum(r_int, GAMMA_INT)
    R_tot = R_ext + INTRINSIC_SCALE * R_int

    V_w = agent.worker_v(s).squeeze(-1)
    V_m = agent.manager_v(s).squeeze(-1)

    adv_w = (R_tot - V_w.detach())
    if len(adv_w) > 1:
        adv_w = (adv_w - adv_w.mean()) / (adv_w.std() + 1e-8)

    adv_m = (R_ext - V_m.detach())
    if len(adv_m) > 1:
        adv_m = (adv_m - adv_m.mean()) / (adv_m.std() + 1e-8)

    entropy = entr.mean()
    pg_w = -(adv_w * logp).mean()
    v_w_loss = 0.5 * F.mse_loss(V_w, R_tot.detach())
    loss_w = pg_w + v_w_loss - 0.01 * entropy

    if T > HORIZON:
        trans_cos = F.cosine_similarity(
            (s[HORIZON:] - s[:-HORIZON]).detach(), g[:-HORIZON], dim=-1)
        loss_m_pg = -(adv_m[:-HORIZON] * trans_cos).mean()
    else:
        loss_m_pg = torch.zeros((), device=device)

    v_m_loss = 0.5 * F.mse_loss(V_m, R_ext.detach())
    loss_m = loss_m_pg + v_m_loss

    optimizer.zero_grad(set_to_none=True)
    (loss_w + loss_m).backward()
    torch.nn.utils.clip_grad_norm_(agent.parameters(), CLIP_NORM)
    optimizer.step()

    if ep % 10 == 0:
        print(f"EP {ep:05d} | "
              f"⟨R_ext⟩={r_ext.float().mean():.3f} | "
              f"⟨R_int⟩={r_int.float().mean():.3f} | "
              f"steps={T:3d} | "
              f"loss_w={loss_w.item():+.3f} | "
              f"loss_m_pg={loss_m_pg.item():+.3f}")

  fn()


Using device cpu


  if not isinstance(terminated, (bool, np.bool8)):


EP 00010 | ⟨R_ext⟩=0.000 | ⟨R_int⟩=0.000 | steps=100 | loss_w=-0.021 | loss_m_pg=+0.011
EP 00020 | ⟨R_ext⟩=0.000 | ⟨R_int⟩=-0.004 | steps=100 | loss_w=-0.019 | loss_m_pg=+0.030
EP 00030 | ⟨R_ext⟩=0.000 | ⟨R_int⟩=0.002 | steps=100 | loss_w=-0.030 | loss_m_pg=+0.005
EP 00040 | ⟨R_ext⟩=0.005 | ⟨R_int⟩=0.003 | steps= 72 | loss_w=-0.059 | loss_m_pg=+0.018
EP 00050 | ⟨R_ext⟩=0.082 | ⟨R_int⟩=-0.032 | steps= 11 | loss_w=+0.096 | loss_m_pg=-0.253
EP 00060 | ⟨R_ext⟩=0.241 | ⟨R_int⟩=-0.066 | steps=  4 | loss_w=-0.003 | loss_m_pg=+0.000
EP 00070 | ⟨R_ext⟩=0.003 | ⟨R_int⟩=-0.001 | steps= 85 | loss_w=-0.134 | loss_m_pg=+0.010
EP 00080 | ⟨R_ext⟩=0.014 | ⟨R_int⟩=0.006 | steps= 44 | loss_w=-0.259 | loss_m_pg=-0.009
EP 00090 | ⟨R_ext⟩=0.008 | ⟨R_int⟩=-0.011 | steps= 59 | loss_w=-0.378 | loss_m_pg=-0.030
EP 00100 | ⟨R_ext⟩=0.000 | ⟨R_int⟩=-0.003 | steps=100 | loss_w=+0.041 | loss_m_pg=-0.023
EP 00110 | ⟨R_ext⟩=0.004 | ⟨R_int⟩=-0.004 | steps= 80 | loss_w=+0.116 | loss_m_pg=-0.006
EP 00120 | ⟨R_ext⟩=0.023 

KeyboardInterrupt: 

In [62]:
import torch 
import gym
import imageio
from PIL import Image
import numpy as np

def display_episodes(agent, preprocess, env_id=ENV_ID, num_episodes=10, max_steps=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent.to(device)

    all_frames = []

    for episode in range(num_episodes):
        env = gym.make(env_id, render_mode="rgb_array")
        obs, _ = env.reset()
        img = preprocess(obs["image"]).to(device)
        agent.reset(batch=1, device=device)

        for t in range(max_steps):
            frame = env.render()

            all_frames.append(np.array(frame))

            with torch.no_grad():
                logits, *_ = agent(img, t)
                action = torch.distributions.Categorical(logits=logits.squeeze(0)).sample().item()

            obs, _, terminated, truncated, _ = env.step(action)
            img = preprocess(obs["image"]).to(device)

            if terminated or truncated:
                break

        env.close()

    video_path = "all_episodes.mp4"
    print(f"Saving {video_path}...")

    # Save as MP4 video
    with imageio.get_writer(video_path, fps=5, codec='libx264', format='mp4') as writer:
        for frame in all_frames:
            writer.append_data(frame)

    print("MP4 video saved.")

# Usage:
display_episodes(agent, preprocess)


Saving all_episodes.mp4...
MP4 video saved.
