In [5]:
%load_ext autoreload
%autoreload 2

In [2]:
import gymnasium as gym
from utils.replay_buffer import ReplayBuffer
from policies.dqn import DQN
from torch import nn

env = gym.make("CartPole-v1", render_mode="human")

buffer = ReplayBuffer(capacity=10000)
module = nn.Sequential(
    nn.Linear(4, 128),
    nn.Linear(128, 2)
)
policy = DQN(
    nn_module=module,
    state_dim=4,
    action_dim=2,
    eps=0.01,
    gamma=0.98,
)

results = []

In [3]:
for epi in range(100):
    observation, info = env.reset(seed=42)
    terminated = False
    truncated = False
    epi_len = 0
    total_return = 0

    while not terminated and not truncated:
        action = policy(observation)
        prev_obs = observation
        observation, reward, terminated, truncated, info = env.step(action)
        buffer.add(prev_obs, action, reward, observation, terminated, truncated)
        epi_len += 1
        total_return += reward

        if buffer.size() > 64:
            sampled = buffer.sample(10)
            policy.update(sampled)

        env.render()
    
    print("epi: {}; len: {}; return: {}".format(epi, epi_len, total_return))
    results.append((epi_len, total_return))
env.close()


epi: 0; len: 87; return: 87.0
epi: 1; len: 85; return: 85.0
epi: 2; len: 72; return: 72.0
epi: 3; len: 74; return: 74.0
epi: 4; len: 94; return: 94.0
epi: 5; len: 80; return: 80.0
epi: 6; len: 100; return: 100.0
epi: 7; len: 120; return: 120.0
epi: 8; len: 63; return: 63.0
epi: 9; len: 71; return: 71.0
epi: 10; len: 78; return: 78.0
epi: 11; len: 65; return: 65.0
epi: 12; len: 105; return: 105.0
epi: 13; len: 112; return: 112.0
epi: 14; len: 67; return: 67.0
epi: 15; len: 71; return: 71.0
epi: 16; len: 80; return: 80.0
epi: 17; len: 86; return: 86.0
epi: 18; len: 122; return: 122.0
epi: 19; len: 30; return: 30.0
epi: 20; len: 70; return: 70.0
epi: 21; len: 23; return: 23.0
epi: 22; len: 22; return: 22.0
epi: 23; len: 36; return: 36.0
epi: 24; len: 27; return: 27.0
epi: 25; len: 57; return: 57.0
epi: 26; len: 78; return: 78.0
epi: 27; len: 16; return: 16.0
epi: 28; len: 47; return: 47.0
epi: 29; len: 68; return: 68.0
epi: 30; len: 74; return: 74.0
epi: 31; len: 26; return: 26.0
epi: 32;