In [1]:
import time

import gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
device = 'cuda'

In [3]:
class Policy(nn.Module):
    def __init__(self,N_s,N_a):
        super().__init__()
        self.layer_1 = nn.Linear(N_s,200)
        self.layer_2 = nn.Linear(200,200)
        self.layer_3 = nn.Linear(200,N_a)
        self.N_s = N_s
        self.N_a = N_a
    
    def forward(self, s):
        assert type(s) == torch.Tensor
        
        h = F.leaky_relu(self.layer_1(s))
        h = F.leaky_relu(self.layer_2(h))
        h = F.leaky_relu(self.layer_3(h))
        return F.log_softmax(h,dim=0)
    
    def sample_action(self,s):
        with torch.no_grad():
            P = torch.exp(self.forward(s)).detach().cpu().numpy()
        return np.random.choice([i for i in range(self.N_a)],p=P)
            

In [4]:
policy = Policy(4,2).to(device)
optimizer = torch.optim.Adam(policy.parameters(),lr=3e-4)

In [5]:
env = gym.make('CartPole-v0')

In [None]:
Ss = []
As = []
Rs = []

for episode in range(1000):
    #Initialize
    Ss = []
    As = []
    Rs = []
    s = env.reset()
    env.render()
    done = 0

    
    #Run Episode
    while(True):
        Ss.append(s)
        s = torch.tensor(s,dtype=torch.float32).to(device)
        
        a = policy.sample_action(s)
        As.append(a)
        
        ns, r, done, info = env.step(a)
        Rs.append(r)
        s = ns
        
        if done:
            break
        else:
            env.render()
    
    #When Episode ends
    Ss = np.array(Ss)
    As = np.array(As)
    Rs = np.array(Rs)
    
    #Train
    optimizer.zero_grad()
    T = Rs.shape[-1]
    objective = 0
    
    for t in range(T-1,0,-1):
        s = torch.tensor(Ss[t],dtype=torch.float32).to(device)
        a = As[t]
        R = np.sum(Rs[t:T])
        log_pi = policy(s)[a]
        
        objective -= R * log_pi
    
    objective.backward()
    optimizer.step()
    if episode % 1 == 0:
        print(R)
            
env.close() 

41.0
11.0
27.0
53.0
20.0
24.0
15.0
17.0
15.0
16.0
30.0
8.0
36.0
23.0
22.0
10.0
15.0
18.0
9.0
30.0
28.0
29.0
16.0
11.0
13.0
31.0
15.0
31.0
15.0
19.0
15.0
18.0
18.0
35.0
29.0
19.0
21.0
39.0
68.0
23.0
29.0
35.0
28.0
40.0
15.0
19.0
19.0
25.0
8.0
22.0
11.0
26.0
20.0
14.0
28.0
11.0
30.0
14.0
94.0
87.0
41.0
41.0
14.0
19.0
15.0
22.0
20.0
85.0
9.0
19.0
43.0
33.0
43.0
21.0
62.0
26.0
19.0
66.0
22.0
23.0
34.0
28.0
34.0
29.0
25.0
14.0
42.0
66.0
17.0
29.0
16.0
55.0
43.0
23.0
58.0
21.0
19.0
27.0
45.0
30.0
34.0
16.0
46.0
45.0
12.0
56.0
94.0
34.0
39.0
57.0
35.0
50.0
22.0
58.0
30.0
63.0
41.0
55.0
53.0
84.0
55.0
41.0
48.0
39.0
32.0
98.0
21.0
94.0
41.0
109.0
174.0
44.0
115.0
56.0
28.0
51.0
59.0
126.0
32.0
187.0
39.0
100.0
94.0
121.0
50.0
78.0
31.0
40.0
81.0
33.0
51.0
37.0
154.0
83.0
79.0
65.0
24.0
67.0
148.0
151.0
120.0
44.0
53.0
40.0
55.0
111.0
28.0
81.0
33.0
46.0
30.0
168.0
45.0
108.0
199.0
25.0
109.0
39.0
164.0
163.0
61.0
65.0
38.0
62.0
80.0
199.0
133.0
84.0
86.0
126.0
108.0
90.0
82.0
107.0
55.0
126.0
