In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym

# Hyperparameters
learning_rate = 0.0005
gamma = 0.98
lmbda = 0.95
eps_clip = 0.1
K_epochs = 3
T_horizon = 20


In [2]:

class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.data = []

        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.fc_v = nn.Linear(256, 1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, softmax_dim=0):
        x = torch.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = torch.softmax(x, dim=softmax_dim)
        return prob

    def v(self, x):
        x = torch.relu(self.fc1(x))
        v = self.fc_v(x)
        return v

    def put_data(self, transition):
        self.data.append(transition)

    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, done = transition

            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            done_mask = 0 if done else 1
            done_lst.append([done_mask])

        s_batch, a_batch, r_batch, s_prime_batch, prob_a_batch, done_batch = (
            torch.tensor(s_lst, dtype=torch.float),
            torch.tensor(a_lst),
            torch.tensor(r_lst),
            torch.tensor(s_prime_lst, dtype=torch.float),
            torch.tensor(prob_a_lst, dtype=torch.float),
            torch.tensor(done_lst, dtype=torch.float),
        )
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, prob_a_batch, done_batch

    def train_net(self):
        s, a, r, s_prime, prob_a, done = self.make_batch()

        for _ in range(K_epochs):
            td_target = r + gamma * self.v(s_prime) * done
            delta = td_target - self.v(s)
            delta = delta.detach().numpy()

            advantage_lst = []
            advantage = 0.0
            for delta_t in delta[::-1]:
                advantage = gamma * lmbda * advantage + delta_t[0]
                advantage_lst.append([advantage])
            advantage_lst.reverse()
            advantage = torch.tensor(advantage_lst, dtype=torch.float)

            pi = self.pi(s, softmax_dim=1)
            pi_a = pi.gather(1, a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
            loss = -torch.min(surr1, surr2
                + 0.5 * torch.nn.functional.mse_loss(self.v(s), td_target.detach())
                - 0.01 * pi_a.mean()
            )

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()


  torch.tensor(s_lst, dtype=torch.float),


# of episode: 20, avg score: 19.5
# of episode: 40, avg score: 22.6
# of episode: 60, avg score: 23.1
# of episode: 80, avg score: 21.0
# of episode: 100, avg score: 14.1
# of episode: 120, avg score: 16.1
# of episode: 140, avg score: 14.8
# of episode: 160, avg score: 13.6
# of episode: 180, avg score: 12.9
# of episode: 200, avg score: 11.6
# of episode: 220, avg score: 11.7
# of episode: 240, avg score: 11.6
# of episode: 260, avg score: 12.2
# of episode: 280, avg score: 12.2
# of episode: 300, avg score: 13.1
# of episode: 320, avg score: 13.1
# of episode: 340, avg score: 12.7
# of episode: 360, avg score: 13.9
# of episode: 380, avg score: 12.3
# of episode: 400, avg score: 12.2
# of episode: 420, avg score: 11.8
# of episode: 440, avg score: 12.9
# of episode: 460, avg score: 11.2
# of episode: 480, avg score: 10.4
# of episode: 500, avg score: 10.6
# of episode: 520, avg score: 11.7
# of episode: 540, avg score: 11.4
# of episode: 560, avg score: 11.6
# of episode: 580, avg s

# of episode: 4720, avg score: 9.3
# of episode: 4740, avg score: 9.4
# of episode: 4760, avg score: 9.2
# of episode: 4780, avg score: 9.2
# of episode: 4800, avg score: 9.2
# of episode: 4820, avg score: 9.5
# of episode: 4840, avg score: 9.8
# of episode: 4860, avg score: 9.6
# of episode: 4880, avg score: 9.3
# of episode: 4900, avg score: 9.5
# of episode: 4920, avg score: 9.2
# of episode: 4940, avg score: 9.3
# of episode: 4960, avg score: 9.6
# of episode: 4980, avg score: 9.4
# of episode: 5000, avg score: 9.5
# of episode: 5020, avg score: 9.1
# of episode: 5040, avg score: 9.3
# of episode: 5060, avg score: 9.4
# of episode: 5080, avg score: 9.3
# of episode: 5100, avg score: 9.2
# of episode: 5120, avg score: 9.3
# of episode: 5140, avg score: 9.4
# of episode: 5160, avg score: 9.1
# of episode: 5180, avg score: 9.4
# of episode: 5200, avg score: 9.4
# of episode: 5220, avg score: 9.8
# of episode: 5240, avg score: 9.3
# of episode: 5260, avg score: 9.3
# of episode: 5280, 

# of episode: 9480, avg score: 9.4
# of episode: 9500, avg score: 9.5
# of episode: 9520, avg score: 9.2
# of episode: 9540, avg score: 9.6
# of episode: 9560, avg score: 9.2
# of episode: 9580, avg score: 9.4
# of episode: 9600, avg score: 9.4
# of episode: 9620, avg score: 9.2
# of episode: 9640, avg score: 9.4
# of episode: 9660, avg score: 9.3
# of episode: 9680, avg score: 9.4
# of episode: 9700, avg score: 9.3
# of episode: 9720, avg score: 9.6
# of episode: 9740, avg score: 9.4
# of episode: 9760, avg score: 9.3
# of episode: 9780, avg score: 9.5
# of episode: 9800, avg score: 9.2
# of episode: 9820, avg score: 9.4
# of episode: 9840, avg score: 9.4
# of episode: 9860, avg score: 9.3
# of episode: 9880, avg score: 9.1
# of episode: 9900, avg score: 9.2
# of episode: 9920, avg score: 9.3
# of episode: 9940, avg score: 9.4
# of episode: 9960, avg score: 9.7
# of episode: 9980, avg score: 9.2


In [3]:


def main():
    env = gym.make("CartPole-v0")
    model = PPO()
    score = 0.0
    print_interval = 20

    for n_epi in range(10000):
        s = env.reset()
        done = False
        while not done:
            for t in range(T_horizon):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, _ = env.step(a)

                model.put_data((s, a, r / 100.0, s_prime, prob[a].item(), done))
                s = s_prime
                score += r

                if done:
                    break

            model.train_net()

        if n_epi % print_interval == 0 and n_epi != 0:
            print("# of episode: {}, avg score: {:.1f}".format(n_epi, score / print_interval))
            score = 0.0

    env.close()


In [4]:


if __name__ == "__main__":
    main()


# of episode: 20, avg score: 24.9
# of episode: 40, avg score: 23.4
# of episode: 60, avg score: 19.8
# of episode: 80, avg score: 18.9
# of episode: 100, avg score: 15.9
# of episode: 120, avg score: 16.1
# of episode: 140, avg score: 14.0
# of episode: 160, avg score: 17.4
# of episode: 180, avg score: 16.3
# of episode: 200, avg score: 17.1
# of episode: 220, avg score: 17.8
# of episode: 240, avg score: 13.3
# of episode: 260, avg score: 13.6
# of episode: 280, avg score: 16.2
# of episode: 300, avg score: 13.2
# of episode: 320, avg score: 13.9
# of episode: 340, avg score: 15.2
# of episode: 360, avg score: 14.3
# of episode: 380, avg score: 11.8
# of episode: 400, avg score: 11.0
# of episode: 420, avg score: 11.9
# of episode: 440, avg score: 10.9
# of episode: 460, avg score: 12.6
# of episode: 480, avg score: 12.2
# of episode: 500, avg score: 11.6
# of episode: 520, avg score: 11.2
# of episode: 540, avg score: 10.8
# of episode: 560, avg score: 10.4
# of episode: 580, avg s

# of episode: 4760, avg score: 9.4
# of episode: 4780, avg score: 9.2
# of episode: 4800, avg score: 9.2
# of episode: 4820, avg score: 9.0
# of episode: 4840, avg score: 9.7
# of episode: 4860, avg score: 9.2
# of episode: 4880, avg score: 9.5
# of episode: 4900, avg score: 9.7
# of episode: 4920, avg score: 9.6
# of episode: 4940, avg score: 9.2
# of episode: 4960, avg score: 9.4
# of episode: 4980, avg score: 9.4
# of episode: 5000, avg score: 9.7
# of episode: 5020, avg score: 9.2
# of episode: 5040, avg score: 9.2
# of episode: 5060, avg score: 9.5
# of episode: 5080, avg score: 9.4
# of episode: 5100, avg score: 9.2
# of episode: 5120, avg score: 9.3
# of episode: 5140, avg score: 9.6
# of episode: 5160, avg score: 9.6
# of episode: 5180, avg score: 9.6
# of episode: 5200, avg score: 9.4
# of episode: 5220, avg score: 9.3
# of episode: 5240, avg score: 9.3
# of episode: 5260, avg score: 9.8
# of episode: 5280, avg score: 9.1
# of episode: 5300, avg score: 9.3
# of episode: 5320, 

# of episode: 9480, avg score: 9.2
# of episode: 9500, avg score: 9.3
# of episode: 9520, avg score: 9.2
# of episode: 9540, avg score: 9.5
# of episode: 9560, avg score: 9.8
# of episode: 9580, avg score: 9.6
# of episode: 9600, avg score: 9.5
# of episode: 9620, avg score: 9.3
# of episode: 9640, avg score: 9.2
# of episode: 9660, avg score: 9.4
# of episode: 9680, avg score: 9.1
# of episode: 9700, avg score: 9.4
# of episode: 9720, avg score: 9.3
# of episode: 9740, avg score: 9.9
# of episode: 9760, avg score: 9.3
# of episode: 9780, avg score: 9.4
# of episode: 9800, avg score: 9.2
# of episode: 9820, avg score: 9.3
# of episode: 9840, avg score: 9.4
# of episode: 9860, avg score: 9.2
# of episode: 9880, avg score: 9.7
# of episode: 9900, avg score: 9.1
# of episode: 9920, avg score: 9.6
# of episode: 9940, avg score: 9.5
# of episode: 9960, avg score: 9.5
# of episode: 9980, avg score: 8.9
