# Meta-gradient learning with SGD

<img src="img/alg.png" alt="algorithm" style="width:800px;"/>

Optimize gamma with an Adam optimizer.

In [1]:
import gym
import numpy as np
import random
import time
import higher
import torch
from torch.distributions import Categorical
import torch.nn.functional as F

from models.models import ActorCritic


In [131]:
ITERATION_NUMS = 20000  # 500
SAMPLE_NUMS = 50  # 100
LR = 0.01
LAMBDA = torch.tensor(0.98)
CLIP_GRAD_NORM = 40  # 40
BETA = 0.001
GAMMA_INIT = 0.99  # 0.99
GAMMA_FIX = 1.0  # 0.995

## Utility Function
### Rollout function & testing function

In [3]:
def roll_out(agent, task, sample_nums, init_state):
    states = []
    actions = []
    rewards = []
    vts = []  # v-values at timestep t
    vt1s = []  # v-values at timestep t+1
    dones = []

    state = init_state

    for i in range(sample_nums):
        states.append(state)
        act, vt = choose_action(agent, state)
        actions.append(act)

        next_state, reward, done, _ = task.step(act.numpy())
        with torch.no_grad():
            _, vt1 = agent(torch.Tensor(next_state))
        state = next_state
        rewards.append(reward)
        dones.append(1 if done is False else 0)
        vt = vt.detach().numpy()
        vt1 = vt1.detach().numpy()
        vts.append(vt)
        vt1s.append(vt1)

        if done:
            state = task.reset()

    return states, actions, rewards, vts, vt1s, dones, state


def test(gym_name, agent):
    result = 0
    test_task = gym.make(gym_name)
    for test_epi in range(10):
        state = test_task.reset()
        for test_step in range(500):
            act, _ = choose_action(agent, state)
            next_state, reward, done, _ = test_task.step(act.numpy())
            result += reward
            state = next_state
            if done:
                break
    return result

### Computational function
Functions to compute advantages and target v-values from a trajectory.

In [4]:
@torch.no_grad()
def choose_action(agent, state):
    logits, v = agent(torch.Tensor(state))
    act_probs = F.softmax(logits, dim=-1)
    m = Categorical(act_probs)
    act = m.sample()

    return act, v


def gae_calculater_grad(rewards, v_t_s, v_t1_s, dones, gamma, lambda_):
    """
    Calculate advantages and target v-values
    """
    batch_size = len(rewards)
    R = torch.reshape(torch.tensor(0), (1, 1))  # increment term
    advs = torch.zeros(batch_size)
    for t in reversed(range(0, batch_size)):
        delta = torch.tensor(rewards[t]) - torch.tensor(v_t_s[t]) + \
                (gamma * torch.tensor(v_t1_s[t]) * torch.tensor(dones[t]))
        R = delta + (gamma * lambda_ * R * torch.tensor(dones[t]))
        advs[t] = R
    value_target = advs + torch.tensor(np.squeeze(v_t_s))  # target v is calculated from adv.

    return advs, value_target


def trajectory_cutter(rewards, vts, vt1s, dones):
    """
    Divide sample into multiple groups for returns computing.
    Samples in the same group means that they was sampled in the same epoch.
    """
    # "not done" = 1, "done" = 0
    cutted_rewards = []
    cutted_vts = []
    cutted_vt1s = []
    cutted_dones = []
    temp_r = []
    temp_vt = []
    temp_vt1 = []
    temp_d = []
    
    for (reward, vt, vt1, done) in zip(rewards, vts, vt1s, dones):
        temp_r.append(reward)
        temp_vt.append(vt)
        temp_vt1.append(vt1)
        temp_d.append(done)
        if done == 0:
            cutted_rewards.append(temp_r)
            cutted_vts.append(temp_vt)
            cutted_vt1s.append(temp_vt1)
            cutted_dones.append(temp_d)
            temp_r = []
            temp_vt = []
            temp_vt1 = []
            temp_d = []
    cutted_rewards.append(temp_r)
    cutted_vts.append(temp_vt)
    cutted_vt1s.append(temp_vt1)
    cutted_dones.append(temp_d)
    
    return cutted_rewards, cutted_vts, cutted_vt1s, cutted_dones


