In [1]:
from __future__ import annotations

from typing import NamedTuple, List, Protocol, Any, Sequence, Optional, Protocol
import numpy as np
from pathlib import Path

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.nn.utils.rnn as rnn

import gymnasium as gym
import gymnasium.wrappers
from rsrch.utils.polyak import Polyak

In [2]:
class MLP(nn.Sequential):
    def __init__(self, num_features: List[int], norm_layer=nn.LayerNorm, act_layer=nn.ReLU):
        layers = []
        seq = enumerate(zip(num_features[:-1], num_features[1:]))
        final_layer_idx = len(num_features) - 1
        for layer_idx, (in_features, out_features) in seq:
            if layer_idx > 0:
                layers.append(norm_layer(in_features))
                layers.append(act_layer())
            bias = (layer_idx == final_layer_idx)
            layers.append(nn.Linear(in_features, out_features, bias=bias))
        
        super().__init__(*layers)

class VisualEncoder(nn.Module):
    def __init__(self, obs_shape: torch.Size, kernel_size=3, hidden_dim=32):
        super().__init__()
        self.obs_shape = obs_shape
        self.kernel_size = k = kernel_size
        self.hidden_dim = h = hidden_dim

        c, W, H = self.obs_shape
        assert W % 16 == 0 and H % 16 == 0, \
            "image resolution should be divisible by 16"
        assert k % 2 == 1, \
            "kernel_size should be an odd number"
        p = k // 2

        final_size = torch.Size([H // 16, W // 16, 8*h])
        self.out_features = final_size.numel()

        self.main = nn.Sequential(
            nn.Conv2d(c, h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(h, 2*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(2*h, 4*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(4*h, 8*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
        )
    
    def forward(self, x: Tensor) -> Tensor:
        return self.main(x)

class ObsEncoder(nn.Module):
    def __init__(self, obs_space: gym.Space, enc_dim: int):
        super().__init__()
        self.out_features = enc_dim
        
        if isinstance(obs_space, gym.spaces.Box):
            obs_shape = torch.Size(obs_space.shape)
            if len(obs_shape) >= 3:
                self.enc = VisualEncoder(obs_shape)
                self.enc = nn.Sequential(self.enc, nn.Linear(self.enc.out_features, enc_dim))
            else:
                self.enc = MLP([obs_shape.numel(), 256, enc_dim])
        else:
            raise NotImplementedError()

    def forward(self, obs: Tensor) -> Tensor:
        return self.enc(obs)

class ActionEncoder(nn.Module):
    def __init__(self, act_space: gym.Space, enc_dim: int):
        super().__init__()
        self.out_features = enc_dim

        if isinstance(act_space, gym.spaces.Discrete):
            self.enc = nn.Embedding(act_space.n, enc_dim)
        elif isinstance(act_space, gym.spaces.Box):
            shape = torch.Size(act_space.shape)
            self.enc = MLP([shape.numel(), 256, enc_dim])
        else:
            raise NotImplementedError()
    
    def forward(self, act: Tensor) -> Tensor:
        return self.enc(act)

In [3]:
class EventHandler(Protocol):
    def on_reset(self, obs):
        ...
    
    def on_step(self, act, next_obs, reward, term, trunc):
        ...

class InvokeHandler(gym.Wrapper):
    def __init__(self, env: gym.Env, handler: EventHandler):
        super().__init__(env)
        self._handler = handler
    
    def reset(self, *, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self._handler.on_reset(obs)
        return obs, info

    def step(self, act):
        next_obs, reward, term, trunc, info = self.env.step(act)
        self._handler.on_step(act, next_obs, reward, term, trunc)
        return next_obs, reward, term, trunc, info

In [4]:
class Transition(NamedTuple):
    obs: Tensor
    act: Tensor
    next_obs: Tensor
    reward: float
    term: bool
    trunc: bool

class TransitionBatch(NamedTuple):
    obs: Tensor
    act: Tensor
    next_obs: Tensor
    reward: Tensor
    term: Tensor
    trunc: Tensor

class TransitionBuffer(data.Dataset):
    def __init__(self, env: gym.Env, capacity: int):
        def buffer_for(space=None, dtype=None):
            if space is not None:
                x = np.empty(space.shape, dtype=space.dtype)
                x = torch.from_numpy(x)
                x = torch.empty(capacity, *x.shape, dtype=x.dtype)
            else:
                x = torch.empty(capacity, dtype=dtype)
            return x

        self.obs = buffer_for(space=env.observation_space)
        self.act = buffer_for(space=env.action_space)
        self.next_obs = torch.empty_like(self.obs)
        self.reward = buffer_for(dtype=torch.float)
        self.term = buffer_for(dtype=torch.bool)
        self.trunc = buffer_for(dtype=torch.bool)

        self._capacity = capacity
        self._cursor = self._size = 0
        self._cur_obs = None
    
    @property
    def device(self):
        return self.obs.device
    
    def _convert(self, x, type_as):
        return torch.as_tensor(x).type_as(type_as)
    
    def push(self, obs, act, next_obs, reward, term, trunc):
        idx = self._cursor
        self.obs[idx] = self._convert(obs, self.obs)
        self.act[idx] = self._convert(act, self.act)
        self.next_obs[idx] = self._convert(next_obs, self.next_obs)
        self.reward[idx] = reward
        self.term[idx] = term
        self.trunc[idx] = trunc

        self._cursor = (self._cursor + 1) % self._capacity
        if self._size < self._capacity:
            self._size += 1
        
        return self

    def __len__(self):
        return self._size
    
    def __getitem__(self, idx):
        return Transition(
            obs=self.obs[idx],
            act=self.act[idx],
            next_obs=self.next_obs[idx],
            reward=self.reward[idx],
            term=self.term[idx],
            trunc=self.trunc[idx],
        )

class CollectTransitions(gym.Wrapper):
    def __init__(self, env: gym.Env, buffer: TransitionBuffer):
        super().__init__(env)
        self._buffer = buffer
    
    def reset(self, *, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self._obs = obs
        return obs, info

    def step(self, act):
        next_obs, reward, term, trunc, info = self.env.step(act)
        self._buffer.push(self._obs, act, next_obs, reward, term, trunc)
        return next_obs, reward, term, trunc, info

In [5]:
class Trajectory(Protocol):
    obs: Sequence
    act: Sequence
    reward: Sequence
    trunc: Sequence
    term: Sequence

    def __len__(self) -> int:
        ...

    def __getitem__(self, idx):
        ...

class TensorTrajectory(NamedTuple):
    obs: Tensor
    act: Tensor
    reward: Tensor
    trunc: Tensor
    term: Tensor

    @staticmethod
    def as_tensor(tr: Trajectory):
        return TensorTrajectory(
            obs=torch.tensor(tr.obs),
            act=torch.tensor(tr.act),
            reward=torch.tensor(tr.reward),
            trunc=torch.tensor(tr.trunc),
            term=torch.tensor(tr.term),
        )

    def clone(self):
        return TensorTrajectory(
            obs=self.obs.clone(),
            act=self.act.clone(),
            reward=self.reward.clone(),
            trunc=self.trunc.clone(),
            term=self.term.clone(),
        )

    def to(self, device: torch.device):
        return TensorTrajectory(
            obs=self.obs.to(device),
            act=self.act.to(device),
            reward=self.reward.to(device),
            trunc=self.trunc.to(device),
            term=self.term.to(device),
        )
    
    def __getitem__(self, idx):
        return TensorTrajectory(
            obs=self.obs[idx],
            act=self.act[idx],
            reward=self.reward[idx],
            trunc=self.trunc[idx],
            term=self.term[idx],
        )

    def __len__(self):
        return len(self.obs)

class MmapTrajectory(NamedTuple):
    obs: np.memmap
    act: np.memmap
    reward: np.memmap
    trunc: np.memmap
    term: np.memmap
    
    @staticmethod
    def open(root: Path):
        return MmapTrajectory(
            obs=np.load(root / "obs.npy", mmap_mode="r"),
            act=np.load(root / "act.npy", mmap_mode="r"),
            reward=np.load(root / "reward.npy", mmap_mode="r"),
            trunc=np.load(root / "trunc.npy", mmap_mode="r"),
            term=np.load(root / "term.npy", mmap_mode="r"),
        )

    @staticmethod
    def save(tr: Trajectory, root: Path):
        root.mkdir(parents=True, exist_ok=True)
        np.save(root / "obs.npy", np.asarray(tr.obs))
        np.save(root / "act.npy", np.asarray(tr.act))
        np.save(root / "reward.npy", np.asarray(tr.reward))
        np.save(root / "trunc.npy", np.asarray(tr.trunc))
        np.save(root / "term.npy", np.asarray(tr.term))
    
    def __len__(self):
        return len(self.obs)
    
    def __getitem__(self, idx):
        return MmapTrajectory(obs=self.obs[idx], act=self.act[idx],
                              reward=self.reward[idx], trunc=self.trunc[idx],
                              term=self.term[idx])

class Subsample(nn.Module):
    def __init__(self, seq_len: int):
        super().__init__()
        self.seq_len = seq_len
    
    def forward(self, traj: Trajectory):
        start = np.random.randint(len(traj))
        end = start + self.seq_len
        return traj[start:end]

class ToTensor(nn.Module):
    def forward(self, traj: Trajectory):
        return TensorTrajectory.as_tensor(traj)

class MapDs(data.Dataset):
    def __init__(self, ds: data.Dataset, f):
        super().__init__()
        self.ds = ds
        self.f = f
    
    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return self.f(self.ds[idx])

class MapIterDs(data.IterableDataset):
    def __init__(self, ds: data.IterableDataset, f):
        super().__init__()
        self.ds = ds
        self.f = f
    
    def __iter__(self):
        for x in self.ds:
            yield self.f(x)

class TrajectoryBatch(NamedTuple):
    obs: rnn.PackedSequence
    act: rnn.PackedSequence
    reward: rnn.PackedSequence
    trunc: rnn.PackedSequence
    term: rnn.PackedSequence

    @staticmethod
    def collate_fn(batch: List[TensorTrajectory]) -> TrajectoryBatch:
        lengths = torch.as_tensor([len(tr) for tr in batch])
        idxes = torch.argsort(lengths, descending=True)
        
        return TrajectoryBatch(
            obs=rnn.pack_sequence([batch[idx].obs for idx in idxes]),
            act=rnn.pack_sequence([batch[idx].act for idx in idxes]),
            reward=rnn.pack_sequence([batch[idx].reward for idx in idxes]),
            trunc=rnn.pack_sequence([batch[idx].trunc for idx in idxes]),
            term=rnn.pack_sequence([batch[idx].term for idx in idxes]),
        )

class EpisodeBuffer(data.Dataset):
    def __init__(self, env: gym.Env, max_seq_len: int, seq_capacity: int, mmap_root: Optional[Path] = None):
        def buffer_for(space=None, dtype=None):
            if space is not None:
                x = np.empty(space.shape, dtype=space.dtype)
                x = torch.from_numpy(x)
                x = torch.empty(max_seq_len, *x.shape, dtype=x.dtype)
            else:
                x = torch.empty(max_seq_len, dtype=dtype)
            return x

        self.obs = buffer_for(space=env.observation_space)
        self.act = buffer_for(space=env.action_space)
        self.reward = buffer_for(dtype=torch.float)
        self.trunc = buffer_for(dtype=torch.bool)
        self.trunc.fill_(True)
        self.term = buffer_for(dtype=torch.bool)
        
        self._episodes = np.empty((seq_capacity,), dtype=object)
        self.mmap_root = mmap_root

        self._cur_step = 0
        self._cur_ep_idx = -1
        self.num_episodes = 0
        self.seq_capacity = seq_capacity
        self.max_seq_len = max_seq_len
    
    @property
    def device(self):
        return self.obs.device
    
    def _conv(self, x, type_as):
        return torch.as_tensor(x).type_as(type_as)

    def on_reset(self, obs):
        self._cur_step = self._cur_ep_len = 0
        self._cur_ep_idx = (self._cur_ep_idx + 1) % self.seq_capacity
        self.num_episodes = min(self.num_episodes + 1, self.seq_capacity)

        self.obs[self._cur_step] = self._conv(obs, self.obs)
        self.trunc[self._cur_step] = True
        self.term[self._cur_step] = False

        self._update_cur_episode_view()
    
    def _update_cur_episode_view(self):
        ep_len = self._cur_step + 1
        
        # NOTE: This only creates a view, no data is copied
        self._episodes[self._cur_ep_idx] = TensorTrajectory(
            obs=self.obs[:ep_len],
            act=self.act[:ep_len],
            reward=self.reward[:ep_len],
            trunc=self.trunc[:ep_len],
            term=self.term[:ep_len],
        )
    
    def on_step(self, act, next_obs, reward, term, trunc):
        if self._cur_step < self.max_seq_len:
            self.act[self._cur_step] = self._conv(act, self.act)
            self.trunc[self._cur_step] = False
            self._cur_step += 1
            self.obs[self._cur_step] = self._conv(next_obs, self.obs)
            self.reward[self._cur_step] = reward
            self.term[self._cur_step] = term
            self.trunc[self._cur_step] = trunc

        self._update_cur_episode_view()

        done = term or trunc
        if done:
            cur_ep_view = self._episodes[self._cur_ep_idx]
            if self.mmap_root is not None:
                dst_root = self.mmap_root / f"{self._cur_ep_idx:06d}"
                MmapTrajectory.save(cur_ep_view, dst_root)
                self._episodes[self._cur_ep_idx] = MmapTrajectory.open(dst_root)
            else:
                self._episodes[self._cur_ep_idx] = cur_ep_view.clone()

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, idx):
        if idx >= self.num_episodes:
            raise IndexError()
        return self._episodes[idx]


class CollectEpisodes(gym.Wrapper):
    def __init__(self, env: gym.Env, buffer: EpisodeBuffer):
        super().__init__(env)
        self._buffer = buffer
    
    def reset(self, *, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self._buffer.on_reset(obs)
        return obs, info

    def step(self, act):
        next_obs, reward, term, trunc, info = self.env.step(act)
        self._buffer.on_step(act, next_obs, reward, term, trunc)
        return next_obs, reward, term, trunc, info

In [6]:
import torch.utils.data as data
import typing
import sys
import logging

class QNetwork(Protocol):
    num_actions: int

    def __call__(self, obs: Tensor) -> Tensor:
        ...

class BaseQNetwork(nn.Module, QNetwork):
    def __init__(self, env: gym.Env):
        super().__init__()
        assert isinstance(env.action_space, gym.spaces.Discrete)

        self.num_actions = int(env.action_space.n)
        self.main = nn.Sequential(
            ObsEncoder(env.observation_space, 128),
            MLP([128, 64, self.num_actions]),
        )
    
    def forward(self, obs: Tensor):
        return self.main(obs)

class QAdvNetwork(nn.Module, QNetwork):
    def __init__(self, env: gym.Env):
        super().__init__()
        assert isinstance(env.action_space, gym.spaces.Discrete)

        self.num_actions = int(env.action_space.n)
        self.encoder = ObsEncoder(env.observation_space, 128)
        self.adv_head = MLP([128, 32, 1])
        self.q_head = MLP([128, 64, self.num_actions])
    
    def forward(self, obs: Tensor):
        enc = self.encoder(obs)
        adv = self.adv_head(enc)
        qs = self.q_head(obs)
        return adv + qs

class GreedyAgent:
    def __init__(self, Q: QNetwork):
        self.Q = Q
    
    def batch_act(self, obs: Tensor) -> Tensor:
        with torch.no_grad():
            q_vals = self.Q(obs)
        return torch.argmax(q_vals, 1)

    def act(self, obs: Tensor) -> Tensor:
        return self.batch_act(obs.unsqueeze(0)).squeeze(0)

class EpsGreedyAgent:
    def __init__(self, Q: QNetwork, eps: float):
        self.Q = Q
        self.greedy = GreedyAgent(self.Q)
        self.eps = eps
    
    def batch_act(self, obs: Tensor) -> Tensor:
        rand_p = torch.rand(len(obs), device=obs.device)
        rand_act = torch.randint(self.Q.num_actions, (len(obs),))
        greedy_act = self.greedy.batch_act(obs)
        return torch.where(rand_p < self.eps, rand_act, greedy_act)

    def act(self, obs: Tensor) -> Tensor:
        return self.batch_act(obs.unsqueeze(0)).squeeze(0)

class RandomAgent:
    def __init__(self, env: gym.Env):
        self.action_space = env.action_space
    
    def __call__(self, obs):
        return self.action_space.sample()

class DQN(nn.Module):
    def __init__(self, env: gym.Env):
        super().__init__()
        assert isinstance(env.action_space, gym.spaces.Discrete)

        self.Q = BaseQNetwork(env)
        self.target_Q = BaseQNetwork(env)
        self.target_Q.load_state_dict(self.Q.state_dict())

class EpsScheduler:
    def __init__(self, agent: EpsGreedyAgent, max_eps, min_eps, step_decay):
        self.agent = agent
        self.base_eps = min_eps
        self.eps_amp = max_eps - min_eps
        self.step_decay = step_decay
        self.reset()

        self.cur_eps = max_eps
        agent.eps = self.cur_eps
    
    def reset(self):
        self._cur_step = 0

    def step(self):
        cur_decay = np.exp(-self.step_decay * self._cur_step)
        self.cur_eps = self.base_eps + self.eps_amp * cur_decay
        self.agent.eps = self.cur_eps
        self._cur_step += 1

class DQNData(Protocol):
    def train_env(self) -> gym.Env:
        ...
    
    def val_env(self) -> gym.Env:
        ...

class DQNLoss:
    def __init__(self, Q: QNetwork, target_Q: QNetwork, gamma: float):
        self.Q = Q
        self.target_Q = target_Q
        self.gamma = gamma
    
    def __call__(self, batch: TransitionBatch):
        value_preds = self.Q(batch.obs)
        preds = value_preds.gather(1, batch.act.unsqueeze(1)).squeeze(1)
        
        with torch.no_grad():
            next_V = self.target_Q(batch.next_obs).max(dim=1)[0]
        targets = batch.reward + (1.0 - batch.term.float()) * self.gamma * next_V

        return F.smooth_l1_loss(preds, targets)

class Rollout(data.IterableDataset[Transition]):
    def __init__(self, env: gym.Env, agent, max_steps=None, max_episodes=None):
        super().__init__()
        self.env = env
        self.agent = agent
        self.max_steps = max_steps
        self.max_episodes = max_episodes
    
    def __iter__(self):
        obs, info = self.env.reset()

        ep_idx = step_idx = 0
        
        while True:
            if self.max_steps is not None and step_idx >= self.max_steps:
                break

            if self.max_episodes is not None and ep_idx >= self.max_episodes:
                break
            
            act = self.agent.act(obs)
            next_obs, reward, term, trunc, info = self.env.step(act)
            yield Transition(obs, act, next_obs, reward, term, trunc)
            obs = next_obs
            step_idx += 1

            if term or trunc:
                obs, info = self.env.reset()
                ep_idx += 1

class Logger:
    def __init__(self, level: int):
        self.metric = lambda *args, **kwargs: ...

class RandomInfiniteSampler:
    def __init__(self, ds: typing.Sized):
        self.ds = ds
    
    def __iter__(self):
        while True:
            yield np.random.randint(len(self.ds))

class DQNTrainer:
    def __init__(self):
        self.train_steps = int(1e6)
        self.train_episodes = int(5e3)
        self.val_every_steps = int(10e3)
        self.val_episodes = 32
        self.buffer_capacity = int(1e5)
        self.max_eps, self.min_eps = 0.9, 0.05
        self.eps_step_decay = 1e-3
        self.val_eps = 0.05
        self.gamma = 0.99
        self.tau = 0.995
        self.batch_size = 128
        self.clip_grad = 100.0
        self.prefill = int(1e3)

        self.log = Logger(level=logging.WARN)

    def train(self, dqn: DQN, dqn_data: DQNData):
        train_env = dqn_data.train_env()
        train_agent = EpsGreedyAgent(dqn.Q, self.max_eps)

        train_buffer = TransitionBuffer(train_env, self.buffer_capacity)
        train_env = CollectTransitions(train_env, train_buffer)

        # Online data from a rollout
        train_rollout = Rollout(train_env, train_agent)
        train_env_iter = iter(train_rollout)

        # Offline data from the buffer
        # NOTE to self: Can possibly replace it with the "Dreamer" buffer?
        # Would also make it super-clean
        train_loader = data.DataLoader(
            dataset=train_buffer,
            sampler=data.BatchSampler(
                sampler=RandomInfiniteSampler(train_buffer),
                batch_size=self.batch_size,
                drop_last=False,
            ),
            batch_size=None,
        )
        train_loader_iter = iter(train_loader)

        val_env = dqn_data.val_env()
        val_agent = EpsGreedyAgent(dqn.Q, self.val_eps)

        dqn_loss = DQNLoss(dqn.Q, dqn.target_Q, self.gamma)

        optim = torch.optim.AdamW(dqn.Q.parameters(), lr=1e-4, amsgrad=True)
        eps_sched = EpsScheduler(train_agent, self.max_eps, self.min_eps,
                                 self.eps_step_decay)
        polyak = Polyak(dqn.Q, dqn.target_Q, tau=self.tau)

        step_idx = 0

        # Prefill the replay buffer
        while len(train_buffer) < self.prefill:
            _ = next(train_env_iter)
        train_env.reset()

        def train_step():
            # Env interaction
            _ = next(train_env_iter)
            
            # Policy learning
            batch = next(train_loader_iter)

            loss = dqn_loss(batch)
                
            optim.zero_grad(True)
            loss.backward()
            if self.clip_grad is not None:
                nn.utils.clip_grad.clip_grad_value_(
                    dqn.Q.parameters(), self.clip_grad)
            optim.step()

            polyak.step()
            eps_sched.step()

            self.log.metric("train_loss", loss)
            self.log.metric("cur_train_eps", eps_sched.cur_eps)

        def val_epoch():
            all_ep_returns = []
            for _ in range(self.val_episodes):
                ep_returns = 0.0
                for tr in Rollout(val_env, val_agent, max_episodes=1):
                    ep_returns += tr.reward

                all_ep_returns.append(ep_returns)
            
            self.log.metric("val_returns", all_ep_returns)

        while True:
            if step_idx % self.val_every_steps == 0:
                val_epoch()
            
            if step_idx >= self.train_steps:
                break

            train_step()
            step_idx += 1

In [7]:
from gymnasium.wrappers import TransformObservation

class TensorAction(gym.Wrapper):
    def step(self, action: Any):
        return self.env.step(action.item())

class DQNData_v0(DQNData):
    def train_env(self):
        return self.val_env()
    
    def val_env(self):
        env = gym.make("CartPole-v1") 
        env = TensorAction(env)
        env = TransformObservation(env, torch.as_tensor)
        return env

dqn_data = DQNData_v0()
dqn = DQN(env=dqn_data.train_env())
trainer = DQNTrainer()

trainer.train(dqn, dqn_data)

KeyboardInterrupt: 