### Pytorch

In [12]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import tqdm
import torch.optim as optim


In [15]:
from hypernn.torch.static_hypernet import StaticTorchHyperNetwork

In [16]:
target_network = nn.Sequential(
    nn.Linear(8, 256, bias=False),
    nn.Tanh(),
    nn.Linear(256,256, bias=False),
    nn.Tanh(),
    nn.Linear(256, 4, bias=False)
)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

68608

In [19]:
from typing import Any, Dict, List, Optional, Tuple, Type, Union  # noqa

class LunarLanderHypernetwork(StaticTorchHyperNetwork):
    def __init__(
        self,
        target_network: nn.Module,
        num_target_parameters: Optional[int] = None,
        embedding_dim: int = 100,
        num_embeddings: int = 3,
        hidden_dim: Optional[int] = None,
        *args, **kwargs
    ):
        super().__init__(
                    target_network = target_network,
                    num_target_parameters = num_target_parameters,
                    embedding_dim = embedding_dim,
                    num_embeddings = num_embeddings,
                    hidden_dim = hidden_dim,
                )

    def make_weight_generator(self):
        return nn.Sequential(
            nn.Linear(self.embedding_dim, 32),
            nn.Tanh(),
            nn.Linear(32, self.hidden_dim)
        )


In [20]:
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 512

hypernetwork = LunarLanderHypernetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

In [21]:
pytorch_total_params = sum(p.numel() for p in hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

6630

In [34]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [None]:
import gym

def rollout(env, hypernetwork, render=False) -> float:
    with torch.no_grad():
        params, _ = hypernetwork.generate_params()
        obs = env.reset()
        done = False
        observations, actions, rewards, rendereds = [], [], [], []
        while not done:
            rendered = None
            if render:
                rendered = env.render(mode="rgb_array")
                rendereds.append(rendered)

            action_logits = hypernetwork(inp=[torch.from_numpy(obs).unsqueeze(0).to(hypernetwork.device)], generated_params=params, has_aux=False)
            dist = Categorical(logits=action_logits)
            action = dist.sample().item()
            next_obs, r, done, _ = env.step(action)

            observations.append(obs)
            actions.append(action)
            rewards.append(r)

            obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        rollout_fn=rollout,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):

    writer = get_tensorboard_logger("HypernetworkTorchRL")
    optimizer = optim.Adam(hypernetwork.parameters(), lr=lr)

    bar = tqdm.tqdm(np.arange(num_epochs))
    for i in bar:
        observations, actions, rewards, _ = rollout_fn(env, hypernetwork)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = torch.from_numpy(np.array(observations)).float().to(hypernetwork.device)
        actions = torch.from_numpy(np.array(actions)).float().to(hypernetwork.device)
        discounted_rewards = torch.from_numpy(discounted_rewards).float().to(hypernetwork.device)

        logits, generated_params, aux_output = hypernetwork(inp=[observations], has_aux=True)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)

        optimizer.zero_grad()
        loss = -1 * torch.sum(discounted_rewards * log_probs)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 10.0)
        optimizer.step()

        grad_dict = {}
        for n, W in hypernetwork.named_parameters():
            if W.grad is not None:
                grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

        metrics = {"loss":loss.item(), "rewards":np.sum(rewards), **grad_dict}

        for key in metrics:
            writer.add_scalar(key, metrics[key], i)

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [None]:
env = gym.make("LunarLander-v2")

In [None]:
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 512

hypernetwork = LunarLanderHypernetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

In [None]:
reinforce(100000, env, hypernetwork, lr=0.0001)

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib ipympl

plt.style.use('ggplot')
# Imports specifically so we can render outputs in Jupyter.
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display
from celluloid import Camera
from IPython.display import HTML


def render_rollout(model):
    fig = plt.figure("Animation",figsize=(7,5))
    camera = Camera(fig)
    ax = fig.add_subplot(111)
    observations, actions, rewards, rendereds = rollout(gym.make("LunarLander-v2"), model, render=True)
    frames = []
    for r in rendereds:
        frame = ax.imshow(r)
        ax.axis('off')
        camera.snap()
        frames.append([frame])
    animation = camera.animate(blit=False, interval=50)
    # display(animations.to_html5_video())
    animation.save('LunarLander.mp4')
    return animation


In [None]:
animation = render_rollout(hypernetwork)

In [None]:
from IPython.display import Video

Video("LunarLander.mp4")

### Jax

In [25]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

# uncomment this to enable jax gpu preallocation, might lead to memory issues

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [26]:
from hypernn.jax.static_hypernet import StaticFlaxHyperNetwork
from hypernn.jax.utils import count_jax_params

In [27]:
class MLP(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256, use_bias=False)(x)
        x = nn.tanh(x)
        x = nn.Dense(256, use_bias=False)(x)
        x = nn.tanh(x)
        x = nn.Dense(4, use_bias=False)(x)
        return x

In [28]:
target_network = MLP()

count_jax_params(target_network, inputs=jnp.zeros((1,8)))

68608

