In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import tqdm
import torch.optim as optim


In [2]:
from hypernn.torch.static_hypernetwork import StaticTorchHyperNetwork

In [3]:
target_network = nn.Sequential(
    nn.Linear(8, 256, bias=False),
    nn.Tanh(),
    nn.Linear(256,256, bias=False),
    nn.Tanh(),
    nn.Linear(256, 4, bias=False)
)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

68608

In [25]:
from typing import Any, Dict, List, Optional, Tuple, Type, Union  # noqa

class LunarLanderHypernetwork(StaticTorchHyperNetwork):
    def __init__(
        self,
        target_network: nn.Module,
        embedding_dim: int = 100,
        num_embeddings: int = 3,
        hidden_dim: Optional[int] = None,
    ):
        super().__init__(
                    target_network = target_network,
                    embedding_dim = embedding_dim,
                    num_embeddings = num_embeddings,
                    hidden_dim = hidden_dim
                )

    def make_weight_generator(self):
        return nn.Sequential(
            nn.Linear(self.embedding_dim, 32),
            nn.Tanh(),
            nn.Linear(32, self.hidden_dim)
        )


In [27]:
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 512

hypernetwork = LunarLanderHypernetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

TypeError: __init__() got an unexpected keyword argument 'target_input_shape'

In [8]:
generated_params, out = hypernetwork.generate_params()

In [10]:
pytorch_total_params = sum(p.numel() for p in hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

2718

In [12]:
out, generated_params, extra = hypernetwork([torch.zeros((1, 8))], has_aux=True)

In [15]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [22]:
import gym

def rollout(env, hypernetwork, render=False) -> float:
    with torch.no_grad():
        params, _ = hypernetwork.generate_params()
        obs = env.reset()
        done = False
        observations, actions, rewards, rendereds = [], [], [], []
        while not done:
            rendered = None
            if render:
                rendered = env.render(mode="rgb_array")
                rendereds.append(rendered)

            action_logits = hypernetwork(inp=[torch.from_numpy(obs).unsqueeze(0).to(hypernetwork.device)], generated_params=params, has_aux=False)
            dist = Categorical(logits=action_logits)
            action = dist.sample().item()
            next_obs, r, done, _ = env.step(action)

            observations.append(obs)
            actions.append(action)
            rewards.append(r)

            obs = next_obs

    env.close()
    return observations, actions, rewards, rendereds

def discount_reward(rews, gamma: float = 0.99):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + gamma*(rtgs[i + 1] if i + 1 < n else 0)
    return rtgs

def reinforce(
        num_epochs,
        env,
        hypernetwork,
        rollout_fn=rollout,
        lr: float = 0.0001,
        gamma: float = 0.99,
    ):

    writer = get_tensorboard_logger("HypernetworkTorchRL")
    optimizer = optim.Adam(hypernetwork.parameters(), lr=lr)

    bar = tqdm.tqdm(np.arange(num_epochs))
    for i in bar:
        observations, actions, rewards, _ = rollout_fn(env, hypernetwork)

        discounted_rewards = discount_reward(np.array(rewards), gamma)
        discounted_rewards = discounted_rewards - np.mean(discounted_rewards)
        discounted_rewards = discounted_rewards / (
            np.std(discounted_rewards) + 1e-10
        )

        observations = torch.from_numpy(np.array(observations)).float().to(hypernetwork.device)
        actions = torch.from_numpy(np.array(actions)).float().to(hypernetwork.device)
        discounted_rewards = torch.from_numpy(discounted_rewards).float().to(hypernetwork.device)

        logits, generated_params, aux_output = hypernetwork(inp=[observations], has_aux=True)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)

        optimizer.zero_grad()
        loss = -1 * torch.sum(discounted_rewards * log_probs)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 10.0)
        optimizer.step()

        grad_dict = {}
        for n, W in hypernetwork.named_parameters():
            if W.grad is not None:
                grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

        metrics = {"loss":loss.item(), "rewards":np.sum(rewards), **grad_dict}

        for key in metrics:
            writer.add_scalar(key, metrics[key], i)

        bar.set_description('Loss: {}, Sum Reward: {}'.format(loss.item(), np.sum(rewards)))


In [23]:
env = gym.make("LunarLander-v2")

In [24]:
reinforce(100000, env, hypernetwork, lr=0.00002)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/tensorboard_logs/HypernetworkTorchRL_2022-04-25 21:53:01.709378'


Loss: -0.9996805191040039, Sum Reward: -17.84515538135497:  20%|█▉        | 19814/100000 [24:47<1:40:19, 13.32it/s]   


KeyboardInterrupt: 