In [1]:
import os
import abc
import random
import itertools
import collections
import warnings
from typing import NamedTuple, Any, Iterable, Callable

import cv2
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.utils.data
import lightning
import lightning.pytorch
import lightning.pytorch.loggers
import lightning.pytorch.callbacks
import gym_minigrid
import gym_minigrid.wrappers
import gym
import lovely_tensors
import lovely_numpy
import wandb
from tqdm.auto import tqdm

from envs import make_env_minigrid, make_env_pong

cv2.ocl.setUseOpenCL(False)


def lovely(x):
    "summarizes important tensor properties"
    if isinstance(x, np.ndarray):
        return lovely_numpy.lovely(x)
    return lovely_tensors.lovely(x)




### Replay Buffer

In [2]:
from __future__ import annotations

class Transition(NamedTuple):
    state: torch.Tensor | np.ndarray
    action: torch.Tensor | np.ndarray
    reward: torch.Tensor | np.ndarray
    next_state: torch.Tensor | np.ndarray
    done: torch.Tensor | np.ndarray | bool


class ReplayBuffer(torch.utils.data.IterableDataset):
    def __init__(self, capacity: int, seed = None):
        self.capacity = capacity
        self.buffer: list[Transition] = []
        self.seed = seed
        self.rng = random.Random(seed)
    
    def __next__(self):
        transition = self.rng.choice(self.buffer)
        transition = Transition(
            state = torch.tensor(transition.state, dtype=torch.float32),
            action = torch.tensor(transition.action, dtype=torch.int64),
            reward = torch.tensor(transition.reward, dtype=torch.float32),
            next_state = torch.tensor(transition.next_state, dtype=torch.float32),
            done = torch.tensor(transition.done.item(), dtype=torch.bool),
        )
        return transition
    
    def __iter__(self):
        return self
    
    def add(self, t: Transition):
        self.buffer.append(t)
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)
    

### Architectures

In [3]:
from __future__ import annotations


