In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
lr = 0.0002
gamma = 0.98

In [3]:
def main():
    env = gy.make('CartPole-v1')
    model = ActorCritic() # 이 부분이 Policy-Based와 다른 점이다
    n_rollout = 5 # 5번의 step마다 update 진행
    print_int = 20
    score = 0
    
    for episode in range(2000): 
        done = False
        state = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                actions = Categorical(prob).sample()
                state_, returns, done, info = env.step(actions.item())
                model.put_data((state,actions,returns,state_,done))
                state = state_
                score += returns

                if done:
                    break
                
            model.train()    
    
        if episode%print_int==0 & episode!=0:
            print('episode : {}, score : {}'.format(episode, score/print_int))
            score= 0
    env.close()
    
    

In [4]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        self.fc_common = nn.Linear(4,128) 
        self.fc_pi = nn.Linear(128,2)
        self.fc_v = nn.Linear(128,1)
        self.opt = optim.Adam(self.parameters(), lr=lr)
    
    def pi(sef,x,dim):
        x = F.relu(self.fc_common(x))
        x = self.fc_pi(x)
        pi = F.softmax(x,dim=dim)
        return pi
    
    def v(self,x):
        x = F.relu(self.fc_common(x))
        v = self.fc_v(x)        
        return v
    
    def put_data(self,item):
        self.data.append(item)
    
    def batch(self):
        S,A,R,S_,Done = [],[],[],[],[]
        
        for item in self.data:
            s,a,r,s_,done = item
            S.append(s)
            A.append([a])
            R.append([r/100.0])
            S_.append(s_)
            if done:
                d = 0
            else :
                d = 1
            D.append([d])
        
        s_batch = torch.tensor(S, dtype=torch.float)
        a_batch = torch.tensor(A, dtype=torch.float),
        r_batch = torch.tensor(R, dtype=torch.float),
        s2_batch = torch.tensor(S_, dtype=torch.float),
        d_batch = torch.tensor(D, dtype=torch.float),
        self.data= []
        
        return s_batch,a_batch,r_batch,s2_batch,d_batch
    
    def train(self):
        s,a,r,s_,done = self.batch()
        TD_error = (r+gamma*self.v(s_)*done) - self.v(s)
        
        pi = self.pi(s,dim=1)
        pi_a = pi.gather(1,a)
        loss_1 = - torch.log(pi_a)*TD_error.detach() # detach : gradient 계산 안되는 상수 취급위해!
        loss_2 = F.smooth_l1_loss(self.v(s), TD_error.detach()) 
        loss = loss_1 + loss_2
        
        self.opt.zero_grad()
        loss.mean().backward()
        self.opt.step()     