In [None]:
# PPO with RNN , using one frame
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy
from torch.distributions import Categorical
import time
import numpy as np
#Hyperparameters
learning_rate = 0.0005
gamma         = 0.98
lmbda         = 0.95
eps_clip      = 0.2
K_epoch       = 5
T_horizon     = 20
class PPO(nn.Module):
    def __init__(self, hidden):
        super(PPO, self).__init__()
        self.data = []
        self.hidden = hidden
        
        self.rc1   = nn.RNN(input_size = 4, hidden_size = self.hidden, num_layers = 1, batch_first=True)
        self.rc2   = nn.RNN(input_size = 4, hidden_size = self.hidden, num_layers = 1, batch_first=True)
        self.fc0 = nn.Linear(self.hidden,self.hidden*10)
        self.fc1 = nn.Linear(self.hidden,self.hidden*10)
        self.fc_pi = nn.Linear(self.hidden*10,2)
        self.fc_v  = nn.Linear(self.hidden*10,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        
    def pi(self, x, h0, softmax_dim = 1):
        x = x.reshape(-1,1,4)
        x, hidden_state = self.rc1(x,h0)
        # hidden_state : layer X batch X node 수
        #x = x[:,-1,:] # batch X length X hidden -> batch X hidden
        x = x.reshape(-1,self.hidden)
        
        x = F.relu(self.fc0(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        torch.nan_to_num(prob, nan = 0.0)
        return prob, hidden_state
    
    def v(self, x,h0): # input : batch X 10 X 10
        x = x.reshape(-1,1,4)
        x, _ = self.rc2(x,h0)
        #x = x[:,-1,:]
        x = x.reshape(-1, self.hidden)
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        
        return v
      
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, h_in_lst, h_out_lst, done_lst = [], [], [], [], [], [], [], []

        for transition in self.data:
            s, a, r, s_prime, prob_a, h_in, h_out, done = transition
            
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            h_in_lst.append(h_in.detach().numpy())
            h_out_lst.append(h_out.detach().numpy())
            done_mask = 0 if done else 1
            done_lst.append([done_mask])
        
        s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                          torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
                                          torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
        self.data = []
        
        h_in_lst = torch.tensor(h_in_lst,dtype = torch.float)
        h_out_lst = torch.tensor(h_out_lst, dtype = torch.float)
        h_in_lst = h_in_lst.squeeze(2)
        h_out_lst = h_out_lst.squeeze(2)
        h_in_lst = h_in_lst.permute(1,0,2)
        h_out_lst = h_out_lst.permute(1,0,2)
        
        # num_layer X batch X hidden으로 만들기
        return s, a, r, s_prime, done_mask, prob_a, h_in_lst, h_out_lst
        
    def train_net(self):
        s, a, r, s_prime, done_mask, prob_a, h1, h2 = self.make_batch()

        h1 = h1.detach()
        h2 = h2.detach()
        
        td_target = r + gamma * self.v(s_prime, h2) * done_mask
        
        delta = td_target - self.v(s, h1)
        delta = delta.detach().numpy()

        advantage_lst = []
        advantage = 0.0
        for delta_t in delta[::-1]:
            advantage = gamma * lmbda * advantage + delta_t[0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantage = torch.tensor(advantage_lst, dtype=torch.float) 
    
        for i in range(K_epoch):
            
            pi, _ = self.pi(s, h1 ,softmax_dim=1) 
            pi_a = pi.gather(1,a) 
           
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b)) ##############

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage

            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s, h1) , td_target.detach()) 

            self.optimizer.zero_grad()
            loss.mean().backward(retain_graph = True) # 여러 번 loss의 gradient를 구할 때
            self.optimizer.step()
        
def main():
    hidden = 32
    env = gym.make('CartPole-v1')
    model = PPO(hidden)
    #model.load_state_dict(torch.load('ppo_ddong_with_dongmin.pt'))

    score = 0.0
    print_interval = 100
    best_score = 0.0
    one_score = 0.0
    
    for n_epi in range(10000):
        s = env.reset()
        done = False
        h_out = torch.zeros([1,1,hidden],dtype = torch.float) #num_layers X batch X hidden size

        while not done:
            for t in range(T_horizon):
                h_in = h_out
                prob, h_out = model.pi(torch.from_numpy(s).float(), h_in)
                # 1 X 3 | 1 X 1 X hidden
                prob = prob.view(-1) # 여기에선 batch가 1이니까 batch를 없애준다.
                
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)

                model.put_data((s, a, r, s_prime, prob[a].item(), h_in, h_out, done))
                s = s_prime.copy()
                
                score += r
                one_score += r
                if done:
                    break
            
            model.train_net()
            
        if best_score < one_score:
            best_score = one_score
        #    print("... save model ...")
            print(f"best score : {one_score}")
        #    torch.save(model.state_dict(),'ppo_ddong_with_dongmin.pt')
            
        one_score = 0
            
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
            score = 0.0
    
    return best_score

if __name__ == '__main__':
    best_score = main()
    print(f"final best_score : {best_score}")
    
