In [2]:
import random

import numpy as np

import gymnasium as gym

import torch
from torch import nn

In [6]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using mps device


In [7]:
class PolicyNet(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, act_dim)
        )
    
    def forward(self, x):
        return self.stack(x)

In [8]:
env = gym.make("CartPole-v1")

In [9]:
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [10]:
model_policy = PolicyNet(obs_dim, act_dim)

In [11]:
optimizer = torch.optim.Adam(model_policy.parameters(), lr=1e-3)

In [40]:
for episode in range(4000):
    obs, _ = env.reset()
    log_probs = []
    rewards = []

    done = False
    while not done:
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        logits = model_policy(obs_tensor)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        log_probs.append(dist.log_prob(action))

        obs, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        rewards.append(reward)

    # Calcular retorno total (suma de rewards)
    total_return = sum(rewards)

    # Actualizar policy (REINFORCE)
    loss = []
    for log_prob in log_probs:
        loss.append(-log_prob * total_return)
    loss = torch.stack(loss).sum()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Episode {episode}, Return: {total_return}")

env.close()

# Guardar modelo entrenado
torch.save(model_policy.state_dict(), "policy_cartpole.pt")

Episode 0, Return: 10.0
Episode 1, Return: 11.0
Episode 2, Return: 12.0
Episode 3, Return: 19.0
Episode 4, Return: 21.0
Episode 5, Return: 28.0
Episode 6, Return: 11.0
Episode 7, Return: 8.0
Episode 8, Return: 17.0
Episode 9, Return: 14.0
Episode 10, Return: 12.0
Episode 11, Return: 38.0
Episode 12, Return: 12.0
Episode 13, Return: 17.0
Episode 14, Return: 18.0
Episode 15, Return: 8.0
Episode 16, Return: 15.0
Episode 17, Return: 22.0
Episode 18, Return: 10.0
Episode 19, Return: 58.0
Episode 20, Return: 44.0
Episode 21, Return: 49.0
Episode 22, Return: 13.0
Episode 23, Return: 15.0
Episode 24, Return: 14.0
Episode 25, Return: 11.0
Episode 26, Return: 15.0
Episode 27, Return: 20.0
Episode 28, Return: 28.0
Episode 29, Return: 13.0
Episode 30, Return: 13.0
Episode 31, Return: 10.0
Episode 32, Return: 16.0
Episode 33, Return: 47.0
Episode 34, Return: 19.0
Episode 35, Return: 10.0
Episode 36, Return: 17.0
Episode 37, Return: 13.0
Episode 38, Return: 16.0
Episode 39, Return: 17.0
Episode 40, 

In [None]:
env = gym.make("CartPole-v1", render_mode="human")

obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

# Crear policy con la misma arquitectura
policy = PolicyNet(obs_dim, act_dim)
policy.load_state_dict(torch.load("policy_cartpole_modified.pt"))  # cargar pesos guardados
policy.eval()


for episode in range(10):
    obs, _ = env.reset()
    done = False

    while not done:
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        logits = policy(obs_tensor)
        probs = torch.softmax(logits, dim=-1)
        action = torch.argmax(probs, dim=-1).item()  # greedy (sin exploración)

        obs, reward, terminated, truncated, _ = env.step(action)
        if terminated:
            done = terminated
            print(f"episodio: {episode} - terminado")
        elif truncated:
            done = truncated
            print(f"episodio: {episode} - truncado")

env.close()

episodio: 0 - truncado
episodio: 1 - truncado
episodio: 2 - truncado
episodio: 3 - truncado
episodio: 4 - truncado
episodio: 5 - truncado
episodio: 6 - terminado
episodio: 7 - truncado
episodio: 8 - terminado
episodio: 9 - truncado


: 