In [489]:
!pip install gymnasium

In [490]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

In [491]:
#parameters
learning_rate = 0.01
discount_factor = 0.99

In [492]:
class Agent(nn.Module):
    def __init__(self, input_size , output_size):
        super(Agent,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 32),
            nn.ReLU(),
            nn.Linear(32, output_size),
        )
        self.rewards = []
        self.log_probs = []
    
    def forward(self, x):
        res = self.model(x)
        return res

    def reset(self):
        self.rewards = []
        self.log_probs = []

    def action(self, state):
        x = torch.from_numpy(state.astype(np.float32))
        out = self.forward(x)
        output = Categorical(logits=out)
        action = output.sample()
        log_prob = output.log_prob(action)
        self.log_probs.append(log_prob)
        return action.item()
        


In [493]:
def train(agent, optimizer):
    T = len(agent.rewards)
    rets = np.empty(T, dtype=np.float32)
    future_ret = 0.0
    for t in reversed(range(T)):
        future_ret = agent.rewards[t] + discount_factor * future_ret
        rets[t] = future_ret
    rets = torch.tensor(rets)
    log_probs = torch.stack(agent.log_probs)
    loss = - log_probs * rets
    loss = torch.sum(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [494]:
env = gym.make('CartPole-v1')

In [495]:
input_size  = env.observation_space.shape[0]
output_size = env.action_space.n

In [496]:
agent = Agent(input_size, output_size)
optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

In [None]:
res = []
episode = 1
while( episode <= 1000 and (episode <= 200 or np.sum(res[-10:]) != 2000)):
    state , _ = env.reset()
    for t in range(200):
        action = agent.action(state)
        state, reward, terminated , truncated, _ = env.step(action)
        agent.rewards.append(reward)
        env.render()
        if terminated or truncated:
            break
    loss = train(agent, optimizer)
    total_reward = sum(agent.rewards)
    solved = total_reward > 195.0
    agent.reset()
    res.append(total_reward)
    episode += 1
env.close()

In [None]:
plt.plot(res)

In [None]:
episode

In [500]:
state , _ = env.reset()
for t in range(200):
    action = agent.action(state)
    state, reward, terminated , truncated, _ = env.step(action)
    agent.rewards.append(reward)
    env.render()
    if terminated or truncated:
        break
print(sum(agent.rewards))

200.0


  gym.logger.warn(
