In [None]:
# Uncomment this to install dependencies for this notebook
# !pip install gym
# !pip install tqdm
# !pip install tensorboard
# !pip install matplotlib
# !pip install JSAnimation

In [None]:
# uncomment this to enable jax gpu preallocation, might lead to memory issues

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [None]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [None]:
from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork

In [None]:
from typing import Optional, Any

class DefaultFlaxEmbeddingModule(FlaxEmbeddingModule):
    def setup(self):
        self.embedding = nn.Embed(self.num_embeddings, self.embedding_dim)

    def __call__(self, inp: Optional[Any] = None):
        indices = jnp.arange(0, self.num_embeddings)
        return self.embedding(indices)


In [None]:
class StaticFlaxWeightGenerator(FlaxWeightGenerator):
    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array, inp: Optional[Any] = None):
        x = self.dense1(embedding)
        x = nn.relu(x)
        x = self.dense2(x)
        return x

In [None]:
class MLP(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(256)(x)
        x = nn.tanh(x)
        x = nn.Dense(4, use_bias=False)(x)
        return x

In [None]:
from tensorflow_probability.substrates import jax as tfp

import jax
import functools

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def sample_actions(apply_fn, hypernetwork_params, obs):
    out, _ =  apply_fn({'params':hypernetwork_params}, jnp.expand_dims(jnp.array(obs), 0))
    return out

def rollout(env, hypernetwork, hypernetwork_params, render=False, seed: int = 0) -> float:
    obs = env.reset()
    done = False
    observations, actions, rewards, rendereds = [], [], [], []
    while not done:
        rendered = None
        if render:
            rendered = env.render(mode="rgb_array")
            rendereds.append(rendered)

        out = sample_actions(hypernetwork.apply, hypernetwork_params, obs)
        out = jnp.squeeze(out)
        dist = tfp.distributions.Categorical(logits=out)
        rng = jax.random.PRNGKey(np.random.randint(0, 10**10))
        action = dist.sample(seed=rng).item()
        next_obs, r, done, _ = env.step(action)

        observations.append(obs)
        actions.append(action)
        rewards.append(r)

        obs = next_obs

    num_steps = len(observations)
    env.close()
    return observations, actions, rewards, rendereds, num_steps

In [None]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import tqdm

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: tensorboard --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [None]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, hypernetwork, learning_rate, input_shape):
    """Creates initial `TrainState`."""
    params = hypernetwork.init(rng, jnp.ones(input_shape))['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=hypernetwork.apply, params=params, tx=tx)

In [None]:
@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_step(apply_fn, state, observations, actions, discounted_rewards):
    def loss_fn(params):
        logits, _ = apply_fn({'params':params}, observations)
        logits = jnp.array(logits)
        dist = tfp.distributions.Categorical(logits=jnp.squeeze(logits))
        log_probs = jnp.squeeze(dist.log_prob(actions))
        loss = jnp.sum(discounted_rewards * log_probs)
        return -1*loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads


In [None]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [None]:
import pandas as pd

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        seed: int = 0,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):
    writer = get_tensorboard_logger("HypernetworkJaxRL")

    rng = jax.random.PRNGKey(seed)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, hypernetwork, lr, (1,8))

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            observations, actions, rewards, _, num_steps = rollout(env, hypernetwork, state.params, seed=seed)
            discounted_rewards = discount_reward(np.array(rewards), 1.0)
            discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
            discounted_rewards = discounted_rewards / (
                np.std(discounted_rewards) + 1e-10
            )

            observations = jnp.array(observations)
            actions = jnp.array(actions)
            discounted_rewards = jnp.array(discounted_rewards)

            state, loss, grads = train_step(hypernetwork.apply, state, observations, actions, discounted_rewards)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            grad_dict = flatten(grad_dict)

            grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
                        for k, v in grad_dict.items()}
            grad_dict = flatten(grad_dict)

            metrics = {"loss":loss.item(), "rewards":np.sum(rewards), "num_steps":num_steps, **grad_dict}

            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))
        return hypernetwork, state
    except KeyboardInterrupt as e:
        return hypernetwork, state


In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)


In [None]:
import gym

env = gym.make("LunarLander-v2")
hyper = FlaxHyperNetwork((1, 8), MLP(), embedding_module_constructor=DefaultFlaxEmbeddingModule, weight_generator_constructor=StaticFlaxWeightGenerator, embedding_dim = 32, num_embeddings = 512)

In [None]:
target_num_params = hyper.num_parameters
print(target_num_params)

In [None]:
params = hyper.init(jax.random.PRNGKey(0), jnp.ones((1, 8)))

In [None]:
hyper, state = reinforce(100000, env, hyper, seed=10, lr=0.0001)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# Imports specifically so we can render outputs in Jupyter.
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display


def render_rollout(model, state):
    fig = plt.figure("Animation",figsize=(7,5))
    ax = fig.add_subplot(111)
    observations, actions, rewards, rendereds, num_steps = rollout(gym.make("LunarLander-v2"), model, state.params, render=True)
    frames = []
    for r in rendereds:
        frame = ax.imshow(r)
        ax.axis('off')
        frames.append([frame])
    anim = animation.ArtistAnimation(fig, frames, interval=50, blit=True)
    display(display_animation(anim, default_mode='loop'))


In [None]:
render_rollout(hyper, state)