In [11]:
! pip install minigrid==2.2.1



In [146]:
import abc
import random
from typing import NamedTuple, Any, Iterable, Callable

import numpy as np
import torch
import minigrid
import gymnasium as gym
import lovely_tensors
import lovely_numpy
import lightning

def lovely(x):
    "summarizes important tensor properties"
    if isinstance(x, np.ndarray):
        return lovely_numpy.lovely(x)
    return lovely_tensors.lovely(x)

In [147]:
class Transition(NamedTuple):
    state: torch.Tensor
    action: torch.Tensor
    reward: torch.Tensor
    next_state: torch.Tensor
    done: torch.Tensor


class ReplayBuffer(torch.utils.data.IterableDataset):
    def __init__(self, capacity: int, seed = None):
        self.capacity = capacity
        self.buffer: list[Transition] = []
        self.seed = seed
        self.rng = np.random.default_rng(seed)
    
    def __next__(self):
        return self.rng.choice(self.buffer)
    
    def __iter__(self):
        return self
    
    def add(self, t: Transition):
        self.buffer.append(t)
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)
    

In [148]:
# Minigrid Environment
from minigrid.wrappers import ImgObsWrapper

class ChannelFirst(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        old_shape = env.observation_space.shape
        self.observation_space = {}
        self.observation_space = gym.spaces.Box(0, 255, shape=(3, 7, 7))

    def observation(self, observation):
        return np.swapaxes(observation, 2, 0)

class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)

    def observation(self, observation):
        # careful! This undoes the memory optimization, use
        # with smaller replay buffers only.
        return np.array(observation).astype(np.float32)

class MinigridEmpty5x5ImgObs(gym.Wrapper):
    """Minigrid with image observations provided by minigrid, partially observable."""
    def __init__(self):
        env = gym.make('MiniGrid-Empty-5x5-v0')
        env = ScaledFloatFrame(ChannelFirst(ImgObsWrapper(env)))
        super().__init__(env)

In [149]:
from __future__ import annotations
import itertools

import torchvision


