In [1]:
# Uncomment this to install dependencies for this notebook
# !pip install gym
# !pip install tqdm
# !pip install tensorboard

In [1]:
class Test:
    def __init__(self, x, y=2):
        self.x = x
        self.y = y

In [2]:
import functools

In [7]:
a = functools.partial(Test, x=2)()

In [9]:
a.x

2

In [2]:
import gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import tqdm
import torch.optim as optim


In [3]:
from hypernn.torch.hypernet import TorchHyperNetwork
from hypernn.torch.weight_generator import TorchWeightGenerator
from hypernn.torch.embedding_module import TorchEmbeddingModule

## Basic Hypernetwork

### 3 Components:
- EmbeddingModule: Layer / general parameter embeddings to be passed into the weight generator
- WeightGenerator: Shared network used to generate parameters from each embedding in the EmbeddingModule
- Hypernetwork: Combines EmbeddingModule & WeightGenerator for end to end parameter generator

### StaticEmbeddingModule

In [4]:
class StaticEmbeddingModule(TorchEmbeddingModule):
    def __init__(self, embedding_dim: int, num_embeddings: int):
        super().__init__(embedding_dim, num_embeddings)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)

    def forward(self):
        indices = torch.arange(self.num_embeddings)
        return self.embedding(indices)


### Static Weight Generator

In [5]:
class StaticWeightGenerator(TorchWeightGenerator):
    def __init__(self, embedding_dim: int, hidden_dim: int):
        super().__init__(embedding_dim, hidden_dim)
        self.embedding_network = nn.Sequential(
            nn.Linear(embedding_dim, 32),
            nn.ReLU(),
            nn.Linear(32, hidden_dim),
        )

    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
        return self.embedding_network(embedding).view(-1)


### Making a LunarLander HyperNetwork

#### Big Target Network

In [6]:
target_network = nn.Sequential(
    nn.Linear(8, 256),
    nn.Tanh(),
    nn.Linear(256,256),
    nn.Tanh(),
    nn.Linear(256, 4, bias=False)
)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

69120

#### Much smaller Hypernetwork, with lots of parameter sharing

In [7]:
hypernetwork = TorchHyperNetwork(
                            target_network,
                            embedding_module_constructor=StaticEmbeddingModule,
                            weight_generator_constructor=StaticWeightGenerator,
                            embedding_dim = 32,
                            num_embeddings = 512
                        )
pytorch_total_params = sum(p.numel() for p in hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

21895

In [8]:
def rollout(env, hypernetwork, render=False) -> float:
    with torch.no_grad():
        params = hypernetwork.generate_params()
        obs = env.reset()
        done = False
        observations, actions, rewards, rendereds = [], [], [], []
        while not done:
            rendered = None
            if render:
                rendered = env.render(mode="rgb_array")
                rendereds.append(rendered)

            action_logits, _ = hypernetwork(torch.from_numpy(obs).unsqueeze(0), params=params)
            dist = Categorical(logits=action_logits)
            action = dist.sample().item()
            next_obs, r, done, _ = env.step(action)

            observations.append(obs)
            actions.append(action)
            rewards.append(r)

            obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

#### Tensorboard logging

In [9]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [10]:
def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):

    writer = get_tensorboard_logger("HypernetworkTorchRL")
    optimizer = optim.Adam(hypernetwork.parameters(), lr=lr)

    bar = tqdm.tqdm(np.arange(num_epochs))
    for i in bar:
        observations, actions, rewards, _ = rollout(env, hypernetwork)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = torch.from_numpy(np.array(observations)).float().to(hypernetwork.device)
        actions = torch.from_numpy(np.array(actions)).float().to(hypernetwork.device)
        discounted_rewards = torch.from_numpy(discounted_rewards).float().to(hypernetwork.device)

        logits, _ = hypernetwork(observations)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)

        optimizer.zero_grad()
        loss = -1 * torch.sum(discounted_rewards * log_probs)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 10.0)
        optimizer.step()

        grad_dict = {}
        for n, W in hypernetwork.named_parameters():
            if W.grad is not None:
                grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

        metrics = {"loss":loss.item(), "rewards":np.sum(rewards), **grad_dict}

        for key in metrics:
            writer.add_scalar(key, metrics[key], i)

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [11]:
env = gym.make("LunarLander-v2")

In [12]:
reinforce(100000, env, hypernetwork)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/tensorboard_logs/HypernetworkTorchRL_2022-03-11 15:15:45.823610'


Loss: 11.215566635131836, Sum Reward: 48.52306111343759:   4%|▍         | 4314/100000 [22:15<8:13:40,  3.23it/s]     


KeyboardInterrupt: 

In [15]:
env = gym.make("LunarLander-v2")
rollout(env, hypernetwork, render=True)

([array([ 0.00166998,  1.4086931 ,  0.1691371 , -0.09898422, -0.00192832,
         -0.03831208,  0.        ,  0.        ], dtype=float32),
  array([ 3.2436370e-03,  1.4058961e+00,  1.5683393e-01, -1.2431073e-01,
         -1.3996738e-03,  1.0574637e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([ 4.7286032e-03,  1.4024968e+00,  1.4569941e-01, -1.5107898e-01,
          1.3612686e-03,  5.5223711e-02,  0.0000000e+00,  0.0000000e+00],
        dtype=float32),
  array([ 0.00612841,  1.3984997 ,  0.13503139, -0.17765811,  0.00625875,
          0.0979587 ,  0.        ,  0.        ], dtype=float32),
  array([ 0.00743961,  1.3938966 ,  0.12391303, -0.20463419,  0.01338215,
          0.14248095,  0.        ,  0.        ], dtype=float32),
  array([ 0.00865555,  1.3886846 ,  0.1119798 , -0.23176104,  0.02289451,
          0.19026475,  0.        ,  0.        ], dtype=float32),
  array([ 0.00994558,  1.382868  ,  0.12128022, -0.25865215,  0.0305368 ,
          0.15285991,  0.   