class MlpQNet(torch.nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.num_actions = num_actions
        self.fc = torch.nn.Sequential(
            torch.nn.Flatten(start_dim=1), 
            torch.nn.Linear(3*7**2, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, num_actions)
        )

    def forward(self, x) -> torch.Tensor:
        if isinstance(x, np.ndarray):
            param = next(self.parameters())
            x = torch.tensor(x).to(param)

        if len(x.shape) == 3:
            x = x.unsqueeze(dim=0)

        return self.fc(x)


def set_submodule(module: torch.nn.Module, submodule_path: str, submodule: torch.nn.Module) -> None:
    *parent_path, attr_name = submodule_path.split(".")
    parent = module
    for child in parent_path:
        parent = getattr(parent, child)
    setattr(parent, attr_name, submodule)


# Copied from my previous homework
class VisionQNet(torch.nn.Module):
    """
    Policy network is compatible with any CNN architecture as long as:
    1) One of the conv layers determines how many channels it can process
    2) One of the linear layers determines the number of output neurons
    
    This should be enough to support all standard architectures

    >>> policy_nn = PolicyNetwork(
    ...     actions=5,
    ...     architecture="efficientnet_b0",
    ...     input_conv="features.0.0",
    ...     output_lin="classifier.1",
    ...     pretrained_weights=None, # "IMAGENET1K_V1"
    ... )
    >>> inputs = torch.rand(size=(16, 1, 64, 64))
    >>> policy_nn(inputs).shape
    torch.Size([16, 5])

    """

    def __init__(
        self,
        actions: int | dict[int, str],
        architecture: str,
        input_conv: str,
        output_lin: str,
        pretrained_weights = None,
        seed: int = 42,
    ) -> None:
        super(VisionQNet, self).__init__()
        if isinstance(actions, int):
            self.n_units_out = actions
            self.actions = {i: f"action_{i}" for i in range(actions)}
        else:
            self.n_units_out = len(actions)
            self.actions = actions
        self.architecture = architecture
        self.input_conv = input_conv
        self.output_lin = output_lin
        self.nn = torchvision.models.get_model(self.architecture)
        self._patch_input_shape()
        self._patch_output_shape()
        self.init(pretrained_weights, seed)

    def init(self, pretrained_weights = None, seed: int = None) -> None:
        """
        Initializes the model parameters, either randomly or with pretrained weights.
        Patches the architecture to have desired input and output shape.
        """
        if seed is not None:
            torch.random.manual_seed(seed)

        state_dict = torchvision.models.get_model(self.architecture, weights=pretrained_weights).state_dict()

        for key in list(state_dict.keys()):
            if key.startswith(self.input_conv) or key.startswith(self.output_lin):
                state_dict.pop(key)

        incompatible = self.nn.load_state_dict(state_dict, strict=False)
        if len(incompatible.unexpected_keys) != 0:
            raise ValueError(f"Unexpected additional keys in pretrained weights: {incompatible.unexpected_keys}")
        
        for key in incompatible.missing_keys:
            assert isinstance(key, str)
            if key.startswith(self.input_conv) or key.startswith(self.output_lin):
                continue
            raise ValueError(f"Unexpected missing key in pretrained weights: {key}")

        del state_dict

        in_out_params = itertools.chain(
            self.nn.get_submodule(self.input_conv).parameters(),
            self.nn.get_submodule(self.output_lin).parameters(),
        )

        for param in in_out_params:
            torch.nn.init.trunc_normal_(param, mean=0, std=1e-4, a=-0.01, b=0.01)            

    def _patch_input_shape(self) -> None:
        """
        Make the architecture accept a single (grayscale) channel
        """
        old_in_conv = self.nn.get_submodule(self.input_conv)
        assert isinstance(old_in_conv, torch.nn.Conv2d)

        if old_in_conv.in_channels == 1:
            return

        new_in_conv = torch.nn.Conv2d(
            in_channels = 1, 
            out_channels = old_in_conv.out_channels,
            kernel_size = old_in_conv.kernel_size,
            stride = old_in_conv.stride,
            padding = old_in_conv.padding,
            bias = old_in_conv.bias is not None,
        ).to(
            old_in_conv.weight
        )
        set_submodule(self.nn, self.input_conv, new_in_conv)

    def _patch_output_shape(self) -> None:
        """
        Make the architecture output the correct shape
        """
        old_out_lin = self.nn.get_submodule(self.output_lin)
        assert isinstance(old_out_lin, torch.nn.Linear)
        
        if old_out_lin.out_features == self.n_units_out:
            return

        new_out_lin = torch.nn.Linear(
            in_features = old_out_lin.in_features,
            out_features = self.n_units_out,
            bias = old_out_lin.bias is not None,
        ).to(
            old_out_lin.weight
        )
        set_submodule(self.nn, self.output_lin, new_out_lin)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # select only last channel if there are more
        # TODO revisit this. it's coming from previous homework and maybe not relevant anymore
        x = x[:, (-1,), :, :]
        return self.nn(x)

In [4]:
VisionQNet(5, "mobilenet_v3_small", "features.0.0", "classifier.3")
None

### Environments

In [5]:
env_minigrid = make_env_minigrid()
state = env_minigrid.reset()
print("obs:", lovely(state))
print("metadata:", env_minigrid.metadata)
print("action_space:", env_minigrid.action_space)

obs: array[3, 7, 7] f32 n=147 x∈[0., 8.000] μ=2.020 σ=2.075
metadata: {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 10, 'render_modes': ['human', 'rgb_array', 'single_rgb_array'], 'render_fps': 10}
action_space: Discrete(7)


In [6]:
env_pong = make_env_pong()
state = env_pong.reset()
print("obs:", lovely(state))
print("metadata:", env_pong.metadata)
print("action_space:", env_pong.action_space)

obs: array[4, 84, 84] u8 n=28224 x∈[52, 236] μ=106.303 σ=47.242
metadata: {'render_modes': ['human', 'rgb_array']}
action_space: Discrete(6)


A.L.E: Arcade Learning Environment (version 0.7.5+db37282)
[Powered by Stella]


### Measuring score

In [7]:
class Agent(abc.ABC):
    def select_action(self, state) -> int:
        ...

In [8]:

def measure_scores_parallel(agent: Agent, make_env: Callable, make_env_kwargs: list[dict[str, Any]], select_action_kwargs: dict[str, Any], num_episodes: int | None = None, limit_steps: int = None, return_unfinished: bool = False) -> Iterable[float]:
    if isinstance(make_env_kwargs, dict):
        make_env_kwargs = [make_env_kwargs] * num_episodes
    elif isinstance(make_env_kwargs, list):
        if num_episodes is None:
            num_episodes = len(make_env_kwargs)
    if len(make_env_kwargs) != num_episodes:
        raise ValueError("num_episodes does not match length of make_env_kwargs")

    env = gym.vector.AsyncVectorEnv([lambda: make_env(**kwargs) for kwargs in make_env_kwargs])
    
    if limit_steps is None:
        limit_steps = env.get_attr("spec")[0].max_episode_steps
    if limit_steps is None:
        warnings.warn("No limit_steps provided and env does not have max_episode_steps set. This might lead to infinite loops.")
    
    state = env.reset()
    score = np.zeros(num_episodes)
    done_in_past = np.zeros(num_episodes, dtype=bool)
    done_at_the_moment = np.zeros(num_episodes, dtype=bool)

    if limit_steps is None:
        steps = itertools.count()
    else:
        steps = range(limit_steps)
    for step in tqdm(steps, "Measuring score - step"):
        score[done_at_the_moment] = 0
        action = agent.select_action(state)
        state, reward, done_at_the_moment, _ = env.step(action)
        done_at_the_moment = np.array(done_at_the_moment, dtype=bool)
        reward = np.array(reward)
        # vector env resets done envs automatically and starts over
        # we only want to update score of episodes that were not done before
        # and output the score for that env only once
        score += reward
        done = done_in_past | done_at_the_moment
        done_for_first_time = done_in_past != done # this is not the same as done_at_the_moment, because of autoreset
        done_in_past = done
        
        for idx in done_for_first_time.nonzero()[0]:
            yield score[idx].item()

        if done_in_past.all():
            break
      
    env.close()

    if return_unfinished:
        yield from score[~done_in_past].tolist()

In [9]:
import numpy as np
score = np.zeros(5)
done_in_past = np.array([False, True, False, True, False])
reward = np.array([1, 2, 3, 4, 5])
score[~done_in_past] += reward[~done_in_past]
score

array([1., 0., 3., 0., 5.])

In [10]:
def cvar(scores: np.ndarray | pd.Series, *, percent: float) -> float:
    """
    Computes conditional value at risk (CVaR) of a given set of scores.
    CVaR is the expected value of the worst percents of the scores.
    """
    assert 0 < percent < 100
    if not isinstance(scores, pd.Series):
        scores = pd.Series(scores)
    quantile = scores.quantile(percent / 100)
    return scores[scores <= quantile].mean()
    
cvar(np.arange(100), percent=25)

12.0

### DQN

In [11]:
class DQN(lightning.LightningModule, Agent):
    def __init__(
        self,
        model_config: dict[str, Any],
        num_actions: int,
        use_double_dqn: bool,
        optimizer_class: str,
        use_pretrained_weights: str | None = None,
        optimizer_config: dict[str, Any] = None,
        make_env: Callable | None = None,
        epsilon: float = 0.1,
        tau: float = 1e-3,
        gamma: float = 0.99,
        use_vision: bool = True,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["use_pretrained_weights", "make_env"])
        self.model_config = model_config

        if optimizer_config is None:
            optimizer_config = {}
        self.optimizer_config = optimizer_config
        self.optimizer_class = optimizer_class
        
        # TODO does it matter that they are initialized with the same weights?
        if use_vision:
            self.qnn = VisionQNet(**model_config, pretrained_weights=use_pretrained_weights)
            self.qnn_target = VisionQNet(**model_config, pretrained_weights=use_pretrained_weights)
        else:
            if use_pretrained_weights is not None:
                raise ValueError("Pretrained weights only supported for vision models")
            self.qnn = MlpQNet(**model_config)
            self.qnn_target = MlpQNet(**model_config)
        
        self.loss_fn = torch.nn.MSELoss()

        self.num_actions = num_actions
        self.use_double_dqn = use_double_dqn
        self.epsilon = epsilon
        self.tau = tau
        self.gamma = gamma

        self.make_env = make_env
        self.validation_step_scores = []

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.qnn(observations)

    def select_action(self, observations: torch.Tensor, greedy: bool = False) -> np.ndarray:
        with torch.no_grad():
            batch_size = observations.shape[0]
            if greedy or random.random() > self.epsilon:
                q_values = self.qnn(observations)
                return q_values.argmax(dim=-1).cpu().numpy()
            return np.random.randint(self.num_actions, size=batch_size)

    def configure_optimizers(self):
        optimizer_class = getattr(torch.optim, self.optimizer_class)
        return optimizer_class(self.qnn.parameters(), **self.optimizer_config)

    def training_step(self, batch: Transition, batch_idx):
        bs = batch.reward.shape[0]
        next_q_value = self.qnn_target(batch.next_state)

        if self.use_double_dqn:
            next_q_value_local = self.qnn(batch.next_state)
            next_action = torch.argmax(next_q_value_local, dim=-1)
            next_value_estimate = next_q_value[range(bs), next_action]
        else:
            next_value_estimate = torch.max(next_q_value, dim=-1).values
        
        td_target = batch.reward + self.gamma * next_value_estimate * (~batch.done)
        td_target = td_target.detach()
        td_pred = self.qnn(batch.state)[range(bs), batch.action]
        td_error = self.loss_fn(td_pred, td_target)
        self.log("td_error", td_error)
        return td_error
    
    def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
        self.soft_update_target_nn()
        # TODO decay epsilon
        # self.epsilon = max(self.epsilon_lb, self.epsilon_ub - self.global_step / self.epsilon_decay)
        return super().on_train_batch_end(outputs, batch, batch_idx)

    @torch.no_grad()
    def soft_update_target_nn(self):
        params = zip(self.qnn_target.parameters(), self.qnn.parameters())
        for target_param, param in params:
            target_param.copy_(self.tau * param + (1.0 - self.tau) * target_param)

    def validation_step(self, batch, batch_idx):
        scores = measure_scores_parallel(
            agent=self,
            make_env=self.make_env,
            make_env_kwargs=batch,
            select_action_kwargs={},#{"greedy": True},
        )
        scores = list(scores)
        self.validation_step_scores.extend(scores)
        return scores
    
    def on_validation_epoch_end(self):
        scores = torch.tensor(self.validation_step_scores)
        self.log_simulation_metrics(scores)
        self.validation_step_scores.clear()

    def simulation_metrics(self, scores: torch.Tensor | np.ndarray) -> dict[str, float]:
        metrics = {
            "score/avg": float(scores.mean()),
            "score/std": float(scores.std()),
            "score/min": float(scores.min()),
            "score/max": float(scores.max()),
            "score/num_episodes": float(len(scores)),
        }
        for p in [25, 50, 75]:
            metrics[f"score/percentile_{p}"] = float(np.percentile(scores, p))
            metrics[f"score/cvar_{p}"] = float(cvar(scores, percent=p))
        return metrics

    def log_simulation_metrics(self, scores: Iterable[float], metrics: dict[str, float] | None = None):
        if metrics is None:
            metrics = self.simulation_metrics(scores)

        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                logger.log_metrics({f"score/score": wandb.Histogram(scores)}, step=self.global_step)
        
        self.log_dict(metrics)
            


### Experience collection

In [12]:
class ParallelExperienceGenerator:
    def __init__(
        self,
        agent: Agent,
        make_env: Callable[[int | None, bool], gym.Env],
        num_parallel_envs: int,
        make_env_kwargs: dict[str, Any] = None,
    ) -> None:
        if make_env_kwargs is None:
            make_env_kwargs = {}

        self.agent = agent
        self.make_env = make_env
        self.num_parallel_envs = num_parallel_envs
        self.make_env_kwargs = make_env_kwargs

        self.envs = gym.vector.AsyncVectorEnv([
            lambda: make_env(**make_env_kwargs)
            for _ in range(num_parallel_envs)
        ])

    def __iter__(self) -> Iterable[Transition]:
        states = self.envs.reset()
        while True:
            actions = self.agent.select_action(states)
            next_states, rewards, dones, _ = self.envs.step(actions)
            for s, a, r, s_, d in zip(states, actions, rewards, next_states, dones):
                yield Transition(s, a, r, s_, d)
            states = next_states


class ExperienceCollectionCallback(lightning.Callback):
    def __init__(
        self,
        parallel_experience_generator: ParallelExperienceGenerator,
        buffer: ReplayBuffer,
        collect_every_n_steps: int,
        collect_num_steps_in_each_env: int,
    ) -> None:
        super().__init__()
        self.parallel_experience_collector = parallel_experience_generator
        self.buffer = buffer
        self.collect_every_n_steps = collect_every_n_steps
        self.collect_num_steps_in_each_env = collect_num_steps_in_each_env
        self.experience_iterator = iter(self.parallel_experience_collector)

    def on_train_start(self, trainer, pl_module) -> None:
        self.collect_experience()

    def on_train_batch_end(self, trainer: lightning.Trainer, pl_module, outputs, batch, batch_idx) -> None:
        if trainer.global_step % self.collect_every_n_steps == 0:
            self.collect_experience()

    def collect_experience(self):
        for step in range(self.collect_num_steps_in_each_env):
            transition = next(self.experience_iterator)
            self.buffer.add(transition)


### Training Minigrid

In [13]:
LR = 1e-4
EPSILON = 0.1
TAU = 0.01
BATCH_SIZE = 128

BUFFER_CAPACITY = 50_000
COLLECT_EXP_NUM_PARALLEL_ENVS = 10
COLLECT_EXP_EVERY_N_STEPS = 1000 # TODO
COLLECT_EXP_NUM_STEPS_IN_EACH_ENV = 2500  # TODO

VAL_CHECK_INTERVAL = 1000
VAL_NUM_PARALLEL_ENVS = 10
VAL_RUNS = 1

In [14]:
NUM_ACTIONS = env_minigrid.action_space.n
USE_PRETRAINED_WEIGHTS = None

mlp_config = {
    "num_actions": NUM_ACTIONS,
}
make_env = make_env_minigrid

dqn = DQN(
    make_env = make_env,
    model_config = mlp_config,
    use_vision = False,
    num_actions = NUM_ACTIONS,
    use_double_dqn = False,
    optimizer_class = "AdamW",
    optimizer_config = dict(lr=LR),
    epsilon = EPSILON,
    tau = TAU,
    use_pretrained_weights=USE_PRETRAINED_WEIGHTS,
)

In [15]:
buffer = ReplayBuffer(
    capacity = BUFFER_CAPACITY,
    seed = 42
)

In [16]:
experience_generator = ParallelExperienceGenerator(
    agent = dqn,
    make_env = make_env,
    num_parallel_envs = COLLECT_EXP_NUM_PARALLEL_ENVS,
)

In [17]:
experience_collection_callback = ExperienceCollectionCallback(
    parallel_experience_generator = experience_generator,
    buffer = buffer,
    collect_every_n_steps = COLLECT_EXP_EVERY_N_STEPS,
    collect_num_steps_in_each_env = COLLECT_EXP_NUM_STEPS_IN_EACH_ENV,
)

In [18]:
os.makedirs("./wandb", exist_ok=True)

wandb_logger = lightning.pytorch.loggers.WandbLogger(
    project="jku-deep-rl_dqn",
    group="dqn-minigrid",
    save_dir="./wandb",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmarkcheeky[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
wandb_logger.experiment.config.update({
    "use_pretrained_weights": USE_PRETRAINED_WEIGHTS,
    "buffer_capacity": BUFFER_CAPACITY,
    "collect_experience": {
        "num_parallel_envs": COLLECT_EXP_NUM_PARALLEL_ENVS,
        "every_n_steps": COLLECT_EXP_EVERY_N_STEPS,
        "num_steps_in_each_env": COLLECT_EXP_NUM_STEPS_IN_EACH_ENV,
    }
})

wandb_logger.experiment.tags += (
    "double-dqn-on" if dqn.use_double_dqn else "double-dqn-off",
    f"init-pretrained-{USE_PRETRAINED_WEIGHTS}",
)

if "architecture" in dqn.model_config:
    wandb_logger.experiment.tags += (
        f"architecture-{dqn.model_config['architecture']}"
    )


In [20]:
checkpointing = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="score/avg",
    mode="max",
    save_last=True,
    save_top_k=10,
    dirpath=f"checkpoints/{wandb_logger.experiment.name}/",
    filename="step={step}_score_avg={score/avg:.2f}",
    auto_insert_metric_name=False,
)

In [21]:
trainer = lightning.Trainer(
    logger = wandb_logger,
    max_epochs = -1,
    val_check_interval = VAL_CHECK_INTERVAL,
    #devices = [0],
    precision="16-mixed",
    accelerator="cpu",
    callbacks=[
        experience_collection_callback,
        checkpointing,
    ]
)

  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [22]:
train_loader = torch.utils.data.DataLoader(
    buffer,
    batch_size = BATCH_SIZE,
    pin_memory = True,
    shuffle = False, # ReplayBuffer already samples randomly
)

valid_make_env_kwargs = [
    {} for i in range(VAL_NUM_PARALLEL_ENVS)
]

# validation loader just gives parameters for creating the gym env
valid_loader = torch.utils.data.DataLoader(
    [valid_make_env_kwargs] * VAL_RUNS,
    batch_size = 1,
    shuffle = False,
    drop_last = False,
)


In [23]:
trainer.fit(
    dqn,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)


  | Name       | Type    | Params
---------------------------------------
0 | qnn        | MlpQNet | 120 K 
1 | qnn_target | MlpQNet | 120 K 
2 | loss_fn    | MSELoss | 0     
---------------------------------------
241 K     Trainable params
0         Non-trainable params
241 K     Total params
0.965     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Measuring score - step: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Measuring score - step: 0it [00:00, ?it/s]

### Training Pong

In [None]:
NUM_ACTIONS = env_pong.action_space.n

mobilenet_v3_small_config = {
    "actions": NUM_ACTIONS,
    "architecture": "mobilenet_v3_small",
    "input_conv": "features.0.0",
    "output_lin": "classifier.3",
}

efficientnet_b0_config = {
    "actions": NUM_ACTIONS,
    "architecture": "efficientnet_b0",
    "input_conv": "features.0.0",
    "output_lin": "classifier.1",
}