In [None]:
import os
import abc
import random
import itertools
import functools
import collections
import pathlib
import warnings
from typing import NamedTuple, Any, Iterable, Callable

import cv2
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.utils.data
import lightning
import lightning.pytorch
import lightning.pytorch.loggers
import lightning.pytorch.callbacks
import gym
import gym.spaces
import lovely_tensors
import lovely_numpy
import wandb
import wandb.wandb_run
from tqdm.auto import tqdm

from envs import make_env_minigrid, make_env_pong

cv2.ocl.setUseOpenCL(False)


def lovely(x):
    "summarizes important tensor properties"
    if isinstance(x, np.ndarray):
        return lovely_numpy.lovely(x)
    return lovely_tensors.lovely(x)

### Replay Buffer

In [None]:
from __future__ import annotations

class Transition(NamedTuple):
    state: torch.Tensor | np.ndarray
    action: torch.Tensor | np.ndarray
    reward: torch.Tensor | np.ndarray
    next_state: torch.Tensor | np.ndarray
    done: torch.Tensor | np.ndarray | bool


class ReplayBuffer(torch.utils.data.IterableDataset):
    def __init__(self, capacity: int, seed = None):
        self.capacity = capacity
        self.buffer: list[Transition] = []
        self.seed = seed
        self.rng = random.Random(seed)
    
    def __next__(self):
        transition = self.rng.choice(self.buffer)
        transition = Transition(
            state = torch.tensor(transition.state, dtype=torch.float32),
            action = torch.tensor(transition.action, dtype=torch.int64),
            reward = torch.tensor(transition.reward, dtype=torch.float32),
            next_state = torch.tensor(transition.next_state, dtype=torch.float32),
            done = torch.tensor(transition.done.item(), dtype=torch.bool),
        )
        return transition
    
    def __iter__(self):
        return self
    
    def add(self, t: Transition):
        self.buffer.append(t)
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)
    

### Architectures

In [None]:
from __future__ import annotations


