In [1]:
import gymnasium as gym
from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 0.995
LR = 0.001
MEMORY_SIZE = 10000

In [5]:
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.model = DQN(state_size, action_size).to(device)
        self.target_model = DQN(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=LR)
        self.epsilon = EPS_START

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action_values = self.model(state)
        return action_values.argmax().item()

    def replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device)

        current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
        next_q_values = self.target_model(next_states).max(1)[0].detach()
        target_q_values = rewards + (1 - dones) * GAMMA * next_q_values

        loss = nn.MSELoss()(current_q_values, target_q_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(EPS_END, self.epsilon * EPS_DECAY)

In [6]:
path_of_video_with_name = './video/dqn_cartpole.mp4'
env = gym.make('CartPole-v1', render_mode='rgb_array')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

agent = DQNAgent(state_size, action_size)
video_recorder = VideoRecorder(env, path_of_video_with_name, enabled=True)

num_episodes = 250
max_steps = 500

for episode in range(num_episodes):
    state, _ = env.reset()
    total_reward = 0

    for step in range(max_steps):
        env.render()
        video_recorder.capture_frame()
        
        action = agent.act(state)
        next_state, reward, done, _, _ = env.step(action)
        
        agent.remember(state, action, reward, next_state, done)
        agent.replay()
        
        state = next_state
        total_reward += reward

        if done:
            break

    agent.decay_epsilon()
    agent.update_target_model()

    print(f"ep: {episode + 1}, totalreward: {total_reward}, epsilion: {agent.epsilon:.2f}")

    if episode % 10 == 0:
        torch.save(agent.model.state_dict(), f'./folder/dqn_model_episode_{episode}.pth')

print("Training completed.")

video_recorder.close()
video_recorder.enabled = False
env.close()

Episode: 1, Total Reward: 48.0, Epsilon: 0.99
Episode: 2, Total Reward: 13.0, Epsilon: 0.99


  states = torch.FloatTensor(states).to(device)


Episode: 3, Total Reward: 37.0, Epsilon: 0.99
Episode: 4, Total Reward: 37.0, Epsilon: 0.98
Episode: 5, Total Reward: 16.0, Epsilon: 0.98
Episode: 6, Total Reward: 14.0, Epsilon: 0.97
Episode: 7, Total Reward: 15.0, Epsilon: 0.97
Episode: 8, Total Reward: 18.0, Epsilon: 0.96
Episode: 9, Total Reward: 27.0, Epsilon: 0.96
Episode: 10, Total Reward: 17.0, Epsilon: 0.95
Episode: 11, Total Reward: 27.0, Epsilon: 0.95
Episode: 12, Total Reward: 14.0, Epsilon: 0.94
Episode: 13, Total Reward: 12.0, Epsilon: 0.94
Episode: 14, Total Reward: 25.0, Epsilon: 0.93
Episode: 15, Total Reward: 20.0, Epsilon: 0.93
Episode: 16, Total Reward: 14.0, Epsilon: 0.92
Episode: 17, Total Reward: 16.0, Epsilon: 0.92
Episode: 18, Total Reward: 14.0, Epsilon: 0.91
Episode: 19, Total Reward: 14.0, Epsilon: 0.91
Episode: 20, Total Reward: 19.0, Epsilon: 0.90
Episode: 21, Total Reward: 14.0, Epsilon: 0.90
Episode: 22, Total Reward: 17.0, Epsilon: 0.90
Episode: 23, Total Reward: 34.0, Epsilon: 0.89
Episode: 24, Total R

                                                                   

Moviepy - Done !
Moviepy - video ready ./video/dqn_cartpole.mp4
