In [None]:
# Uncomment this to install dependencies for this notebook
# !pip install gym
# !pip install tqdm
# !pip install tensorboard
# !pip install ipympl

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import tqdm
import torch.optim as optim


In [None]:
from hypernn.torch.hypernet import TorchHyperNetwork
from hypernn.torch.weight_generator import TorchWeightGenerator
from hypernn.torch.embedding_module import TorchEmbeddingModule

## Basic Hypernetwork

### 3 Components:
- EmbeddingModule: Layer / general parameter embeddings to be passed into the weight generator
- WeightGenerator: Shared network used to generate parameters from each embedding in the EmbeddingModule
- Hypernetwork: Combines EmbeddingModule & WeightGenerator for end to end parameter generator

### EmbeddingModule

In [None]:
from typing import Optional, Any

class DefaultTorchEmbeddingModule(TorchEmbeddingModule):
    def __init__(self, embedding_dim: int, num_embeddings: int, input_shape: Optional[Any] = None):
        super().__init__(embedding_dim, num_embeddings, input_shape)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)

    def forward(self, inp: Optional[Any] = None, *args, **kwargs):
        indices = torch.arange(self.num_embeddings).to(self.device)
        return self.embedding(indices)


### Weight Generator

In [None]:
class DeepTorchWeightGenerator(TorchWeightGenerator):
    def __init__(self, embedding_dim: int, num_embeddings: int, hidden_dim: int, input_shape: Optional[Any] = None):
        super().__init__(embedding_dim, num_embeddings, hidden_dim, input_shape)
        self.linear1 = nn.Linear(embedding_dim, 32)
        self.linear2 = nn.Linear(32, hidden_dim)

    def forward(
        self, embedding: torch.Tensor, inp: Optional[Any] = None
    ) -> torch.Tensor:
        x = self.linear1(embedding)
        x = F.relu(x)
        return self.linear2(x).view(-1)



### Making a LunarLander HyperNetwork

#### Big Target Network

In [None]:
target_network = nn.Sequential(
    nn.Linear(8, 256, bias=False),
    nn.Tanh(),
    nn.Linear(256,256, bias=False),
    nn.Tanh(),
    nn.Linear(256, 4, bias=False)
)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
EMBEDDING_DIM = 32
NUM_EMBEDDINGS = 512

embedding_network = DefaultTorchEmbeddingModule.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS)
weight_generator = DeepTorchWeightGenerator.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS)

#### Much smaller Hypernetwork, with lots of parameter sharing

In [None]:
hypernetwork = TorchHyperNetwork(
                            (1, 8),
                            target_network,
                            embedding_module=embedding_network,
                            weight_generator=weight_generator,
                        )
pytorch_total_params = sum(p.numel() for p in hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
hypernetwork

#### Tensorboard logging

In [None]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [None]:
def rollout(env, hypernetwork, render=False) -> float:
    with torch.no_grad():
        params, embeddings = hypernetwork.generate_params()
        obs = env.reset()
        done = False
        observations, actions, rewards, rendereds = [], [], [], []
        while not done:
            rendered = None
            if render:
                rendered = env.render(mode="rgb_array")
                rendereds.append(rendered)

            action_logits, params, _ = hypernetwork(torch.from_numpy(obs).unsqueeze(0).to(hypernetwork.device), params=params)
            dist = Categorical(logits=action_logits)
            action = dist.sample().item()
            next_obs, r, done, _ = env.step(action)

            observations.append(obs)
            actions.append(action)
            rewards.append(r)

            obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):

    writer = get_tensorboard_logger("HypernetworkTorchRL")
    optimizer = optim.Adam(hypernetwork.parameters(), lr=lr)

    bar = tqdm.tqdm(np.arange(num_epochs))
    for i in bar:
        observations, actions, rewards, _ = rollout(env, hypernetwork)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = torch.from_numpy(np.array(observations)).float().to(hypernetwork.device)
        actions = torch.from_numpy(np.array(actions)).float().to(hypernetwork.device)
        discounted_rewards = torch.from_numpy(discounted_rewards).float().to(hypernetwork.device)

        logits, params, embeddings = hypernetwork(observations)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)

        optimizer.zero_grad()
        loss = -1 * torch.sum(discounted_rewards * log_probs)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 10.0)
        optimizer.step()

        grad_dict = {}
        for n, W in hypernetwork.named_parameters():
            if W.grad is not None:
                grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

        metrics = {"loss":loss.item(), "rewards":np.sum(rewards), **grad_dict}

        for key in metrics:
            writer.add_scalar(key, metrics[key], i)

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [None]:
env = gym.make("LunarLander-v2")