In [29]:
class WeightGenerator(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(32, use_bias=True)(x)
        x = nn.tanh(x)
        x = nn.Dense(self.hidden_dim, use_bias=True)(x)
        return x

In [30]:
from typing import Any, Dict, List, Optional, Tuple, Type, Union  # noqa

class LunarLanderHypernetwork(StaticFlaxHyperNetwork):

    def make_weight_generator(self):
        return WeightGenerator(self.hidden_dim)


In [31]:
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 512

hyper = LunarLanderHypernetwork.from_target(
                target_network=target_network,
                target_input_shape=[(1,8)],
                embedding_dim = EMBEDDING_DIM,
                num_embeddings = NUM_EMBEDDINGS
)

In [32]:
hyper.hidden_dim

134

In [33]:
count_jax_params(hyper, inputs=[jnp.zeros((1,8))])

6630

In [38]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state
import jax

def create_train_state(rng, hypernetwork, learning_rate, input_shape):
    """Creates initial `TrainState`."""
    params = hypernetwork.init(rng, [jnp.ones(input_shape)])['params']
    # Exponential decay of the learning rate.
    scheduler = optax.exponential_decay(
        init_value=learning_rate, 
        transition_steps=1000,
        decay_rate=0.99)

    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        # optax.scale_by_adam(),  # Use the updates from adam.
        # optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
        # # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
        # optax.scale(-1.0)
        optax.adam(learning_rate)
    )
    return train_state.TrainState.create(
        apply_fn=hypernetwork.apply, params=params, tx=tx)

In [45]:
import functools

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_step(apply_fn, state, observations, actions, discounted_rewards):
    def loss_fn(params):
        logits = apply_fn({'params':params}, inp=[observations], has_aux=False)
        dist = tfp.distributions.Categorical(logits=jnp.squeeze(logits))
        log_probs = jnp.squeeze(dist.log_prob(actions))
        loss = jnp.sum(discounted_rewards * log_probs)
        return -1*loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads


In [46]:
from tensorflow_probability.substrates import jax as tfp

import jax
import functools

@functools.partial(jax.jit, static_argnames=('apply_fn', 'method'))
def generate_params(apply_fn, hypernetwork_params, method):
    generated_params, aux_output = apply_fn({'params':hypernetwork_params}, method=method)
    return generated_params

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def sample_actions(apply_fn, hypernetwork_params, obs, generated_params):
    out =  apply_fn({'params':hypernetwork_params}, inp=[jnp.expand_dims(jnp.array(obs), 0)], generated_params=generated_params, has_aux=False)
    return out

def rollout(env, apply_fn, generate_params_fn, hypernetwork_params, render=False, seed: int = 0) -> float:
    obs = env.reset()
    done = False
    observations, actions, rewards, rendereds = [], [], [], []
    generated_params = generate_params(apply_fn, hypernetwork_params, generate_params_fn)
    while not done:
        rendered = None
        if render:
            rendered = env.render(mode="rgb_array")
            rendereds.append(rendered)

        out = sample_actions(apply_fn, hypernetwork_params, obs, generated_params)
        out = jnp.squeeze(out)
        dist = tfp.distributions.Categorical(logits=out)
        rng = jax.random.PRNGKey(np.random.randint(0, 10**10))
        action = dist.sample(seed=rng).item()

        next_obs, r, done, _ = env.step(action)

        observations.append(obs)
        actions.append(action)
        rewards.append(r)

        obs = next_obs

    num_steps = len(observations)
    env.close()
    return observations, actions, rewards, rendereds, num_steps

In [246]:
import pandas as pd
import tqdm

# def flatten(d):
#     df = pd.json_normalize(d, sep='_')
#     return df.to_dict(orient='records')[0]

from flax.core.frozen_dict import FrozenDict
import collections

def flatten(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping) or isinstance(v, FrozenDict):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        seed: int = 0,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):
    writer = get_tensorboard_logger("HypernetworkJaxRL")

    rng = jax.random.PRNGKey(seed)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, hypernetwork, lr, (1,8))

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            observations, actions, rewards, _, num_steps = rollout(env, hypernetwork.apply, hypernetwork.generate_params, state.params, seed=seed)
            discounted_rewards = discount_reward(np.array(rewards), gamma)
            discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
            discounted_rewards = discounted_rewards / (
                np.std(discounted_rewards) + 1e-10
            )

            observations = jnp.array(observations)
            actions = jnp.array(actions)
            discounted_rewards = jnp.array(discounted_rewards)

            state, loss, grads = train_step(hypernetwork.apply, state, observations, actions, discounted_rewards)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            grad_dict = flatten(grad_dict)
            grad_dict = {k:np.sum(np.array(grad_dict[k])).item() for k in grad_dict}

            metrics = {"loss":loss.item(), "rewards":np.sum(rewards), "num_steps":num_steps, **grad_dict}

            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))
        return hypernetwork, state, grad_dict
    except KeyboardInterrupt as e:
        return hypernetwork, state, grad_dict


In [250]:
target_network = MLP()

In [251]:
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 512

hyper = LunarLanderHypernetwork.from_target(
                target_network=target_network,
                target_input_shape=[(1,8)],
                embedding_dim = EMBEDDING_DIM,
                num_embeddings = NUM_EMBEDDINGS
)

In [252]:
import gym

env = gym.make("LunarLander-v2")

hyper, state, grad_dict = reinforce(100000, env, hyper, seed=0, lr=0.0001)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/tensorboard_logs/HypernetworkJaxRL_2022-04-26 12:42:55.752280'


Loss: -12.134172439575195, Sum Reward: -65.94675015305978:   3%|▎         | 3022/100000 [54:59<29:24:50,  1.09s/it]   