class MlpQNet(torch.nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.num_actions = num_actions
        self.fc = torch.nn.Sequential(
            torch.nn.Flatten(start_dim=1), 
            torch.nn.Linear(3*7**2, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, num_actions)
        )

    def forward(self, x) -> torch.Tensor:
        if isinstance(x, np.ndarray):
            param = next(self.parameters())
            x = torch.tensor(x).to(param)

        if len(x.shape) == 3:
            x = x.unsqueeze(dim=0)

        return self.fc(x)


def set_submodule(module: torch.nn.Module, submodule_path: str, submodule: torch.nn.Module) -> None:
    *parent_path, attr_name = submodule_path.split(".")
    parent = module
    for child in parent_path:
        parent = getattr(parent, child)
    setattr(parent, attr_name, submodule)


# Copied from my previous homework
class VisionQNet(torch.nn.Module):
    """
    Policy network is compatible with any CNN architecture as long as:
    1) One of the conv layers determines how many channels it can process
    2) One of the linear layers determines the number of output neurons
    
    This should be enough to support all standard architectures

    >>> policy_nn = PolicyNetwork(
    ...     actions=5,
    ...     architecture="efficientnet_b0",
    ...     input_conv="features.0.0",
    ...     output_lin="classifier.1",
    ...     pretrained_weights=None, # "IMAGENET1K_V1"
    ... )
    >>> inputs = torch.rand(size=(16, 1, 64, 64))
    >>> policy_nn(inputs).shape
    torch.Size([16, 5])

    """

    def __init__(
        self,
        actions: int | dict[int, str],
        architecture: str,
        input_conv: str,
        output_lin: str,
        input_channels: int,
        pretrained_weights = None,
        seed: int = 42,
    ) -> None:
        super(VisionQNet, self).__init__()
        if isinstance(actions, int):
            self.n_units_out = actions
            self.actions = {i: f"action_{i}" for i in range(actions)}
        else:
            self.n_units_out = len(actions)
            self.actions = actions
        self.architecture = architecture
        self.input_conv = input_conv
        self.output_lin = output_lin
        self.input_channels = input_channels
        self.nn = torchvision.models.get_model(self.architecture)
        self._patch_input_shape()
        self._patch_output_shape()
        self.init(pretrained_weights, seed)

    def init(self, pretrained_weights = None, seed: int = None) -> None:
        """
        Initializes the model parameters, either randomly or with pretrained weights.
        Patches the architecture to have desired input and output shape.
        """
        if seed is not None:
            torch.random.manual_seed(seed)

        state_dict = torchvision.models.get_model(self.architecture, weights=pretrained_weights).state_dict()

        for key in list(state_dict.keys()):
            if key.startswith(self.input_conv) or key.startswith(self.output_lin):
                state_dict.pop(key)

        incompatible = self.nn.load_state_dict(state_dict, strict=False)
        if len(incompatible.unexpected_keys) != 0:
            raise ValueError(f"Unexpected additional keys in pretrained weights: {incompatible.unexpected_keys}")
        
        for key in incompatible.missing_keys:
            assert isinstance(key, str)
            if key.startswith(self.input_conv) or key.startswith(self.output_lin):
                continue
            raise ValueError(f"Unexpected missing key in pretrained weights: {key}")

        del state_dict

        in_out_params = itertools.chain(
            self.nn.get_submodule(self.input_conv).parameters(),
            self.nn.get_submodule(self.output_lin).parameters(),
        )

        for param in in_out_params:
            torch.nn.init.trunc_normal_(param, mean=0, std=1e-4, a=-0.01, b=0.01)            

    def _patch_input_shape(self) -> None:
        """
        Make the architecture accept a single (grayscale) channel
        """
        old_in_conv = self.nn.get_submodule(self.input_conv)
        assert isinstance(old_in_conv, torch.nn.Conv2d)

        if old_in_conv.in_channels == 1:
            return

        new_in_conv = torch.nn.Conv2d(
            in_channels = self.input_channels, 
            out_channels = old_in_conv.out_channels,
            kernel_size = old_in_conv.kernel_size,
            stride = old_in_conv.stride,
            padding = old_in_conv.padding,
            bias = old_in_conv.bias is not None,
        ).to(
            old_in_conv.weight
        )
        set_submodule(self.nn, self.input_conv, new_in_conv)

    def _patch_output_shape(self) -> None:
        """
        Make the architecture output the correct shape
        """
        old_out_lin = self.nn.get_submodule(self.output_lin)
        assert isinstance(old_out_lin, torch.nn.Linear)
        
        if old_out_lin.out_features == self.n_units_out:
            return

        new_out_lin = torch.nn.Linear(
            in_features = old_out_lin.in_features,
            out_features = self.n_units_out,
            bias = old_out_lin.bias is not None,
        ).to(
            old_out_lin.weight
        )
        set_submodule(self.nn, self.output_lin, new_out_lin)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.nn(x)

In [None]:
VisionQNet(5, "mobilenet_v3_small", "features.0.0", "classifier.3", input_channels=1)
None

### Environments

In [None]:
env_minigrid = make_env_minigrid()
state = env_minigrid.reset()
print("obs:", lovely(state))
print("metadata:", env_minigrid.metadata)
print("action_space:", env_minigrid.action_space)

In [None]:
env_pong = make_env_pong()
state = env_pong.reset()
print("obs:", lovely(state))
print("metadata:", env_pong.metadata)
print("action_space:", env_pong.action_space)

### Measuring score

In [None]:
class Agent(abc.ABC):
    def select_action(self, state, collecting_data: bool) -> int:
        ...

In [None]:
def measure_scores_parallel(
    agent: Agent,
    make_env: Callable,
    make_env_kwargs: list[dict[str, Any]],
    select_action_kwargs: dict[str, Any],
    num_episodes: int | None = None,
    limit_steps: int = None,
    return_unfinished: bool = False,
    video_folder: pathlib.Path | str | None = None,
) -> Iterable[float]:
    if video_folder is not None:
        video_folder = pathlib.Path(video_folder)
        video_folder.mkdir(parents=True, exist_ok=True)
    
    if isinstance(make_env_kwargs, dict):
        make_env_kwargs = [make_env_kwargs] * num_episodes
    elif isinstance(make_env_kwargs, list):
        if num_episodes is None:
            num_episodes = len(make_env_kwargs)
    if len(make_env_kwargs) != num_episodes:
        raise ValueError("num_episodes does not match length of make_env_kwargs")

    def create_env(i, kwargs):
        env = make_env(**kwargs)
        if video_folder is not None:
            env = gym.wrappers.RecordVideo(env, video_folder/str(i))
        return env

    env = gym.vector.AsyncVectorEnv([
        # lambda: create_env(i, kwargs) would not work because of late binding
        functools.partial(create_env, i, kwargs) 
        for i, kwargs in enumerate(make_env_kwargs)
    ])
    
    if limit_steps is None:
        limit_steps = env.get_attr("spec")[0].max_episode_steps
    if limit_steps is None:
        warnings.warn("No limit_steps provided and env does not have max_episode_steps set. This might lead to infinite loops.")
    
    state = env.reset()
    score = np.zeros(num_episodes)
    done_in_past = np.zeros(num_episodes, dtype=bool)
    done_at_the_moment = np.zeros(num_episodes, dtype=bool)

    if limit_steps is None:
        steps = itertools.count()
    else:
        steps = range(limit_steps)
    for step in tqdm(steps, "Measuring score - step"):
        score[done_at_the_moment] = 0
        action = agent.select_action(state, **select_action_kwargs)
        state, reward, done_at_the_moment, _ = env.step(action)
        done_at_the_moment = np.array(done_at_the_moment, dtype=bool)
        reward = np.array(reward)
        # vector env resets done envs automatically and starts over
        # we only want to update score of episodes that were not done before
        # and output the score for that env only once
        score += reward
        done = done_in_past | done_at_the_moment
        done_for_first_time = done_in_past != done # this is not the same as done_at_the_moment, because of autoreset
        done_in_past = done
        
        for idx in done_for_first_time.nonzero()[0]:
            yield score[idx].item()

        if done_in_past.all():
            break
      
    env.close()

    if return_unfinished:
        yield from score[~done_in_past].tolist()

In [None]:
def cvar(scores: np.ndarray | pd.Series, *, percent: float) -> float:
    """
    Computes conditional value at risk (CVaR) of a given set of scores.
    CVaR is the expected value of the worst percents of the scores.
    """
    assert 0 < percent < 100
    if not isinstance(scores, pd.Series):
        scores = pd.Series(scores)
    quantile = scores.quantile(percent / 100)
    return scores[scores <= quantile].mean()
    
cvar(np.arange(100), percent=25)

### DQN

In [None]:
class DQN(lightning.LightningModule, Agent):
    def __init__(
        self,
        model_config: dict[str, Any],
        use_vision: bool,
        num_actions: int,
        use_double_dqn: bool,
        optimizer_class: str,
        epsilon: tuple[float, float, int] | float,
        use_pretrained_weights: str | None = None,
        optimizer_config: dict[str, Any] = None,
        make_env: Callable | None = None,
        tau: float = 1e-3,
        gamma: float = 0.99,
        video_root_folder: pathlib.Path | str | None = None,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["use_pretrained_weights", "make_env", "video_root_folder"])
        self.model_config = model_config

        if optimizer_config is None:
            optimizer_config = {}
        self.optimizer_config = optimizer_config
        self.optimizer_class = optimizer_class
        
        # TODO does it matter that they are initialized with the same weights?
        if use_vision:
            self.qnn = VisionQNet(**model_config, pretrained_weights=use_pretrained_weights)
            self.qnn_target = VisionQNet(**model_config, pretrained_weights=use_pretrained_weights)
        else:
            if use_pretrained_weights is not None:
                raise ValueError("Pretrained weights only supported for vision models")
            self.qnn = MlpQNet(**model_config)
            self.qnn_target = MlpQNet(**model_config)
        
        self.loss_fn = torch.nn.MSELoss()

        self.num_actions = num_actions
        self.use_double_dqn = use_double_dqn
        self.epsilon = epsilon
        self.iter_collecting_data = 0
        self.tau = tau
        self.gamma = gamma

        self.make_env = make_env
        self.validation_step_scores = []
        if video_root_folder is None:
            self.video_root_folder = None
        else:
            self.video_root_folder = pathlib.Path(video_root_folder)
            self.video_root_folder.mkdir(parents=True, exist_ok=True)

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.qnn(observations)

    def get_current_epsilon(self) -> float:
        if isinstance(self.epsilon, float):
            return self.epsilon
        upper, lower, decay_steps = self.epsilon
        epsilon = upper - self.iter_collecting_data / decay_steps * (upper - lower)
        epsilon = max(epsilon, lower)
        return epsilon


    def select_action(self, observations: torch.Tensor | np.ndarray, collecting_data: bool = False, epsilon: float | None = None) -> np.ndarray:
        if epsilon is None:
            epsilon = self.get_current_epsilon()
            
        if collecting_data:
            self.iter_collecting_data += 1

        if isinstance(observations, np.ndarray):
            param = next(self.qnn.parameters())
            observations = torch.from_numpy(observations).to(param)

        with torch.no_grad():
            batch_size = observations.shape[0]
            if epsilon or random.random() > self.epsilon:
                q_values = self.qnn(observations)
                return torch.argmax(q_values, dim=-1).cpu().numpy()
            return np.random.randint(self.num_actions, size=batch_size)

    def configure_optimizers(self):
        optimizer_class = getattr(torch.optim, self.optimizer_class)
        return optimizer_class(self.qnn.parameters(), **self.optimizer_config)

    def training_step(self, batch: Transition, batch_idx):
        bs = batch.reward.shape[0]
        next_q_value = self.qnn_target(batch.next_state)

        if self.use_double_dqn:
            next_q_value_local = self.qnn(batch.next_state)
            next_action = torch.argmax(next_q_value_local, dim=-1)
            next_value_estimate = next_q_value[range(bs), next_action]
        else:
            next_value_estimate = torch.max(next_q_value, dim=-1).values
        
        td_target = batch.reward + self.gamma * next_value_estimate * (~batch.done)
        td_target = td_target.detach()
        td_pred = self.qnn(batch.state)[range(bs), batch.action]
        td_error = self.loss_fn(td_pred, td_target)
        self.log("train/td_error", td_error)
        self.log("epsilon", self.get_current_epsilon())
        return td_error
    
    def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
        self.soft_update_target_nn()
        return super().on_train_batch_end(outputs, batch, batch_idx)

    @torch.no_grad()
    def soft_update_target_nn(self):
        params = zip(self.qnn_target.parameters(), self.qnn.parameters())
        for target_param, param in params:
            target_param.copy_(self.tau * param + (1.0 - self.tau) * target_param)

    def validation_step(self, batch, batch_idx):
        if self.video_root_folder is None:
            video_folder = None
        else:
            video_folder = self.video_root_folder / f"step-{self.trainer.global_step}-batch-{batch_idx}"
        
        scores = measure_scores_parallel(
            agent=self,
            make_env=self.make_env,
            make_env_kwargs=batch,
            select_action_kwargs={},
            video_folder=video_folder,
        )
        scores = list(scores)
        self.validation_step_scores.extend(scores)

        if batch_idx == 0 and video_folder is not None:
            for logger in self.loggers:
                if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                    for worker_dir in video_folder.glob("*"):
                        logger.log_metrics({
                            f"video/worker_{worker_dir.name}": wandb.Video(str(video))
                            for video in worker_dir.glob("*0.mp4")
                        }, step=self.trainer.global_step)
                    break

        return scores
    
    def on_validation_epoch_end(self):
        scores = torch.tensor(self.validation_step_scores)
        self.log_simulation_metrics(scores)
        self.validation_step_scores.clear()

    def simulation_metrics(self, scores: torch.Tensor | np.ndarray) -> dict[str, float]:
        metrics = {
            "score/avg": float(scores.mean()),
            "score/std": float(scores.std()),
            "score/min": float(scores.min()),
            "score/max": float(scores.max()),
            "score/num_episodes": float(len(scores)),
        }
        for p in [25, 50, 75]:
            metrics[f"score/percentile_{p}"] = float(np.percentile(scores, p))
            metrics[f"score/cvar_{p}"] = float(cvar(scores, percent=p))
        return metrics

    def log_simulation_metrics(self, scores: Iterable[float], metrics: dict[str, float] | None = None):
        if metrics is None:
            metrics = self.simulation_metrics(scores)

        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                logger.log_metrics({f"score/score": wandb.Histogram(scores)}, step=self.global_step)
        
        self.log_dict(metrics)
            


### Experience collection

In [None]:
class DataGeneratorParallel:
    def __init__(
        self,
        agent: Agent,
        make_env: Callable[[int | None, bool], gym.Env],
        make_env_kwargs: dict[str, Any] | list[dict[str, Any]] | None = None,
        num_parallel_envs: int | None = None,
    ) -> None:
        if num_parallel_envs is None:
            if make_env_kwargs is None:
                raise ValueError("Either make_env_kwargs or num_parallel_envs must be set")
            if isinstance(make_env_kwargs, dict):
                raise ValueError("num_parallel_envs must be set if make_env_kwargs is a dict")
        
        if make_env_kwargs is None:
            make_env_kwargs = {}
        if isinstance(make_env_kwargs, dict):
            make_env_kwargs = [make_env_kwargs for _ in range(num_parallel_envs)]
        if num_parallel_envs is None:
            num_parallel_envs = len(make_env_kwargs)
        if len(make_env_kwargs) != num_parallel_envs:
            raise ValueError("make_env_kwargs must have the same length as num_parallel_envs")

        self.agent = agent
        self.make_env = make_env
        self.make_env_kwargs = make_env_kwargs
        self.num_parallel_envs = num_parallel_envs

    def __iter__(self) -> Iterable[Transition]:
        envs = gym.vector.AsyncVectorEnv([
            functools.partial(self.make_env, **kwargs)
            for kwargs in self.make_env_kwargs
        ])
        states = envs.reset()
        while True:
            actions = self.agent.select_action(states, collecting_data=True)
            next_states, rewards, dones, _ = envs.step(actions)
            for s, a, r, s_, d in zip(states, actions, rewards, next_states, dones):
                yield Transition(s, a, r, s_, d)
            states = next_states


class DataCollectionCallback(lightning.Callback):
    def __init__(
        self,
        parallel_experience_generator: DataGeneratorParallel,
        buffer: ReplayBuffer,
        collect_every_n_steps: int,
        collect_num_steps_in_each_env: int,
    ) -> None:
        super().__init__()
        self.parallel_experience_collector = parallel_experience_generator
        self.buffer = buffer
        self.collect_every_n_steps = collect_every_n_steps
        self.collect_num_steps_in_each_env = collect_num_steps_in_each_env
        self.experience_iterator = iter(self.parallel_experience_collector)

    def on_train_start(self, trainer, pl_module) -> None:
        self.collect_experience()

    def on_train_batch_end(self, trainer: lightning.Trainer, pl_module, outputs, batch, batch_idx) -> None:
        if trainer.global_step % self.collect_every_n_steps == 0:
            self.collect_experience()

    def collect_experience(self):
        for step in range(self.collect_num_steps_in_each_env):
            transition = next(self.experience_iterator)
            self.buffer.add(transition)


### Hyperparameters

In [None]:
learning_rate = 1e-4
epsilon_upper = 0.60
epsilon_lower = 0.01
epsilon_decay_steps = 10_000
gamma = 0.99
tau = 0.01
batch_size = 128
use_double_dqn = True

buffer_capacity = 100_000
collect_data_num_parallel_envs = 10
collect_data_every_n_steps = 1000 # TODO
collect_data_num_steps_in_each_env = 2500  # TODO

val_check_interval = 1000
valid_num_parallel_envs = 10
valid_num_batches = 1

use_pretrained_weights = None # set 'DEFAULT' to use weights from torchvision

accelerator = "cpu"
precision = "16-mixed"
devices = []

training_env = "pong"
model_config_name = "efficientnet-b0" # or "mobilenet_v3_small" or "mlp"

checkpoint_monitor = ("score/avg", "max")


In [None]:
trainer_devices = "auto" if accelerator == "cpu" else devices

In [None]:
if training_env not in ["pong", "gridworld"]:
    raise ValueError("training_env must be either 'pong' or 'gridworld'")

if training_env == "pong":
    make_env = make_env_pong
else:
    make_env = make_env_minigrid

In [None]:
num_actions = make_env().action_space.n

mlp_config = {
    "num_actions": num_actions,
}

mobilenet_v3_small_config = {
    "actions": num_actions,
    "architecture": "mobilenet_v3_small",
    "input_conv": "features.0.0",
    "output_lin": "classifier.3",
    "input_channels": 4,
}

efficientnet_b0_config = {
    "actions": num_actions,
    "architecture": "efficientnet_b0",
    "input_conv": "features.0.0",
    "output_lin": "classifier.1",
    "input_channels": 4,
}

if model_config_name == "mlp":
    model_config = mlp_config
elif model_config_name == "mobilenet_v3_small":
    model_config = mobilenet_v3_small_config
elif model_config_name == "efficientnet-b0":
    model_config = efficientnet_b0_config
else:
    raise ValueError(f"Unknown model_config: {model_config_name}")

In [None]:
os.makedirs("./wandb", exist_ok=True)
os.makedirs("./videos", exist_ok=True)

wandb_logger = lightning.pytorch.loggers.WandbLogger(
    project="jku-deep-rl_dqn",
    group=f"dqn-{training_env}",
    save_dir="./wandb",
)
wandb_experiment: wandb.wandb_run.Run = wandb_logger.experiment

In [None]:
dqn = DQN(
    make_env = make_env,
    model_config = model_config,
    use_vision = model_config != mlp_config,
    num_actions = num_actions,
    use_double_dqn = use_double_dqn,
    optimizer_class = "AdamW",
    optimizer_config = dict(lr=learning_rate),
    epsilon = (epsilon_upper, epsilon_lower, epsilon_decay_steps),
    gamma = gamma,
    tau = tau,
    use_pretrained_weights = use_pretrained_weights,
    video_root_folder = f"./videos/{wandb_experiment.name}",
)

In [None]:
buffer = ReplayBuffer(capacity=buffer_capacity, seed=42)

experience_generator = DataGeneratorParallel(
    agent = dqn,
    make_env = make_env,
    make_env_kwargs = [{} for _ in range(collect_data_num_parallel_envs)],
    num_parallel_envs = collect_data_num_parallel_envs,
)

experience_collection_callback = DataCollectionCallback(
    parallel_experience_generator = experience_generator,
    buffer = buffer,
    collect_every_n_steps = collect_data_every_n_steps,
    collect_num_steps_in_each_env = collect_data_num_steps_in_each_env,
)

monitor_metric, metric_mode = checkpoint_monitor
checkpointing = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor=monitor_metric,
    mode=metric_mode,
    save_last=True,
    save_top_k=10,
    dirpath=f"checkpoints/{wandb_logger.experiment.name}/",
    filename="step={step}_score_avg={score/avg:.2f}",
    auto_insert_metric_name=False,
)

