In [409]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [410]:
def count_jax_params(model: nn.Module, input_shape: Optional[Tuple[int, ...]] = None, inputs: Optional[List[jnp.array]] = None) -> int:
    if input_shape is None and inputs is None:
        raise ValueError("Input shape or inputs must be specified")
    if inputs is None:
        inputs = jnp.zeros(input_shape)
    variables = jax.lax.stop_gradient(model.init(jax.random.PRNGKey(0), inputs))

    def count_recursive(d):
        s = 0
        if isinstance(d, int):
            return d
        for k in d:
            s += count_recursive(d[k])
        return s

    param_counts = jax.tree_map(lambda x: int(np.prod(x.shape)), variables)['params']
    return count_recursive(param_counts), variables



In [411]:
import abc

class FlaxEmbeddingModule(nn.Module, metaclass=abc.ABCMeta):
    embedding_dim: int
    num_embeddings: int

    @abc.abstractmethod
    def __call__(self, *args, **kwargs):
        """
            Forward pass to output embeddings
        """



In [412]:
class EmbeddingModule(FlaxEmbeddingModule):

    def setup(self):
        self.embedding = nn.Dense( self.embedding_dim, use_bias=False)

    def __call__(self):
        indices = jax.nn.one_hot(jax.numpy.arange(0,self.num_embeddings), self.num_embeddings)
        return self.embedding(indices)


In [413]:
import abc
import jax.numpy as jnp

class FlaxWeightGenerator(nn.Module, metaclass=abc.ABCMeta):
    embedding_dim: int
    hidden_dim: int

    @abc.abstractmethod
    def __call__(self, embedding: jnp.array, *args, **kwargs):
        """
            Forward pass to output embeddings
        """



In [414]:
class StaticWeightGenerator(FlaxWeightGenerator):

    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array):
        x = self.dense1(embedding)
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [415]:
from typing import Callable, Optional, Tuple, Any, List
import math
from hypernn.base_hypernet import BaseHyperNetwork

def FlaxHyperNetwork(
    input_shape: Tuple[int, ...],
    target_network: nn.Module,
    embedding_module_constructor: Callable[[int, int], FlaxEmbeddingModule] = EmbeddingModule,
    weight_generator_constructor: Callable[[int, int], FlaxWeightGenerator] = StaticWeightGenerator,
    embedding_dim: int = 100,
    num_embeddings: int = 3,
    hidden_dim: Optional[int] = None
):
    class FlaxHyperNetwork(nn.Module, BaseHyperNetwork):
        _target: nn.Module
        embedding_module_constructor: Callable[[int, int], FlaxEmbeddingModule] = EmbeddingModule
        weight_generator_constructor: Callable[[int, int], FlaxWeightGenerator] = StaticWeightGenerator
        embedding_dim: int = 100

        def setup(self):
            self.num_parameters, variables = count_jax_params(self._target, input_shape)
            self.setup_dims()
            self.embedding_module, self.weight_generator = self.get_networks()

            _value_flat, self.target_treedef = jax.tree_util.tree_flatten(variables)
            self.target_weight_shapes = [v.shape for v in _value_flat]

        @nn.nowrap
        def setup_dims(self):
            self.num_embeddings = num_embeddings
            self.hidden_dim = hidden_dim
            if self.hidden_dim is None:
                self.hidden_dim = math.ceil(self.num_parameters / self.num_embeddings)
                if self.hidden_dim != 0:
                    remainder = self.num_parameters % self.hidden_dim
                    if remainder > 0:
                        diff = math.ceil(remainder / self.hidden_dim)
                        self.num_embeddings += diff

        @nn.nowrap
        def get_networks(self) -> Tuple[FlaxEmbeddingModule, FlaxWeightGenerator]:
            embedding_module = self.embedding_module_constructor(
                self.embedding_dim, self.num_embeddings
            )
            weight_generator = self.weight_generator_constructor(
                self.embedding_dim, self.hidden_dim
            )
            return embedding_module, weight_generator

        def generate_params(self, x: Optional[Any] = None, *args, **kwargs) -> List[jnp.array]:
            embeddings = self.embedding_module()
            params = self.weight_generator(embeddings).reshape(-1)
            param_list = []
            curr = 0
            for shape in self.target_weight_shapes:
                num_params = np.prod(shape)
                param_list.append(params[curr:curr+num_params].reshape(shape))
                curr = curr+num_params
            return param_list

        def __call__(self, x: Any, params: Optional[List[jnp.array]] = None) -> Tuple[jnp.array, List[jnp.array]]:
            if params is None:
                params = self.generate_params(x)
            param_tree = jax.tree_util.tree_unflatten(self.target_treedef, params)
            return self._target.apply(param_tree, x), params

    return FlaxHyperNetwork(target_network, embedding_module_constructor, weight_generator_constructor, embedding_dim)