In [None]:
reinforce(100000, env, hypernetwork, lr=0.0001)

In [None]:
env = gym.make("LunarLander-v2")
observations, actions, rewards, rendereds = rollout(env, hypernetwork, render=True)

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib ipympl

plt.style.use('ggplot')
# Imports specifically so we can render outputs in Jupyter.
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display
from celluloid import Camera
from IPython.display import HTML


def render_rollout(model):
    fig = plt.figure("Animation",figsize=(7,5))
    camera = Camera(fig)
    ax = fig.add_subplot(111)
    observations, actions, rewards, rendereds = rollout(gym.make("LunarLander-v2"), model, render=True)
    frames = []
    for r in rendereds:
        frame = ax.imshow(r)
        ax.axis('off')
        camera.snap()
        frames.append([frame])
    animation = camera.animate(blit=False, interval=50)
    # display(animations.to_html5_video())
    animation.save('animation.mp4')
    return animation


In [None]:
vid = render_rollout(hypernetwork)

In [None]:
from IPython.display import Video

Video("animation.mp4")

## Dynamic Hypernetwork

In [None]:

from typing import Optional, Any, Tuple
import functools
import torch.nn.functional as F

class DynamicTorchEmbeddingModule(TorchEmbeddingModule):
    def __init__(self, embedding_dim: int, num_embeddings: int, input_shape):
        super().__init__(embedding_dim, num_embeddings)
        self.rnn_hidden_dim = num_embeddings
        self.gru = nn.GRUCell(np.prod(input_shape), num_embeddings)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)

    def forward(self, inp: Tuple[Any, torch.Tensor], hidden_state: Optional[torch.Tensor] = None):
        # note this only works with batch size of 1
        x = inp
        if hidden_state is None:
            hidden_state = torch.zeros(x.size(0), self.rnn_hidden_dim).to(self.device)
        hidden_state = torch.sigmoid(self.gru(x, hidden_state))
        indices = torch.arange(self.num_embeddings).to(self.device)
        embedding = self.embedding(indices)*hidden_state.view(self.num_embeddings, 1)
        return embedding, hidden_state

class DynamicTorchWeightGenerator(TorchWeightGenerator):
    def __init__(self, embedding_dim: int, hidden_dim: int, num_embeddings: int, input_shape: Optional[Any] = None):
        super().__init__(embedding_dim, hidden_dim, num_embeddings, input_shape)
        self.linear1 = nn.Linear(embedding_dim, 32)
        self.linear2 = nn.Linear(32, hidden_dim)

    def forward(
        self, embedding: Tuple[torch.Tensor, torch.Tensor], inp: Optional[Any] = None
    ) -> torch.Tensor:
        x = self.linear1(embedding[0])
        x = F.relu(x)
        return self.linear2(x).view(-1)


In [None]:
target_network = nn.Sequential(
    nn.Linear(8, 256, bias=False),
    nn.Tanh(),
    nn.Linear(256,256, bias=False),
    nn.Tanh(),
    nn.Linear(256, 4, bias=False)
)

dynamic_hypernetwork = TorchHyperNetwork(
                            (1,8),
                            target_network,
                            embedding_module_constructor=DynamicTorchEmbeddingModule,
                            weight_generator_constructor=DynamicTorchWeightGenerator,
                            embedding_dim = 8,
                            num_embeddings = 512
                        )

In [None]:
out, parameters, embedding_output = dynamic_hypernetwork(torch.zeros(1,8), embedding_kwargs={"hidden_state":None})