def trajectory_gae(rewards, vts, vt1s, dones, gamma, lambda_):
    rewards, vts, vt1s, dones = trajectory_cutter(rewards, vts, vt1s, dones)
    
    advs = []
    v_targets = []
    for (r, vt, vt1, d) in zip(rewards, vts, vt1s, dones):
        adv, v_target = gae_calculater_grad(r, vt, vt1, d, gamma, lambda_)
        advs.append(adv)
        v_targets.append(v_target)
    advs = torch.cat(advs).float()
    v_targets = torch.cat(v_targets).float()
    
    return advs, v_targets

In [5]:
# test trajectory_gae()
rewards = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
dones = [1, 1, 1, 1, 1, 0, 1, 1, 1, 1]
vts = [np.array(np.random.randn()) for _ in range(10)]
vt1s = [np.array(np.random.randn()) for _ in range(10)]
GAMMA = torch.tensor(0.99, requires_grad=True)
LAMBDA = torch.tensor(0.98)

advs, v_targets = trajectory_gae(rewards, vts, vt1s, dones, GAMMA, LAMBDA)
# grad = torch.autograd.grad(advs.mean(), GAMMA, allow_unused=True)

# for g in grad:
#     assert g is not None, "test failed"
#     print("test passed")


### Agent training function
RL algorithm: A2C + GAE
Calculate grad for substitute agent

In [6]:
def get_logits(agent, states, actions, action_dim):
    states = torch.Tensor(np.array(states))
    actions = torch.tensor(actions, dtype=torch.int64).view(-1, 1)
    
    logits, v = agent(states)
    logits = logits.view(-1, action_dim)
    v = v.view(-1)
    probs = F.softmax(logits, dim=1)
    log_probs = F.log_softmax(logits, dim=1)
    log_probs_act = log_probs.gather(1, actions).view(-1)
    
    return probs, log_probs, log_probs_act, v


def get_meta_logits(agent, states, actions, action_dim):
    states = torch.Tensor(np.array(states))
    actions = torch.tensor(actions, dtype=torch.int64).view(-1, 1)
    
    logits, _ = agent(states)
    logits = logits.view(-1, action_dim)
    log_probs = F.log_softmax(logits, dim=1)
    log_probs_act = log_probs.gather(1, actions).view(-1)
    
    return log_probs_act

## Running section

In [102]:
random_seed = 123456789
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

gym_name = "CartPole-v1"
task = gym.make(gym_name)
task.seed(random_seed)

[123456789]

In [129]:
discrete = isinstance(task.action_space, gym.spaces.Discrete)
STATE_DIM = task.observation_space.shape[0]
ACTION_DIM = task.action_space.n if discrete else task.action_space.shape[0]

agent = ActorCritic(STATE_DIM, ACTION_DIM)
optim = torch.optim.RMSprop(agent.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=1.0, end_factor=0.0, total_iters=ITERATION_NUMS)
assert next(agent.parameters()).is_cuda is False  # use only cpu for the current version

gamma = torch.tensor(GAMMA_INIT, requires_grad=True)

Use SGD to update gamma

In [130]:
iterations = []
test_results = []
gamma_buffer = []

