In [146]:
import torch
from torch import nn
import numpy as np
import gymnasium as gym

In [147]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

DF = 0.99
LR = 0.01
model = NeuralNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = LR)

def training_loop_REINFORCE(epochs):
    global model
    global optimizer
    env = gym.make("CartPole-v1")

    for ep in range(epochs):
        terminated = truncated = False
        episode = []
        obs, _ = env.reset()

        while not (terminated or truncated):
            logits = model(torch.tensor(obs))
            dist = torch.distributions.Categorical(logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)

            old_obs = obs
            obs, reward, terminated, truncated, _ = env.step(action.item())
            episode.append((old_obs, action, reward, log_prob))

        returns = []
        G = 0
        for _, _, reward, _ in reversed(episode):
            G = reward + DF * G
            returns.insert(0, G)

        returns = torch.tensor(returns, dtype=torch.float32, device=device)
        returns = (returns - returns.mean()) / (returns.std(unbiased=False) + 1e-8)

        # Compute loss
        loss = 0
        for (_, _, _, log_prob), G in zip(episode, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Episode {ep+1}, loss={loss.item():.4f}")

    env.close()

def testing_loop_REINFORCE(epochs):
    global model
    global optimizer
    env = gym.make("CartPole-v1", render_mode="human")
    obs, _ = env.reset()

    for ep in range(epochs):
        terminated = truncated = False

        while not (terminated or truncated):
            logits = model(torch.tensor(obs))
            dist = torch.distributions.Categorical(logits)
            action = dist.sample()

            obs, _, terminated, truncated, _ = env.step(action.item())

    env.close()


Using cpu device


In [148]:
training_loop_REINFORCE(300)
LR = 0.001
training_loop_REINFORCE(100)
LR = 0.0001
training_loop_REINFORCE(100)

Episode 1, loss=0.2029
Episode 2, loss=0.1983
Episode 3, loss=0.4970
Episode 4, loss=-0.6761
Episode 5, loss=-0.1370
Episode 6, loss=-0.7056
Episode 7, loss=-0.0431
Episode 8, loss=-0.0916
Episode 9, loss=0.0454
Episode 10, loss=-1.7428
Episode 11, loss=0.3668
Episode 12, loss=0.4730
Episode 13, loss=-2.7938
Episode 14, loss=-0.0422
Episode 15, loss=1.2884
Episode 16, loss=-0.2788
Episode 17, loss=1.8216
Episode 18, loss=0.7543
Episode 19, loss=-3.9014
Episode 20, loss=1.0946
Episode 21, loss=0.4778
Episode 22, loss=0.3583
Episode 23, loss=0.4806
Episode 24, loss=-0.9957
Episode 25, loss=0.1514
Episode 26, loss=-0.0561
Episode 27, loss=0.7286
Episode 28, loss=-0.6525
Episode 29, loss=-0.6882
Episode 30, loss=-1.2162
Episode 31, loss=0.6176
Episode 32, loss=-0.1195
Episode 33, loss=-0.5412
Episode 34, loss=-0.5497
Episode 35, loss=1.4280
Episode 36, loss=1.8412
Episode 37, loss=0.9333
Episode 38, loss=0.6575
Episode 39, loss=-0.5315
Episode 40, loss=0.2421
Episode 41, loss=0.2772
Episod

In [149]:
testing_loop_REINFORCE(1)