Agent의 policy의 weight를 모두 1으로 초기화하면 학습이 거의 되지 않는다.

In [1]:
import gymnasium
import numpy as np
import torch

In [2]:
env = gymnasium.make('CartPole-v1')

In [3]:
gamma = 0.99
def discount_rewards(r):
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(range(r.size)):
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r
discount_rewards(np.array([3, 2, 1, 3, 2.0]))

array([10.79218902,  7.870898  ,  5.9302    ,  4.98      ,  2.        ])

In [4]:
env.reset()

(array([-0.02066758,  0.04918516,  0.00725879, -0.04990533], dtype=float32),
 {})

In [5]:
class Agent:
    def __init__(self, lr, s_size, a_size, h_size):
        self.policy = torch.nn.Sequential(
            torch.nn.Linear(s_size, h_size, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(h_size, a_size, bias=False),
            torch.nn.Softmax(dim=-1),
        )
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)

    def policy_output(self, state):
        return self.policy(state)

    def loss(self, action_logits, rewards):
        loss = -rewards * torch.log(action_logits)
        return loss
    
    def update(self, losses):
        self.optimizer.zero_grad()
        losses.backward()
        self.optimizer.step()

    def action(self, state):
        x = self.policy(state)
        x = torch.multinomial(x, 1)
        return x.cpu().detach().item()

agent = Agent(1e-3, 4, 2, 8)
observation, info = env.reset()
output = agent.policy_output(torch.tensor(observation))
action = torch.multinomial(output, 1)
observation, reward, terminated, truncated, info = env.step(action.cpu().detach().item())
observation, reward, terminated, truncated, info

(array([ 0.04350268,  0.15241024,  0.04872172, -0.28389737], dtype=float32),
 1.0,
 False,
 False,
 {})

In [6]:
total_episode = 10000
max_ep = 999
update_frequency = 5

running_rewards = []
for i in range(total_episode):
    observation, info = env.reset()
    action_logits = []
    rewards = []
    losses = []
    for j in range(max_ep):
        output = agent.policy_output(torch.tensor(observation))
        action = torch.multinomial(output, 1)
        observation, reward, terminated, truncated, info = env.step(action.cpu().detach().item())
        action_logits.append(output[action])
        rewards.append(reward)
        if terminated:
            running_rewards.append(sum(rewards))
            action_logits = torch.concat(action_logits)
            rewards = torch.tensor(discount_rewards(np.array(rewards)))
            loss = agent.loss(action_logits, rewards)
            losses.append(loss)
            if i % update_frequency == 0:
                agent.update(torch.concat(losses).sum())
                losses.clear()
            if i % 1000 == 999:
                print(np.mean(running_rewards[-100:]))
            break

28.77
35.88
43.26
51.41
71.89
128.03
301.52
394.08
419.29
376.61


In [7]:
env = gymnasium.make('CartPole-v1', render_mode='human')
observation, info = env.reset()
while True:
    action = agent.action(torch.tensor(observation))
    observation, reward, terminated, truncated, info = env.step(action)
    env.render()
    if terminated:
        env.close()
        break