In [122]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from collections import deque
from torch.utils.data import DataLoader

In [7]:
def rewards_to_go(rewards, discount_factor=0.99):
    # from https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html
    
    r2g = []
    discounted_reward = 0
    
    for reward in reversed(rewards):
        discounted_reward = reward + discount_factor * discounted_reward
        r2g.insert(0, discounted_reward)
    
    return torch.tensor(r2g, dtype=torch.float)

In [185]:
def compute_advantage(rewards, state_values):
    advantages = rewards - state_values
    return normalize(advantages)

In [9]:
def normalize(x):
    return (x - x.mean(0)) / (x.std(0) + 1e-7)

In [215]:
class Trajectory(torch.utils.data.Dataset):
    def __init__(self):
        self.states = []
        self.log_probs = []
        self.actions = []
        self.rewards = []
        self.dones = []

    def __len__(self):
        return len(self.states)
    
    def convert_rewards_to_go(self, discount_factor=0.99):
        # from https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html
        r2g = []
        discounted_reward = 0

        for reward in reversed(self.rewards):
            discounted_reward = reward + discount_factor * discounted_reward
            r2g.insert(0, discounted_reward)

        self.rewards = torch.tensor(r2g, dtype=torch.float)
        
    def fix_datatypes(self):
        self.states = torch.stack(self.states)
        self.actions = torch.tensor(self.actions, dtype=torch.long)
        #self.rewards = torch.tensor(self.rewards, dtype=torch.float)
        self.dones = torch.tensor(self.dones, dtype=torch.int)
        self.log_probs = torch.tensor(self.log_probs, dtype=torch.float)
        

    def store_timestep(self, state, action, reward, done, log_prob):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.log_probs.append(log_prob)
        
    def clear_memory(self):
        self.states = []
        self.log_probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
    
    def __getitem__(self, index):
        state = self.states[index]
        action = self.actions[index]
        reward = self.rewards[index]
        done = self.dones[index]
        log_prob = self.log_probs[index]
        
        return state, action, reward, done, log_prob

In [242]:
class Agent:
    def __init__(self, env, action_dim, state_dim, batch_size, clip_ratio=0.2):
        self.env = env
        self.actor = Policy(action_dim, state_dim)
        self.critic = CriticNet(state_dim)
        self.trajectories = deque(maxlen=5)
        self.discount_factor = 0.99
        self.batch_size = batch_size
        self.clip_ratio = clip_ratio
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=3e-4)
        
    def get_action(self, obs):
        action_probs = self.actor(obs)
        action_dist = Categorical(logits=action_probs)
        action = action_dist.sample()
        
        return action.item(), action_dist.log_prob(action).item()
    
    def train(self):
        for trajectory in self.trajectories:
            loader = DataLoader(trajectory, batch_size=self.batch_size, shuffle=False)
            
            for states, actions, rewards, dones, old_log_probs in loader:
                state_values = self.critic(states)
                state_values = state_values.squeeze()
                advantages = compute_advantage(rewards, state_values.detach())
                
                action_probs = self.actor(states)
                action_dist = Categorical(logits=action_probs)
                log_probs = action_dist.log_prob(actions)
                
                # log trick for efficient computational graph during backprop
                probs = torch.exp(log_probs - old_log_probs)
                
                weighted_objective = advantages * probs
                clipped_objective = torch.clamp(probs, 1 - self.clip_ratio, 1 + self.clip_ratio)
                
                objective = -torch.min(weighted_objective, clipped_objective).mean()
                self.actor_opt.zero_grad()
                objective.backward(retain_graph=True)
                self.actor_opt.step()
                
                self.critic_opt.zero_grad()
                print(rewards.shape, state_values.shape)
                loss = nn.MSELoss()(state_values, rewards)
                loss.backward()
                self.critic_opt.step()
        
    
    def run_episode(self, render=False):
        trajectory = Trajectory()
        
        done = False
        obs = self.env.reset()
        
        while not done:
            if render:
                self.env.render()

            obs = torch.from_numpy(obs).float()
            action, log_prob = agent.get_action(obs)
            next_obs, reward, done, _ = self.env.step(action)
            
            trajectory.store_timestep(obs, action, reward, done, log_prob)
            obs = next_obs
            
        trajectory.convert_rewards_to_go()
        trajectory.fix_datatypes()
        self.trajectories.append(trajectory)
        