In [416]:
class MLP(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(4, use_bias=False)(x)
        return x

In [417]:
target_network = MLP()

In [418]:
hyper = FlaxHyperNetwork((1, 8), target_network)

In [446]:
from tensorflow_probability.substrates import jax as tfp

def rollout(env, hypernetwork, hypernetwork_params, render=False, seed: int = 0) -> float:
    _, target_params = hypernetwork.apply({'params':hypernetwork_params}, jnp.zeros((1,8)))
    rng = jax.random.PRNGKey(seed)
    obs = env.reset()
    done = False
    observations, actions, rewards, rendereds = [], [], [], []
    while not done:
        rendered = None
        if render:
            rendered = env.render(mode="rgb_array")
            rendereds.append(rendered)

        out, _ =  hypernetwork.apply(hypernetwork_params, jnp.expand_dims(jnp.array(obs), 0), target_params)
        # action_logits = hypernetwork(), params=params)
        dist = tfp.distributions.Categorical(logits=out)
        action = dist.sample(seed=rng).item()
        next_obs, r, done, _ = env.step(action)

        observations.append(obs)
        actions.append(action)
        rewards.append(r)

        obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

In [447]:
rng = jax.random.PRNGKey(0)
params = hyper.init(rng, jnp.ones((1,8)))['params']


In [435]:
import gym

env = gym.make("LunarLander-v2")

In [436]:
rollout(env, hyper, params, render=False)

([array([ 1.7147065e-04,  1.4184132e+00,  1.7347123e-02,  3.3302727e-01,
         -1.9183687e-04, -3.9293827e-03,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([ 2.40898138e-04,  1.42650437e+00,  7.63659459e-03,  3.59606147e-01,
         -8.76591250e-04, -1.36952475e-02,  0.00000000e+00,  0.00000000e+00],
        dtype=float32),
  array([ 1.8815995e-04,  1.4345708e+00, -3.9888453e-03,  3.5850966e-01,
         -2.1495728e-03, -2.5461977e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-1.7642975e-05,  1.4431829e+00, -1.8539604e-02,  3.8275427e-01,
         -4.1639209e-03, -4.0291034e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-8.1443788e-05,  1.4526235e+00, -5.0052027e-03,  4.1957754e-01,
         -5.5167424e-03, -2.7058903e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([-2.5501251e-04,  1.4624420e+00, -1.5443854e-02,  4.3637249e-01,
         -7.4127149e-03, -3.7923016e-02,  0.0000000e+00,  0.0000

In [448]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import tqdm

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: tensorboard --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [450]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, learning_rate, input_shape, target_network):
    """Creates initial `TrainState`."""
    hypernet = FlaxHyperNetwork(input_shape, target_network)
    params = hypernet.init(rng, jnp.ones((1, 8)))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=hypernet.apply, params=params, tx=tx)

In [451]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

state = create_train_state(init_rng, 0.0001, (1,8), target_network)


In [452]:
state.params

FrozenDict({
    embedding_module: {
        embedding: {
            kernel: DeviceArray([[ 0.2794615 ,  0.614751  , -0.3520885 ,  0.07950468,
                           0.29180396,  0.6010681 , -0.06202558, -0.8630064 ,
                          -0.06942614,  0.35426682, -0.5740729 ,  1.0088627 ,
                          -0.7688579 , -0.61084026,  0.20705049, -0.57455933,
                           0.50739807, -0.43176562, -0.1687114 ,  0.2547729 ,
                           0.5349665 ,  0.12229237, -0.61539537,  0.2644072 ,
                           0.6914668 , -0.41548407, -0.14418218, -0.1401382 ,
                          -0.52232295,  0.58492476, -0.74866724,  0.57164043,
                           0.4839978 , -0.93924844, -0.70880425,  0.08879238,
                          -0.07147714,  0.48945302, -0.4213996 , -0.01919367,
                          -0.21136707,  0.20323907, -0.3379675 , -0.53174704,
                          -0.198072  ,  0.09982792,  0.86448824,  0.86292523

In [467]:
@jax.jit
def train_step(state, observations, actions, discounted_rewards):
    def loss_fn(params):
        logits, _ = hyper.apply({'params':params}, observations)
        dist = tfp.distributions.Categorical(logits=logits)
        log_probs = dist.log_prob(actions)
        loss = -1 * jnp.sum(discounted_rewards * log_probs)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, grads


In [465]:
def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        target_network,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):
    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, lr, (1,8), target_network)

    bar = tqdm.tqdm(np.arange(num_epochs))

    for i in bar:
        observations, actions, rewards, _ = rollout(env, hypernetwork, state.params)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = jnp.array(observations)
        actions = jnp.array(actions)
        discounted_rewards = jnp.array(discounted_rewards)

        state, loss, grads = train_step(state, observations, actions, discounted_rewards)

        print(grads)
        # metrics = {"loss":loss.item(), "rewards":np.sum(rewards)}

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [466]:
reinforce(10000, env, hyper, target_network)

  0%|          | 0/10000 [00:03<?, ?it/s]


ValueError: All input arrays must have the same shape.