In [1]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [2]:
from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork

In [3]:
class StaticEmbeddingModule(FlaxEmbeddingModule):

    def setup(self):
        self.embedding = nn.Dense(self.embedding_dim, use_bias=False)

    def __call__(self):
        indices = jax.nn.one_hot(jax.numpy.arange(0,self.num_embeddings), self.num_embeddings)
        return self.embedding(indices)


In [4]:
class StaticWeightGenerator(FlaxWeightGenerator):

    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array):
        x = self.dense1(embedding)
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [5]:
class MLP(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(4, use_bias=False)(x)
        return x

In [6]:
target_network = MLP()

In [14]:
import math
from typing import Any, Callable, List, Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax._src.tree_util import PyTreeDef
from dataclasses import field
import functools

from hypernn.base_hypernet import BaseHyperNetwork
from hypernn.jax.embedding_module import FlaxEmbeddingModule, FlaxStaticEmbeddingModule
from hypernn.jax.utils import count_jax_params
from hypernn.jax.weight_generator import FlaxStaticWeightGenerator, FlaxWeightGenerator


@functools.partial(jax.jit, static_argnames=('apply_fn'))
def target_forward(apply_fn, param_tree, inputs):
    return apply_fn(param_tree, inputs)

def FlaxHyperNetwork(
    input_shape: Tuple[int, ...],
    target_network: nn.Module,
    embedding_module_constructor: Callable[
        [int, int], FlaxEmbeddingModule
    ] = FlaxStaticEmbeddingModule,
    weight_generator_constructor: Callable[
        [int, int], FlaxStaticWeightGenerator
    ] = FlaxStaticWeightGenerator,
    embedding_dim: int = 100,
    num_embeddings: int = 3,
    hidden_dim: Optional[int] = None,
):
    num_parameters, variables = count_jax_params(target_network, input_shape)
    _value_flat, target_treedef = jax.tree_util.tree_flatten(variables)
    target_weight_shapes = [v.shape for v in _value_flat]

    num_embeddings = num_embeddings
    hidden_dim = hidden_dim
    if hidden_dim is None:
        hidden_dim = math.ceil(num_parameters / num_embeddings)
        if hidden_dim != 0:
            remainder = num_parameters % hidden_dim
            if remainder > 0:
                diff = math.ceil(remainder / hidden_dim)
                num_embeddings += diff

    class FlaxHyperNetwork(nn.Module):
        _target: nn.Module
        target_treedef: PyTreeDef
        num_parameters: int
        num_embeddings: int
        hidden_dim: int
        embedding_module_constructor: Callable[
            [int, int], FlaxEmbeddingModule
        ] = FlaxStaticEmbeddingModule
        weight_generator_constructor: Callable[
            [int, int], FlaxWeightGenerator
        ] = FlaxStaticWeightGenerator
        embedding_dim: int = 100
        target_weight_shapes: Optional[List[Any]] = field(default_factory=list)

        def setup(self):
            self.embedding_module, self.weight_generator = self.get_networks()

        def get_networks(self) -> Tuple[FlaxEmbeddingModule, FlaxWeightGenerator]:
            embedding_module = self.embedding_module_constructor(
                self.embedding_dim, self.num_embeddings
            )
            weight_generator = self.weight_generator_constructor(
                self.embedding_dim, self.hidden_dim
            )
            return embedding_module, weight_generator

        def generate_params(
            self, x: Optional[Any] = None, *args, **kwargs
        ) -> List[jnp.array]:
            embeddings = self.embedding_module()
            params = self.weight_generator(embeddings).reshape(-1)
            param_list = []
            curr = 0
            for shape in self.target_weight_shapes:
                num_params = np.prod(shape)
                param_list.append(params[curr : curr + num_params].reshape(shape))
                curr = curr + num_params
            return param_list

        def __call__(
            self, x: Any, params: Optional[List[jnp.array]] = None
        ) -> Tuple[jnp.array, List[jnp.array]]:
            if params is None:
                params = self.generate_params()
            param_tree = jax.tree_util.tree_unflatten(self.target_treedef, params)
            return target_forward(self._target.apply, param_tree, x), params

    return FlaxHyperNetwork(
        target_network,
        target_treedef,
        num_parameters,
        num_embeddings,
        hidden_dim,
        embedding_module_constructor,
        weight_generator_constructor,
        embedding_dim,
        target_weight_shapes
    )


In [15]:
hyper = FlaxHyperNetwork((1, 8), target_network, embedding_module_constructor=StaticEmbeddingModule)

In [16]:
import jax
import functools

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def sample_actions(apply_fn, hypernetwork_params, obs):
    out, _ =  apply_fn({'params':hypernetwork_params}, jnp.expand_dims(jnp.array(obs), 0))
    return out


In [17]:
from tensorflow_probability.substrates import jax as tfp

def rollout(env, hypernetwork, hypernetwork_params, render=False, seed: int = 0) -> float:
    _, target_params = hypernetwork.apply({'params':hypernetwork_params}, jnp.zeros((1,8)))
    rng = jax.random.PRNGKey(seed)
    obs = env.reset()
    done = False
    observations, actions, rewards, rendereds = [], [], [], []
    while not done:
        rendered = None
        if render:
            rendered = env.render(mode="rgb_array")
            rendereds.append(rendered)

        out =  sample_actions(hypernetwork.apply, hypernetwork_params, np.expand_dims(obs, 0))
        # action_logits = hypernetwork(), params=params)
        dist = tfp.distributions.Categorical(logits=out)
        action = dist.sample(seed=rng).item()
        next_obs, r, done, _ = env.step(action)

        observations.append(obs)
        actions.append(action)
        rewards.append(r)

        obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

In [18]:
rng = jax.random.PRNGKey(0)
params = hyper.init(rng, jnp.ones((1,8)))['params']


In [19]:
params

FrozenDict({
    embedding_module: {
        embedding: {
            kernel: DeviceArray([[ 0.471646  , -1.0702633 ,  0.46410823,  0.15693954,
                           0.20041348, -0.34952757, -0.49351946,  1.2054473 ,
                          -0.84482384, -0.3916509 , -0.5571263 ,  0.47865412,
                          -0.37104475, -0.21414787,  0.71070576, -0.46190038,
                          -0.48479974, -0.4342748 ,  0.24446385,  0.3088161 ,
                           0.5623318 , -0.1759971 , -0.8852003 , -0.16054398,
                           0.37777218,  0.0428249 ,  0.46296674, -0.2089282 ,
                          -1.2209946 , -0.5637566 , -0.02068712,  0.8296711 ,
                          -0.31478047,  0.80464214,  0.01968026, -0.32316324,
                           0.6855821 ,  0.27597412,  0.39884514,  0.00604566,
                           0.14924732,  1.2728313 , -0.22321522,  0.72782063,
                          -0.23503335, -0.12701213,  0.0195619 ,  0.48308682

In [20]:
import gym

env = gym.make("LunarLander-v2")

In [22]:
rollout(env, hyper, params, render=False)

([array([-5.5437087e-04,  1.4023566e+00, -5.6164313e-02, -3.8060778e-01,
          6.4914080e-04,  1.2722072e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-1.0639190e-03,  1.3941774e+00, -5.1817141e-02, -3.6351767e-01,
          1.4957504e-03,  1.6934676e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-1.7573356e-03,  1.3863037e+00, -6.9304220e-02, -3.4994677e-01,
          1.4637661e-03, -6.3969108e-04,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-2.5861741e-03,  1.3784418e+00, -8.2197130e-02, -3.4941489e-01,
          7.8579114e-04, -1.3560888e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-3.5046577e-03,  1.3713793e+00, -9.0732738e-02, -3.1389001e-01,
         -3.1786770e-04, -2.2075048e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-4.2401315e-03,  1.3651805e+00, -7.3315926e-02, -2.7550304e-01,
         -5.4817565e-04, -4.6068798e-03,  0.0000000e+00,  0.0000000e+00]

In [16]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import tqdm

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: tensorboard --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [17]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, learning_rate, input_shape, target_network):
    """Creates initial `TrainState`."""
    hypernet = FlaxHyperNetwork(input_shape, target_network)
    params = hypernet.init(rng, jnp.ones((1, 8)))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=hypernet.apply, params=params, tx=tx)

In [18]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

state = create_train_state(init_rng, 0.0001, (1,8), target_network)


In [19]:
@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_step(apply_fn, state, observations, actions, discounted_rewards):
    def loss_fn(params):
        logits, _ = apply_fn({'params':params}, observations)
        dist = tfp.distributions.Categorical(logits=logits)
        log_probs = dist.log_prob(actions)
        loss = -1 * jnp.sum(discounted_rewards * log_probs)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, grads


In [20]:
def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        target_network,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):
    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, lr, (1,8), target_network)

    bar = tqdm.tqdm(np.arange(num_epochs))

    for i in bar:
        observations, actions, rewards, _ = rollout(env, hypernetwork, state.params)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = jnp.array(observations)
        actions = jnp.array(actions)
        discounted_rewards = jnp.array(discounted_rewards)

        state, loss, grads = train_step(hypernetwork.apply, state, observations, actions, discounted_rewards)

        # metrics = {"loss":loss.item(), "rewards":np.sum(rewards)}

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [21]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)


gpu


In [22]:
reinforce(10000, env, hyper, target_network)

Loss: 77.35972595214844, Sum Reward: -848.3317884003933:   0%|          | 31/10000 [01:41<9:05:47,  3.28s/it]     


KeyboardInterrupt: 