In [1]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
import numpy as np

In [6]:
class QNet(nn.Module):
    n_states: int
    hidden_features: Sequence[int]
    n_actions: int

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        features = [self.n_states] + list(self.hidden_features) + [self.n_actions]
        for i, feat in enumerate(features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(features) - 1:
                x = nn.relu(x)
            # providing a name is optional though!
            # the default autonames would be "Dense_0", "Dense_1", ...
        return x

class DuelingQNet(QNet):
    hidden_value_features: Sequence[int]
    hidden_advantage_features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        hidden = inputs
        features = [self.n_states] + list(self.hidden_features)
        for i, feat in enumerate(features):
            hidden = nn.Dense(feat, name=f'hidden_layers_{i}')(hidden)
            hidden = nn.relu(hidden)

        values = hidden
        for i, feat in enumerate(self.hidden_value_features):
            values = nn.Dense(feat, name=f'value_layers_{i}')(values)
            values = nn.relu(values)
        values = nn.Dense(1, name='value')(values)

        advantages = hidden
        for i, feat in enumerate(self.hidden_advantage_features):
            advantages = nn.Dense(feat, name=f'advantage_layers_{i}')(advantages)
            advantages = nn.relu(advantages)
        advantages = nn.Dense(self.n_actions, name='advantage')(advantages)

        qvalues = values + (advantages - jnp.mean(advantages, axis=-1, keepdims=True))
        
        return qvalues
  
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = QNet(n_states=4, hidden_features=[128, 512, 512, 128], n_actions=2)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (4,), 'kernel': (4, 4)}, 'layers_1': {'bias': (128,), 'kernel': (4, 128)}, 'layers_2': {'bias': (512,), 'kernel': (128, 512)}, 'layers_3': {'bias': (512,), 'kernel': (512, 512)}, 'layers_4': {'bias': (128,), 'kernel': (512, 128)}, 'layers_5': {'bias': (2,), 'kernel': (128, 2)}}}
output:
 [[0.26753563 0.49408063]
 [0.17382324 0.3050771 ]
 [0.16901806 0.33530086]
 [0.06821223 0.16916744]]


In [7]:
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = DuelingQNet(n_states=4, hidden_features=[128, 512], hidden_value_features=[64], hidden_advantage_features=[64], n_actions=2)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'advantage': {'bias': (2,), 'kernel': (64, 2)}, 'advantage_layers_0': {'bias': (64,), 'kernel': (512, 64)}, 'hidden_layers_0': {'bias': (4,), 'kernel': (4, 4)}, 'hidden_layers_1': {'bias': (128,), 'kernel': (4, 128)}, 'hidden_layers_2': {'bias': (512,), 'kernel': (128, 512)}, 'value': {'bias': (1,), 'kernel': (64, 1)}, 'value_layers_0': {'bias': (64,), 'kernel': (512, 64)}}}
output:
 [[-0.00040921  0.24043036]
 [-0.00195015  0.11650497]
 [ 0.00692727  0.17678855]
 [-0.03291429  0.12385821]]


In [None]:
class TransitionBuffer:
    def __init__(self, n_states, n_actions, size=1000):
        self._states = np.zeros((size, n_states))
        self._actions = np.zeros((size, n_actions))
        self._rewards = np.zeros(size)
        self._dones = np.zeros(size, dtype=np.bool_)
        self._next_states = np.zeros((size, n_states))
        self._size = size
        self._pointer = 0
        self._full = False

    @property
    def size(self):
        return self._size
    @property
    def full(self):
        return self._full
    @property
    def pointer(self):
        return self._pointer
    

In [None]:
from optax import adam

class Controller:
    def __init__(self, model: nn.Module, optimizer, loss_fn):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn

In [None]:
class Runner:
    def __init__(self, env, controller: Controller, replay_buffer):
        self.env = env
        self.controller = controller
        self.replay_buffer = replay_buffer

In [None]:
class Trainer:
    def __init__(self, model: nn.Module):
        self.model = model

    def train(self):
        pass

In [None]:
class Experiment:
    def __init__(self, runner: Runner, trainer: Trainer):
        self.runner = runner
        self.trainer = trainer

    def run(self):
        pass