In [1]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [2]:
from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork

In [3]:
class StaticEmbeddingModule(FlaxEmbeddingModule):

    def setup(self):
        self.embedding = nn.Dense(self.embedding_dim, use_bias=False)

    def __call__(self):
        indices = jax.nn.one_hot(jax.numpy.arange(0,self.num_embeddings), self.num_embeddings)
        return self.embedding(indices)


In [4]:
class StaticWeightGenerator(FlaxWeightGenerator):

    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array):
        x = self.dense1(embedding)
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [5]:
class MLP(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(4, use_bias=False)(x)
        return x

In [10]:
from tensorflow_probability.substrates import jax as tfp

import jax
import functools

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def sample_actions(apply_fn, hypernetwork_params, obs):
    out, _ =  apply_fn({'params':hypernetwork_params}, jnp.expand_dims(jnp.array(obs), 0))
    return out

def rollout(env, hypernetwork, hypernetwork_params, render=False, seed: int = 0) -> float:
    _, target_params = hypernetwork.apply({'params':hypernetwork_params}, jnp.zeros((1,8)))
    rng = jax.random.PRNGKey(seed)
    obs = env.reset()
    done = False
    observations, actions, rewards, rendereds = [], [], [], []
    while not done:
        rendered = None
        if render:
            rendered = env.render(mode="rgb_array")
            rendereds.append(rendered)

        out =  sample_actions(hypernetwork.apply, hypernetwork_params, np.expand_dims(obs, 0))
        # action_logits = hypernetwork(), params=params)
        dist = tfp.distributions.Categorical(logits=out)
        action = dist.sample(seed=rng).item()
        next_obs, r, done, _ = env.step(action)

        observations.append(obs)
        actions.append(action)
        rewards.append(r)

        obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

In [15]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import tqdm

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: tensorboard --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [16]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, hypernet, learning_rate, input_shape, target_network):
    """Creates initial `TrainState`."""
    params = hypernet.init(rng, jnp.ones((1, 8)))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=hypernet.apply, params=params, tx=tx)

In [None]:
@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_step(apply_fn, state, observations, actions, discounted_rewards):
    def loss_fn(params):
        logits, _ = apply_fn({'params':params}, observations)
        dist = tfp.distributions.Categorical(logits=logits)
        log_probs = dist.log_prob(actions)
        loss = -1 * jnp.sum(discounted_rewards * log_probs)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, grads


In [None]:
def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        target_network,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):
    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(hypernetwork, init_rng, lr, (1,8), target_network)

    bar = tqdm.tqdm(np.arange(num_epochs))

    for i in bar:
        observations, actions, rewards, _ = rollout(env, hypernetwork, state.params)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = jnp.array(observations)
        actions = jnp.array(actions)
        discounted_rewards = jnp.array(discounted_rewards)

        state, loss, grads = train_step(hypernetwork.apply, state, observations, actions, discounted_rewards)

        # metrics = {"loss":loss.item(), "rewards":np.sum(rewards)}

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)


In [None]:
import gym

env = gym.make("LunarLander-v2")
target_network = MLP()
hyper = FlaxHyperNetwork((1, 8), target_network, embedding_module_constructor=StaticEmbeddingModule, embedding_dim=32, num_embeddings=512)

reinforce(10000, env, hyper, target_network)