init_state = task.reset()
for i in range(ITERATION_NUMS):
    # Line 3-4　Sample trajectories for both RL and meta-grad learning
    states1, actions1, rewards1, vts1, vt1s1, dones1, current_state = roll_out(agent, task, SAMPLE_NUMS, init_state)
    init_state = current_state

    """RL Phase:"""
    probs, log_probs, log_probs_act, v = get_logits(agent, states1, actions1, ACTION_DIM)
    advs, v_targets = trajectory_gae(rewards1, vts1, vt1s1, dones1, gamma, LAMBDA)

    # Line 5  Loss computation
    loss_policy = - (advs.detach() * log_probs_act).mean()
    loss_critic = F.mse_loss(v_targets.detach(), v, reduction='mean')
    loss_entropy = - (log_probs * probs).mean()

    loss = loss_policy + .25 * loss_critic - .001 * loss_entropy
    optim.zero_grad()
    
    # Compute trace z.
    f1 = torch.autograd.grad(v_targets.mean(), gamma, retain_graph=True)
    f1 = f1[0]
    
    f2 = torch.autograd.grad(loss_policy + .25 * loss_critic, agent.parameters(), retain_graph=True)
    f2 = [item.view(-1) for item in f2]
    f2 = torch.cat(f2)
    
    z = LR * f1 * f2
    
    loss.backward(retain_graph=True)
    torch.nn.utils.clip_grad_norm_(agent.parameters(), CLIP_GRAD_NORM)

    # Line 6  Obtain the update agent
    optim.step()
    scheduler.step()
    optim.zero_grad()

    """Meta-grad Learning Phase:"""
    states2, actions2, rewards2, vts2, vt1s2, dones2, current_state = roll_out(agent, task, SAMPLE_NUMS, init_state)
    init_state = current_state
    
    # Line 7-9  Update meta-parameter
    log_probs_act_dash = get_meta_logits(agent, states2, actions2, ACTION_DIM)
    advs_dash, _ = trajectory_gae(rewards2, vts2, vt1s2, dones2, torch.tensor(GAMMA_FIX), LAMBDA)
    
    # Compute meta-gradient
    meta_loss = (advs_dash.detach() * log_probs_act_dash).mean()
    J_dash = torch.autograd.grad(meta_loss, agent.parameters(), allow_unused=True)
    J_dash = [torch.zeros_like(params.data) if item is None else item
              for (item, params) in zip(J_dash, agent.parameters())]
    J_dash = [item.view(-1) for item in J_dash]
    J_dash = torch.cat(J_dash)
    
    delta_gamma = - BETA * (J_dash @ z)
    gamma.data += delta_gamma
    # gamma.data += torch.sign(delta_gamma) * torch.min(torch.abs(delta_gamma), torch.tensor(0.1))
    # Limit gamma into (0, 1), bug might happen after this mechanism take effect
    if gamma.data > torch.tensor(0.9999):
        gamma.data = torch.tensor(0.9999)
    elif gamma.data < torch.tensor(0.0001):
        gamma.data = torch.tensor(0.0001)

    # testing
    if (i + 1) % (ITERATION_NUMS // 100) == 0:
        result = test(gym_name, agent)
        print("iteration:", i + 1, "test result:", result / 10.0, "gamma:", gamma.data)
        iterations.append(i + 1)
        test_results.append(result / 10)

iteration: 200 test result: 9.8 gamma: tensor(0.7000)
iteration: 400 test result: 9.1 gamma: tensor(0.7000)
iteration: 600 test result: 264.5 gamma: tensor(0.6997)
iteration: 800 test result: 173.2 gamma: tensor(0.7017)
iteration: 1000 test result: 122.6 gamma: tensor(0.6965)
iteration: 1200 test result: 177.8 gamma: tensor(0.6962)
iteration: 1400 test result: 123.5 gamma: tensor(0.6975)
iteration: 1600 test result: 131.3 gamma: tensor(0.6956)
iteration: 1800 test result: 165.0 gamma: tensor(0.6947)
iteration: 2000 test result: 209.1 gamma: tensor(0.6956)
iteration: 2200 test result: 114.8 gamma: tensor(0.6965)
iteration: 2400 test result: 145.1 gamma: tensor(0.6960)
iteration: 2600 test result: 185.5 gamma: tensor(0.6915)
iteration: 2800 test result: 220.5 gamma: tensor(0.6922)
iteration: 3000 test result: 119.5 gamma: tensor(0.6924)
iteration: 3200 test result: 140.8 gamma: tensor(0.6899)
iteration: 3400 test result: 230.6 gamma: tensor(0.6881)
iteration: 3600 test result: 271.3 gamm