In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import gymnasium as gym
from dataclasses import dataclass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2025-02-05 07:58:48.057872: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-05 07:58:48.059679: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-05 07:58:48.068542: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-05 07:58:48.103284: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738735128.165780   24610 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738735128.18

In [65]:
# Params
BATCH_SIZE = 128
PERCENTILE = 70

In [52]:
model = nn.Sequential(
    nn.Linear(1, 100),
    nn.ReLU(),

    nn.Linear(100, 50),
    nn.ReLU(),
    
    nn.Linear(50, 25),
    nn.ReLU(),
    
    nn.Linear(25, 4),
)
model = model.to(device)

In [53]:
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=0.01)

In [54]:
@dataclass
class Episode:
    observations: list[list[float]]
    actions: list[int]
    reward: float

In [67]:
writer = SummaryWriter()
epochs = 100

env = gym.make('FrozenLake-v1', render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, './videos')
action_space = list(range(env.action_space.n))


for epoch in range(epochs):
    best_of_the_best = []

    # episodes loop
    for batch in range(BATCH_SIZE):
        obs, _ = env.reset()
        episode = Episode(observations=[], actions=[], reward=0)

        # game loop
        while True:
            obs_t = torch.tensor([obs], dtype=torch.float32, device=device)
            action_pred = model(obs_t)
            action_p_dist = torch.softmax(action_pred, dim=0)
            action = np.random.choice(
                action_space, p=action_p_dist.detach().cpu().numpy()
            )

            next_obs, reward, terminated, truncated, _ = env.step(action)

            if terminated or truncated:
                if terminated:
                    episode.reward -= 10
                break

            episode.observations.append(obs)
            episode.actions.append(action)
            episode.reward += reward * 100
            episode.reward -= 1
            episode.reward += obs

            obs = next_obs

        if episode.reward > 70:
            best_of_the_best.append(episode)

    # pick only elite episodes
    if len(best_of_the_best) == 0:
        continue
    rewards = np.array([episode.reward for episode in best_of_the_best])
    percentile = np.percentile(rewards, PERCENTILE)
    reward_mean = np.mean(rewards)
    best_max = np.max(rewards)

    writer.add_scalar('reward/mean', reward_mean, epoch)
    writer.add_scalar('reward/percentile', percentile, epoch)
    writer.add_scalar('reward/best', best_max, epoch)

    opt.zero_grad()

    for episode in best_of_the_best:
        if episode.reward < percentile:
            continue

        obs_t = torch.tensor(episode.observations, dtype=torch.float32, device=device).view(-1, 1) # shape (n, 1)
        action_target = torch.tensor(episode.actions, dtype=torch.long, device=device) # shape (n)
        action_pred = model(obs_t) # shape (n, 4)
        loss = criterion(action_pred, action_target)

        loss.backward()

    opt.step()

    print(f'epoch {epoch} reward mean: {reward_mean} percentile: {percentile} best: {best_max}')

  logger.warn(


epoch 0 reward mean: 97.0 percentile: 82.4 best: 168.0
epoch 1 reward mean: 102.25 percentile: 109.19999999999999 best: 145.0
epoch 2 reward mean: 89.27272727272727 percentile: 89.0 best: 140.0
epoch 3 reward mean: 93.5 percentile: 101.7 best: 114.0
epoch 4 reward mean: 78.8 percentile: 82.2 best: 85.0
epoch 5 reward mean: 98.5 percentile: 103.3 best: 141.0
epoch 6 reward mean: 110.54545454545455 percentile: 118.0 best: 149.0
epoch 7 reward mean: 97.8 percentile: 99.2 best: 125.0
epoch 8 reward mean: 105.77777777777777 percentile: 125.2 best: 144.0
epoch 9 reward mean: 100.42857142857143 percentile: 107.0 best: 168.0
epoch 10 reward mean: 103.25 percentile: 106.0 best: 140.0
epoch 11 reward mean: 122.6 percentile: 130.5 best: 203.0
epoch 12 reward mean: 101.5 percentile: 113.2 best: 130.0
epoch 13 reward mean: 128.125 percentile: 130.0 best: 225.0
epoch 14 reward mean: 96.75 percentile: 87.3 best: 154.0
epoch 15 reward mean: 109.26666666666667 percentile: 122.39999999999999 best: 168.0

In [68]:
torch.save(model.state_dict(), './weights.pth')
torch.save(model, './complete_model.pth')