In [243]:
class Policy(nn.Module):
    def __init__(self, action_dim, state_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 100),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(100, action_dim)
        )
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [244]:
class CriticNet(nn.Module):
    def __init__(self, state_dim):
        super(CriticNet, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 100),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(100, 1)
        )       
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [245]:
env = gym.make("CartPole-v0")
agent = Agent(env, 2, 4, 16)

In [248]:
#for _ in range(5):
#    agent.run_episode()
    
for _ in range(5):
    agent.train()
    agent.run_episode()

torch.Size([16]) torch.Size([16])
torch.Size([5]) torch.Size([5])
torch.Size([9]) torch.Size([9])
torch.Size([16]) torch.Size([16])
torch.Size([5]) torch.Size([5])
torch.Size([13]) torch.Size([13])
torch.Size([16]) torch.Size([16])
torch.Size([4]) torch.Size([4])
torch.Size([9]) torch.Size([9])
torch.Size([16]) torch.Size([16])
torch.Size([5]) torch.Size([5])
torch.Size([13]) torch.Size([13])
torch.Size([16]) torch.Size([16])
torch.Size([4]) torch.Size([4])
torch.Size([16]) torch.Size([16])
torch.Size([3]) torch.Size([3])
torch.Size([16]) torch.Size([16])
torch.Size([5]) torch.Size([5])
torch.Size([13]) torch.Size([13])
torch.Size([16]) torch.Size([16])
torch.Size([4]) torch.Size([4])
torch.Size([16]) torch.Size([16])
torch.Size([3]) torch.Size([3])
torch.Size([16]) torch.Size([16])
torch.Size([1]) torch.Size([])


ValueError: The parameter logits has invalid values

In [None]:
%debug

> [0;32m/home/lars/miniconda3/lib/python3.8/site-packages/torch/distributions/distribution.py[0m(53)[0;36m__init__[0;34m()[0m
[0;32m     51 [0;31m                    [0;32mcontinue[0m  [0;31m# skip checking lazily-constructed args[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m                [0;32mif[0m [0;32mnot[0m [0mconstraint[0m[0;34m.[0m[0mcheck[0m[0;34m([0m[0mgetattr[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mparam[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mall[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 53 [0;31m                    [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34m"The parameter {} has invalid values"[0m[0;34m.[0m[0mformat[0m[0;34m([0m[0mparam[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m        [0msuper[0m[0;34m([0m[0mDistribution[0m[0;34m,[0m [0mself[0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34

In [172]:
for obs, action, reward, done, log_prob in loader:
    print(obs)
    print()
    print(action)
    print()
    print(reward)
    print()
    print(done)
    print()
    print(log_prob)

tensor([[-0.1665, -0.9947,  0.1671,  1.4754],
        [-0.0355, -0.4057, -0.0154,  0.4988],
        [-0.0716, -0.9909,  0.0319,  1.3735],
        [-0.1864, -0.8020,  0.1966,  1.2392],
        [-0.0313, -0.2108, -0.0196,  0.2124],
        [-0.0232, -0.4063, -0.0299,  0.5144],
        [-0.0186, -0.0171, -0.0336, -0.0491],
        [-0.0189, -0.2117, -0.0346,  0.2328],
        [-0.1151, -1.3821,  0.0929,  1.9866],
        [-0.1427, -1.1881,  0.1326,  1.7240],
        [-0.0436, -0.6006, -0.0054,  0.7866],
        [-0.0914, -1.1864,  0.0594,  1.6760],
        [-0.0556, -0.7956,  0.0103,  1.0776]])

tensor([1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0])

tensor([ 1.9900,  8.6483,  5.8520,  1.0000,  9.5618, 10.4662, 12.2479, 11.3615,
         3.9404,  2.9701,  7.7255,  4.9010,  6.7935], dtype=torch.float64)

tensor([False, False, False,  True, False, False, False, False, False, False,
        False, False, False])

tensor([-0.5157, -0.7881, -0.8877, -0.5286, -0.7402, -0.6036, -0.7031, -0.7437,
      