In [54]:
import torch
from torch import nn
import numpy as np
import gymnasium as gym

LR = 1e-3
LR_WEIGHTS = 1e-4
DF = 0.9

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [55]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

In [56]:
model = NeuralNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
weights = np.zeros(256)
weights = np.append(weights, 1)

In [57]:
class Bin2D:
    def __init__(self, posx, posy, size):
        self.posx = posx
        self.posy = posy
        self.size = size  # side length of the bin

    def check(self, x, y):
        if self.posx <= x < self.posx + self.size and self.posy <= y < self.posy + self.size:
            return 1
        return 0


class Tiling2D:
    def __init__(self, offset_x, offset_y, width, height, n_bins):
        self.offset_x = offset_x
        self.offset_y = offset_y
        self.width = width
        self.height = height
        self.n_bins = n_bins

    def setup(self):
        """Create grid of bins for this tiling."""
        self.bins = []
        cell_w = self.width / self.n_bins
        cell_h = self.height / self.n_bins

        for i in range(self.n_bins):
            for j in range(self.n_bins):
                x = self.offset_x + i * cell_w
                y = self.offset_y + j * cell_h
                self.bins.append(Bin2D(x, y, cell_w))

    def check(self, x, y):
        """Return binary vector for this tiling."""
        return [b.check(x, y) for b in self.bins]


class Tile2D:
    def __init__(self, x_range, y_range, n_tilings, n_bins):
        self.x_range = x_range
        self.y_range = y_range
        self.n_tilings = n_tilings
        self.n_bins = n_bins

    def setup(self):
        """Create multiple slightly offset tilings."""
        self.tilings = []
        x_min, x_max = self.x_range
        y_min, y_max = self.y_range
        width = x_max - x_min
        height = y_max - y_min

        for i in range(self.n_tilings):
            # small offset for each tiling (staggered grids)
            offset_x = x_min + (i / self.n_tilings) * (width / self.n_bins)
            offset_y = y_min + (i / self.n_tilings) * (height / self.n_bins)
            tiling = Tiling2D(offset_x, offset_y, width, height, self.n_bins)
            tiling.setup()
            self.tilings.append(tiling)

    def check(self, x, y):
        """Return flattened binary vector of all tilings."""
        features = []
        for tiling in self.tilings:
            features.extend(tiling.check(x, y))
        return np.array(features, dtype=np.float32)


def x_of_s_a(s, a, tile2d, n_actions=3):
    """Return full state-action feature vector x(s,a)."""
    phi = tile2d.check(*s)
    n = len(phi)
    x = np.zeros(n_actions * n)
    x[a * n:(a + 1) * n] = phi
    return x

pos_range = (-2.4, 2.4)
pol_angle_range = (-0.2095, 0.2095)
cart_velocity = (-4, 4)
pol_ang_vel = (-4, 4)

tile1 = Tile2D(pos_range, cart_velocity, n_tilings=4, n_bins=4)
tile2 = Tile2D(pol_angle_range, pol_ang_vel, n_tilings=4, n_bins=4)

tile1.setup()
tile2.setup()

tiles = [tile1, tile2]

def create_feature_vector_ntiles(obs, action):
    global tiles
    feat_v = []
    feat_v.extend(x_of_s_a(obs[:2], action, tiles[0], 2))
    feat_v.extend(x_of_s_a(obs[2:], action, tiles[1], 2))
    return np.array(feat_v)

def get_state_val(obs_vec):
    global weights
    return np.dot(obs_vec, weights[:-1]) + weights[-1]

In [61]:
def training_loop_REINFORCE(epochs):
    global model
    global optimizer
    global weights

    env = gym.make("CartPole-v1")

    for ep in range(epochs):
        terminated = truncated = False
        episode = []
        obs, _ = env.reset()

        while not (terminated or truncated):
            obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            logits = model(obs_t).squeeze(0)
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)

            old_obs = obs
            obs, reward, terminated, truncated, _ = env.step(action.item())
            episode.append((old_obs, action, reward, log_prob))

        returns = []
        G = 0
        for obs, action, reward, _ in reversed(episode):
            G = reward + DF * G
            feature_vec = create_feature_vector_ntiles(obs, action)
            td_error = G - get_state_val(feature_vec)
            feature_vec = np.append(feature_vec, 1)
            weights += LR_WEIGHTS*td_error*feature_vec
            returns.insert(0, td_error)


        returns = torch.tensor(returns, dtype=torch.float32, device=device)
        #returns = (returns - returns.mean()) / (returns.std(unbiased=False) + 1e-8)

        # Compute loss
        loss = 0
        for (_, _, _, log_prob), G in zip(episode, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Episode {ep+1}, loss={loss.item():.4f}")

    env.close()

def testing_loop_REINFORCE(epochs, nsteps):
    global model
    global optimizer
    env = gym.make("CartPole-v1", render_mode="human")
    obs, _ = env.reset()
    step = 0

    for ep in range(epochs):
        terminated = truncated = False

        while not (terminated or step > nsteps):
            obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            logits = model(obs_t).squeeze(0)
            dist = torch.distributions.Categorical(logits = logits)
            action = dist.sample()

            obs, _, terminated, truncated, _ = env.step(action.item())
            step += 1

    env.close()


In [59]:
training_loop_REINFORCE(800)

Episode 1, loss=32.4451
Episode 2, loss=131.1698
Episode 3, loss=33.2898
Episode 4, loss=23.8246
Episode 5, loss=88.6883
Episode 6, loss=45.7618
Episode 7, loss=22.8177
Episode 8, loss=99.8301
Episode 9, loss=70.9209
Episode 10, loss=192.2784
Episode 11, loss=22.3662
Episode 12, loss=116.0356
Episode 13, loss=107.8345
Episode 14, loss=115.2021
Episode 15, loss=30.1538
Episode 16, loss=79.4904
Episode 17, loss=58.0739
Episode 18, loss=225.3856
Episode 19, loss=222.1423
Episode 20, loss=42.1475
Episode 21, loss=201.6336
Episode 22, loss=70.5290
Episode 23, loss=34.1335
Episode 24, loss=28.5818
Episode 25, loss=55.2704
Episode 26, loss=64.6542
Episode 27, loss=43.3510
Episode 28, loss=275.1664
Episode 29, loss=67.9917
Episode 30, loss=115.7313
Episode 31, loss=247.5483
Episode 32, loss=40.7236
Episode 33, loss=68.3396
Episode 34, loss=18.0804
Episode 35, loss=57.3896
Episode 36, loss=42.6932
Episode 37, loss=437.2654
Episode 38, loss=149.1047
Episode 39, loss=33.8935
Episode 40, loss=32.2

In [63]:
testing_loop_REINFORCE(1, 2000)