In [2]:
import torchrl
import torchrl.network as network

# later incorporate this into torchrl
from rlutils.vec_env import SubprocVecEnv
from rlutils.utils import make_atari_env
from rlutils.vec_env import VecFrameStack

import numpy as np

In [18]:
env = make_atari_env("BreakoutNoFrameskip-v4", 4)

In [19]:
env = VecFrameStack(env, 4)

In [23]:
params = {}
params["num_in_channels"] = 4
params["num_latent_nodes"] = 512
params["num_actions"] = 4
params["num_workers"] = 4
params["discount_gamma"] = 0.9
params["gae_tau"] = 0.5
params["use_gae"] = True
params["entropy_coef"]
params["value_coef"]
params["max_norm_grad"]
params["num_actions"]

In [25]:
class ActorCritic(nn.Module):
    def __init__(self, params):
        super(ActorCritic, self).__init__()
        self.e_coef = params["entropy_coef"]
        self.v_coef = params["value_coef"]
        self.max_grad_norm = params["max_norm_grad"]
        self.n_actions = params["num_actions"]

        # cnn head
        self.cnn_head = network.NatureCNN(params)

        # policy function
        self.pf = nn.Linear(512, self.n_actions)
        self.pdist = torch.distributions.Categorical
        
        # value function
        self.vf = nn.Linear(512, 1)
        
        self.mse = nn.MSELoss()
        

    def forward(self, x):
        x = self.cnn_head(x)
        p = self.pf(x)
        v = self.vf(x)
        
        pd = self.pdist(logits=p)
        action = pd.sample()
        log_prob = dist.log_prob(action).unsqueeze(-1)
        
        return action, log_prob, dist.entropy().unsqueeze(-1), v
        
    def value_func(self, x):
        x = self.cnn_head(x)
        v = self.vf(x)
        return v
        

    def loss_func(self, rollout):

        log_probs, values, returns, advantages, entropys = map(lambda x: torch.cat(x, dim=0), zip(*rollout))

        policy_loss = (-advantages * log_probs).mean()
        value_loss = self.mse(values, returns)

        loss = policy_loss - entropys * self.e_coef + value_loss * self.v_coef

        return loss

In [7]:
def v_wrap(x, dtype=np.float32):
    return torch.from_numpy(dtype(x))

In [24]:
episode_rewards = []
online_rewards = np.zeros(params["num_workers"])
num_workers = params["num_workers"]
gamma = params["discount_gamma"]
tau = params["gae_tau"]
use_gae = params["use_gae"]

curr_states = env.reset()

# collect rollout used for training
def collect_rollout(model, env, num_rollout=20):
    
    rollout = []

    # predict and collect
    for i in range(num_rollout):
        # care if states are uint8
        actions, log_probs, entropys, values = model(v_wrap(curr_states))

        next_states, rewards, dones, _ = env.step(actions)
        online_rewards += rewards
        for i in range(dones.shape[0]):
            if dones[i]:
                episode_rewards.append(online_rewards[i])
                online_rewards[i] = 0 
        
        rollout.append([actions, log_probs, entropys, values, rewards, 1 - dones])
        curr_states = next_states
    
    
    # calculate discounted returns and advantages
    last_values = model.value_func(v_wrap(curr_states))
    returns = last_values.detach()
    advantages = torch.zeros((num_workers))
    discounted_rollout = []
    
    for i in reversed(range(len(rollout) - 1)):
        actions, log_probs, entropys, values, rewards, not_dones = rollout[i]
        returns = rewards + gamma * not_dones * returns

        if not use_gae:
            advantages = returns - values.detach()
        else:
            next_values = last_values
            gae_returns = rewards + gamma * not_dones * next_values.detach() 
            td_error = gae_returns - values.detach()
            advantages = advantages * tau * gamma * not_dones + td_error
    
        discounted_rollout.append([log_probs, values, returns, advantages, entropys])
    
    return discounted_rollout

In [None]:


model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters, self.max_grad_norm)
optimz.step()