class MlpQNet(torch.nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.num_actions = num_actions
        self.fc = torch.nn.Sequential(
            torch.nn.Flatten(), 
            torch.nn.Linear(3*7**2, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, num_actions)
        )

    def forward(self, x) -> torch.Tensor:
        if isinstance(x, np.ndarray):
            param = next(self.parameters())
            x = torch.tensor(x).to(param)

        if len(x.shape) == 3:
            x = x.unsqueeze(dim=0)

        return self.fc(x)


def set_submodule(module: torch.nn.Module, submodule_path: str, submodule: torch.nn.Module) -> None:
    *parent_path, attr_name = submodule_path.split(".")
    parent = module
    for child in parent_path:
        parent = getattr(parent, child)
    setattr(parent, attr_name, submodule)


# Copied from my previous homework
class VisionQNet(torch.nn.Module):
    """
    Policy network is compatible with any CNN architecture as long as:
    1) One of the conv layers determines how many channels it can process
    2) One of the linear layers determines the number of output neurons
    
    This should be enough to support all standard architectures

    >>> policy_nn = PolicyNetwork(
    ...     actions=5,
    ...     architecture="efficientnet_b0",
    ...     input_conv="features.0.0",
    ...     output_lin="classifier.1",
    ...     pretrained_weights=None, # "IMAGENET1K_V1"
    ... )
    >>> inputs = torch.rand(size=(16, 1, 64, 64))
    >>> policy_nn(inputs).shape
    torch.Size([16, 5])

    """

    def __init__(
        self,
        actions: int | dict[int, str],
        architecture: str,
        input_conv: str,
        output_lin: str,
        pretrained_weights = None,
        seed: int = 42,
    ) -> None:
        super(VisionQNet, self).__init__()
        if isinstance(actions, int):
            self.n_units_out = actions
            self.actions = {i: f"action_{i}" for i in range(actions)}
        else:
            self.n_units_out = len(actions)
            self.actions = actions
        self.architecture = architecture
        self.input_conv = input_conv
        self.output_lin = output_lin
        self.nn = torchvision.models.get_model(self.architecture)
        self._patch_input_shape()
        self._patch_output_shape()
        self.init(pretrained_weights, seed)

    def init(self, pretrained_weights = None, seed: int = None) -> None:
        """
        Initializes the model parameters, either randomly or with pretrained weights.
        Patches the architecture to have desired input and output shape.
        """
        if seed is not None:
            torch.random.manual_seed(seed)

        state_dict = torchvision.models.get_model(self.architecture, weights=pretrained_weights).state_dict()

        for key in list(state_dict.keys()):
            if key.startswith(self.input_conv) or key.startswith(self.output_lin):
                state_dict.pop(key)

        incompatible = self.nn.load_state_dict(state_dict, strict=False)
        if len(incompatible.unexpected_keys) != 0:
            raise ValueError(f"Unexpected additional keys in pretrained weights: {incompatible.unexpected_keys}")
        
        for key in incompatible.missing_keys:
            assert isinstance(key, str)
            if key.startswith(self.input_conv) or key.startswith(self.output_lin):
                continue
            raise ValueError(f"Unexpected missing key in pretrained weights: {key}")

        del state_dict

        in_out_params = itertools.chain(
            self.nn.get_submodule(self.input_conv).parameters(),
            self.nn.get_submodule(self.output_lin).parameters(),
        )

        for param in in_out_params:
            torch.nn.init.trunc_normal_(param, mean=0, std=1e-4, a=-0.01, b=0.01)            

    def _patch_input_shape(self) -> None:
        """
        Make the architecture accept a single (grayscale) channel
        """
        old_in_conv = self.nn.get_submodule(self.input_conv)
        assert isinstance(old_in_conv, torch.nn.Conv2d)

        if old_in_conv.in_channels == 1:
            return

        new_in_conv = torch.nn.Conv2d(
            in_channels = 1, 
            out_channels = old_in_conv.out_channels,
            kernel_size = old_in_conv.kernel_size,
            stride = old_in_conv.stride,
            padding = old_in_conv.padding,
            bias = old_in_conv.bias is not None,
        ).to(
            old_in_conv.weight
        )
        set_submodule(self.nn, self.input_conv, new_in_conv)

    def _patch_output_shape(self) -> None:
        """
        Make the architecture output the correct shape
        """
        old_out_lin = self.nn.get_submodule(self.output_lin)
        assert isinstance(old_out_lin, torch.nn.Linear)
        
        if old_out_lin.out_features == self.n_units_out:
            return

        new_out_lin = torch.nn.Linear(
            in_features = old_out_lin.in_features,
            out_features = self.n_units_out,
            bias = old_out_lin.bias is not None,
        ).to(
            old_out_lin.weight
        )
        set_submodule(self.nn, self.output_lin, new_out_lin)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # select only last channel if there are more
        # TODO revisit this. it's coming from previous homework and maybe not relevant anymore
        x = x[:, (-1,), :, :]
        return self.nn(x)

In [150]:
VisionQNet(5, "mobilenet_v3_small", "features.0.0", "classifier.3")
None

In [151]:
envs = MinigridEmpty5x5ImgObs()
state, info = envs.reset()
print("obs:", lovely(state))
print("info:", info)
print("action_space:", envs.action_space)

obs: array[3, 7, 7] f32 n=147 x∈[0., 8.000] μ=2.020 σ=2.075
info: {}
action_space: Discrete(7)


In [152]:
# def set_seed(env, seed=None):
#     if seed is None:
#         return 
    
#     random.seed(seed)
#     env.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(seed)
#         torch.cuda.manual_seed_all(seed)


In [155]:
class Agent(abc.ABC):
    def select_action(self, state) -> int:
        ...

class DQN(lightning.LightningModule, Agent):
    def __init__(
        self,
        model_config: dict[str, Any],
        num_actions: int,
        use_double_dqn: bool,
        optimizer_class: str,
        use_pretrained_weights: str | None = None,
        optimizer_config: dict[str, Any] = None,
        epsilon: float = 0.1,
        tau: float = 1e-3,
        gamma: float = 0.99,
        use_vision: bool = True,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["use_pretrained_weights"])
        self.model_config = model_config

        if optimizer_config is None:
            optimizer_config = {}
        self.optimizer_config = optimizer_config
        self.optimizer_class = optimizer_class
        
        # TODO does it matter that they are initialized with the same weights?
        if use_vision:
            self.qnn = VisionQNet(**model_config, pretrained_weights=use_pretrained_weights)
            self.qnn_target = VisionQNet(**model_config, pretrained_weights=use_pretrained_weights)
        else:
            self.qnn = MlpQNet(**model_config)
            self.qnn_target = MlpQNet(**model_config)
        
        self.loss_fn = torch.nn.MSELoss()

        self.num_actions = num_actions
        self.use_double_dqn = use_double_dqn
        self.epsilon = epsilon
        self.tau = tau
        self.gamma = gamma

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.qnn(x)

    @torch.no_grad()
    def select_action(self, observation: torch.Tensor, greedy: bool) -> int:
        if greedy or random.random() > self.epsilon:
            q_values = self.qnn(observation) 
            return torch.argmax(q_values).item()
        return random.randrange(self.num_actions)

    def configure_optimizers(self):
        optimizer_class = getattr(torch.optim, self.optimizer_class)
        return optimizer_class(self.qnn.parameters(), **self.optimizer_config)

    def training_step(self, batch: Transition, batch_idx):
        with torch.no_grad():
            if self.use_double_dqn:
                next_q_values = self.select_action(batch.next_state, greedy=True)
                next_value_estimate = self.qnn_target(batch.next_state, next_q_values)
            else:
                next_value_estimate = torch.max(self.qnn_target(batch.next_state), dim=1).values
            td_target = batch.reward + self.gamma * next_value_estimate * (1 - batch.done)
        
        td_error: torch.Tensor
        td_error = self.loss_fn(self.qnn(batch), td_target)
        self.log("td_error", td_error)
        return td_error
    
    def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
        self.soft_update_target_nn()
        # TODO decay epsilon
        # self.epsilon = max(self.epsilon_lb, self.epsilon_ub - self.global_step / self.epsilon_decay)
        return super().on_train_batch_end(outputs, batch, batch_idx)

    @torch.no_grad()
    def soft_update_target_nn(self):
        """Soft update model parameters.
        θ_target = tau * θ_local + (1 - tau) * θ_target
        """
        params = zip(self.qnn_target.parameters(), self.qnn.parameters())
        for target_param, param in params:
            target_param.copy_(self.tau * param + (1.0 - self.tau) * target_param)


    def validation_step(self, batch, batch_idx):
        # TODO
        ...



In [161]:
class ParallelExperienceGenerator:
    def __init__(
        self,
        agent: Agent,
        make_env: Callable[[int | None, bool], gym.Env],
        make_env_kwargs: dict[str, Any],
        num_parallel_envs: int,
    ) -> None:
        self.agent = agent
        self.make_env = make_env
        self.num_parallel_envs = num_parallel_envs

        self.envs = gym.vector.AsyncVectorEnv([
            lambda: make_env(**make_env_kwargs)
            for _ in range(num_parallel_envs)
        ])

    def __iter__(self) -> Iterable[Transition]:
        states, _ = self.envs.reset()
        while True:
            actions = self.agent.select_action(states)
            next_states, rewards, dones, _ = envs.step(actions)
            for s, a, r, s_, d in zip(states, actions, next_states, rewards, dones):
                yield Transition(s, a, r, s_, d)
            states = next_states


class ExperienceCollectionCallback(lightning.Callback):
    def __init__(
        self,
        parallel_experience_generator: ParallelExperienceGenerator,
        buffer: ReplayBuffer,
        collect_every_n_steps: int,
        collect_num_steps_in_each_env: int,
        num_parallel_envs: int,
    ) -> None:
        super().__init__()
        self.parallel_experience_collector = parallel_experience_generator
        self.buffer = buffer
        self.collect_every_n_steps = collect_every_n_steps
        self.collect_num_steps_in_each_env = collect_num_steps_in_each_env
        self.num_parallel_envs = num_parallel_envs

    def on_train_start(self, trainer, pl_module) -> None:
        self.collect_experience()

    def on_train_batch_end(self, trainer: lightning.Trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
        if trainer.global_step % self.collect_every_n_steps == 0:
            self.collect_experience()

    def collect_experience(self):
        for step in range(self.collect_num_steps_in_each_env):
            transitions = next(self.parallel_experience_collector)
            for transition in transitions:
                self.buffer.add(transition)


In [162]:
num_actions = 5

mobilenet_v3_small_config = {
    "actions": num_actions,
    "architecture": "mobilenet_v3_small",
    "input_conv": "features.0.0",
    "output_lin": "classifier.3",
}

efficientnet_b0_config = {
    "actions": num_actions,
    "architecture": "efficientnet_b0",
    "input_conv": "features.0.0",
    "output_lin": "classifier.1",
}

mlp_config = {
    "num_actions": num_actions,
}

dqn = DQN(
    model_config = mlp_config,
    use_vision = False,
    num_actions = num_actions,
    use_double_dqn = True,
    optimizer_class = "AdamW",
    optimizer_config = dict(lr=1e-4),
    epsilon = 0.1,
    tau = 1e-3,
    use_pretrained_weights="DEFAULT",
)

In [163]:
buffer = ReplayBuffer(
    capacity = 100_000,
    seed = 42
)

In [None]:
experience_generator = ParallelExperienceGenerator(
    agent = dqn,
    make_env = ..., #TODO
    make_env_kwargs = ..., #TODO
    num_parallel_envs = 16,
)

In [None]:
experience_collection_callback = ExperienceCollectionCallback(
    parallel_experience_generator = experience_generator,
    buffer = buffer,
    collect_every_n_steps = ..., #TODO
    collect_num_steps_in_each_env = ..., #TODO
)