In [None]:
import torch
from torch import nn
import gymnasium
import plotly.express as px
from collections import deque
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
env = gymnasium.make("CartPole-v1")
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
print(f"Input dimension: {input_dim}, Output dimension: {output_dim}")

In [None]:
class Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(self).__init__()
        self.mod = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )
    def forward(self, x):
        return self.mod(x)
    @torch.no_grad()
    def predict(self, x):
        return self.mod(x).argmax(dim=1)



In [None]:


class ReplayMemory():
    def __init__(self, max_capacity):
        self.memory = deque(maxlen=max_capacity)
        self.max_capacity = 10000

    def add(self, transition):
        if len(self.memory) >= self.max_capacity:
            self.memory.popleft()
        self.memory.append(transition)
        
    def sample(self, batch_size):
        ln = len(self.memory)
        indices = torch.randint(0, ln, (min(batch_size, ln),))
        return torch.stack([self.memory[i] for i in indices])

In [None]:
model = Model(input_dim, output_dim).to(device)
memory = ReplayMemory(1e7)

In [None]:
batch_size = 128
learning_rate = 1e-3
optimizer = torch.optim.Adam
n_episodes = 1000


optimizer = optimizer(model.parameters(), lr=learning_rate)

In [None]:
for episode in tqdm(range(n_episodes)):
    state, info = env.reset()