In [None]:
def dynamic_rollout(env, dynamic_hypernetwork, render=False) -> float:
    with torch.no_grad():
        obs = env.reset()
        done = False
        observations, actions, rewards, rendereds = [], [], [], []
        hidden = None
        while not done:
            rendered = None
            if render:
                rendered = env.render(mode="rgb_array")
                rendereds.append(rendered)

            action_logits, _, embedding_output = dynamic_hypernetwork(torch.from_numpy(obs).unsqueeze(0).to(dynamic_hypernetwork.device), embedding_kwargs={"hidden_state":hidden})
            hidden = embedding_output[1]
            # action_logits, params, _ = hypernetwork(torch.from_numpy(obs).unsqueeze(0), params=params)
            dist = Categorical(logits=action_logits)
            action = dist.sample().item()
            next_obs, r, done, _ = env.step(action)

            observations.append(obs)
            actions.append(action)
            rewards.append(r)

            obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def dynamic_reinforce(
        num_epochs,
        env,
        hypernetwork,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):

    writer = get_tensorboard_logger("HypernetworkTorchRL")
    optimizer = optim.Adam(hypernetwork.parameters(), lr=lr)

    bar = tqdm.tqdm(np.arange(num_epochs))
    for i in bar:
        observations, actions, rewards, _ = dynamic_rollout(env, hypernetwork)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = torch.from_numpy(np.array(observations)).float().to(hypernetwork.device)
        actions = torch.from_numpy(np.array(actions)).float().to(hypernetwork.device)
        discounted_rewards = torch.from_numpy(discounted_rewards).float().to(hypernetwork.device)

        log_probs = []

        hidden = None
        for j in range(observations.size(0)):
            logits, _, embedding_output = hypernetwork(observations[j:j+1], embedding_kwargs={"hidden_state":hidden})
            hidden = embedding_output[1]
            dist = Categorical(logits=logits)
            log_probs.append(dist.log_prob(actions))

        log_probs = torch.stack(log_probs).squeeze()
        optimizer.zero_grad()
        loss = -1 * torch.sum(discounted_rewards * log_probs)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 10.0)
        optimizer.step()

        grad_dict = {}
        for n, W in hypernetwork.named_parameters():
            if W.grad is not None:
                grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

        metrics = {"loss":loss.item(), "rewards":np.sum(rewards), **grad_dict}

        for key in metrics:
            writer.add_scalar(key, metrics[key], i)

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [None]:
dynamic_hypernetwork = TorchHyperNetwork(
                            (1,8),
                            target_network,
                            embedding_module_constructor=DynamicTorchEmbeddingModule,
                            weight_generator_constructor=DynamicTorchWeightGenerator,
                            embedding_dim = 32,
                            num_embeddings = 512
                        )

device = torch.device('cuda')
dynamic_hypernetwork = dynamic_hypernetwork.to(device)

In [None]:
env = gym.make("LunarLander-v2")

dynamic_reinforce(100000, env, dynamic_hypernetwork, lr=0.00002)

In [None]:
param_vector, embeddings = dynamic_hypernetwork.generate_params((torch.zeros(1,8), None))

In [None]:
params = []
start = 0
for name, p in hypernetwork._target.named_params:
    end = start + np.prod(p.size())
    params.append(param_vector[start:end].view(p.size()).detach().cpu().numpy())
    start = end

In [None]:
[p.shape for p in params]

In [None]:
f, (a0, a1) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 1]})
a0.imshow(params[0])
a1.imshow(params[1])
f.tight_layout()


In [None]:
import matplotlib.pyplot as plt

def plot_params(params):
    fig = plt.figure(figsize=(12, 12))
    num_params = len(params)
    for i in range(num_params):
        fig.add_subplot(1,num_params,i+1, )
        p = params[i]
        if len(p.shape) == 1:
            p = np.expand_dims(p,1)
        plt.imshow(p)
    fig.tight_layout()
    plt.show()

plot_params(params)


In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 12))
fig.add_subplot(1,2,1)
plt.imshow(params[0])
fig.add_subplot(1,2,2)
plt.imshow(params[1])
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
plt.show()