In [None]:
train_loader = torch.utils.data.DataLoader(
    buffer,
    batch_size = batch_size,
    pin_memory = True,
    shuffle = False, # ReplayBuffer already samples randomly
)

valid_make_env_kwargs = [
    {"seed": env_idx * valid_num_batches + batch_ids}
    for env_idx in range(valid_num_parallel_envs) for batch_ids in range(valid_num_batches)
]

# validation loader just gives parameters for creating the gym env
valid_loader = torch.utils.data.DataLoader(
    valid_make_env_kwargs,
    batch_size = valid_num_parallel_envs,
    shuffle = False,
    drop_last = False,
    collate_fn=lambda x: x,
)


In [None]:
wandb_logger.experiment.config.update({
    "use_pretrained_weights": use_pretrained_weights,
    "buffer_capacity": buffer_capacity,
    "collect_data": {
        "num_parallel_envs": collect_data_num_parallel_envs,
        "every_n_steps": collect_data_every_n_steps,
        "num_steps_in_each_env": collect_data_num_steps_in_each_env,
    },
    "checkpointing": {
        "monitor": checkpointing.monitor,
        "mode": checkpointing.mode,
        "save_last": checkpointing.save_last,
        "save_top_k": checkpointing.save_top_k,
    },
    "train_loader": {
        "batch_size": train_loader.batch_size,
    },
    "validation": {
        "num_parallel_envs": valid_num_parallel_envs,
        "make_env_kwargs": valid_make_env_kwargs,
        "total_runs": len(valid_make_env_kwargs),
    }
})

wandb_logger.experiment.tags += (
    "double-dqn-on" if dqn.use_double_dqn else "double-dqn-off",
    f"init-pretrained-{use_pretrained_weights}",
)

if "architecture" in dqn.model_config:
    wandb_logger.experiment.tags += tuple([
        f"architecture-{dqn.model_config['architecture']}",
    ])


In [None]:
trainer = lightning.Trainer(
    logger=wandb_logger,
    max_epochs=-1,
    val_check_interval=val_check_interval,
    devices=trainer_devices,
    precision=precision,
    accelerator=accelerator,
    callbacks=[
        experience_collection_callback,
        checkpointing,
    ]
)

In [None]:
trainer.fit(
    dqn,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)