In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from collections import deque
from torch.utils.data import DataLoader
from tqdm.notebook import trange
import wandb

In [2]:
gym.envs.register(
    id='CartPole-v200',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    max_episode_steps=200,
    reward_threshold=195.0,
)

In [3]:
def rewards_to_go(rewards, discount_factor=0.99):
    # from https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html
    
    r2g = []
    discounted_reward = 0
    
    for reward in reversed(rewards):
        discounted_reward = reward + discount_factor * discounted_reward
        r2g.insert(0, discounted_reward)
    
    return torch.tensor(r2g, dtype=torch.float)

In [4]:
def compute_advantage(rewards, state_values):
    advantages = rewards - state_values
    return normalize(advantages)
    #return advantages

In [5]:
def normalize(x):
    if torch.isnan(x.std()):
        return x - x.mean(0)
    
    return (x - x.mean(0)) / (x.std(0) + 1e-8)

In [6]:
class Trajectory(torch.utils.data.Dataset):
    def __init__(self):
        self.states = []
        self.log_probs = []
        self.actions = []
        self.rewards = []
        self.dones = []

    def __len__(self):
        return len(self.states)
    
    def convert_rewards_to_go(self, discount_factor=0.99):
        # from https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html
        r2g = []
        discounted_reward = 0

        for reward in reversed(self.rewards):
            discounted_reward = reward + discount_factor * discounted_reward
            r2g.insert(0, discounted_reward)

        self.rewards = torch.tensor(r2g, dtype=torch.float)
        
    def fix_datatypes(self):
        self.states = torch.stack(self.states)
        self.actions = torch.tensor(self.actions, dtype=torch.long)
        self.dones = torch.tensor(self.dones, dtype=torch.int)
        self.log_probs = torch.tensor(self.log_probs, dtype=torch.float)
        

    def store_timestep(self, state, action, reward, done, log_prob):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.log_probs.append(log_prob)
        
    def clear_memory(self):
        self.states = []
        self.log_probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
    
    def __getitem__(self, index):
        state = self.states[index]
        action = self.actions[index]
        reward = self.rewards[index]
        done = self.dones[index]
        log_prob = self.log_probs[index]
        
        return state, action, reward, done, log_prob

In [7]:
class Agent:
    def __init__(self, env, action_dim, state_dim, batch_size, clip_ratio=0.2):
        self.env = env
        self.actor = Policy(action_dim, state_dim)
        self.critic = CriticNet(state_dim)
        self.trajectories = deque(maxlen=5)
        self.discount_factor = 0.99
        self.batch_size = batch_size
        self.clip_ratio = clip_ratio
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=1e-3)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=5e-2)
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.critic_loss = nn.MSELoss()
        
        # initialize logging
        wandb.init(project="cartpole")
        wandb.watch(self.actor, log="all")
        wandb.watch(self.critic, log="all")
        # wandb.run.name = 'test'
        
    def get_action(self, obs):
        action_probs = self.actor(obs)

        action_dist = Categorical(logits=action_probs)
        action = action_dist.sample()
        
        return action.item(), action_dist.log_prob(action).item()
    
    def train(self):
        trajectory = self.trajectories[-1]
        #for trajectory in self.trajectories:
        for i in range(2):
            loader = DataLoader(trajectory, batch_size=self.batch_size, shuffle=True, pin_memory=True,
                               drop_last=True)
            
            for states, actions, rewards, dones, old_log_probs in loader:
                states = states.to(self.device)
                actions = actions.to(self.device)
                rewards = rewards.to(self.device)
                old_log_probs = old_log_probs.to(self.device)
                
                state_values = self.critic(states)
                state_values = state_values.squeeze()
                advantages = compute_advantage(rewards, state_values.detach())
                
                action_probs = self.actor(states)
                action_dist = Categorical(logits=action_probs)
                log_probs = action_dist.log_prob(actions)
                
                # log trick for efficient computational graph during backprop
                ratio = torch.exp(log_probs - old_log_probs)
                # probs = torch.exp(log_probs) / torch.exp(old_log_probs)
                
                weighted_objective = advantages * ratio
                clipped_objective = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
                
                actor_loss = -torch.min(weighted_objective, clipped_objective).mean()
                critic_loss = self.critic_loss(state_values, rewards)
                
                self.actor_opt.zero_grad()
                actor_loss.backward(retain_graph=True)
                self.actor_opt.step()
                
                self.critic_opt.zero_grad()
                critic_loss.backward()
                self.critic_opt.step()
                
                wandb.log({'critic loss': critic_loss.mean()})
                wandb.log({'avg state value': state_values.mean()})
        
    
    def run_episode(self, render=False):
        trajectory = Trajectory()
        
        done = False
        obs = self.env.reset()
        
        while not done:
            if render:
                self.env.render()

            obs = torch.tensor(obs, dtype=torch.float, device=self.device)
            action, log_prob = agent.get_action(obs)
            obs = obs.cpu()
            next_obs, reward, done, _ = self.env.step(action)
            
            trajectory.store_timestep(obs, action, reward, done, log_prob)
            obs = next_obs
            
        wandb.log({'reward': np.sum(trajectory.rewards)})
        trajectory.convert_rewards_to_go()
        trajectory.fix_datatypes()
        self.trajectories.append(trajectory)
        

In [8]:
class Policy(nn.Module):
    def __init__(self, action_dim, state_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU()
        )
        
        self.fc3 = nn.Sequential(
            nn.Linear(64, action_dim),
        #    nn.Softmax(dim=-1)
        )
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x

In [9]:
class CriticNet(nn.Module):
    def __init__(self, state_dim):
        super(CriticNet, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU()
        )
        
        self.fc3 = nn.Sequential(
            nn.Linear(64, 1)
        )       
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x

In [10]:
env = gym.make("CartPole-v200")
agent = Agent(env, 2, 4, 16)

for i in range(500):
    agent.run_episode()
    agent.train()
        

[34m[1mwandb[0m: Currently logged in as: [33mtnfru[0m (use `wandb login --relogin` to force relogin)


In [11]:
for _ in range(20):
    agent.run_episode(True)

In [12]:
env.close()