In [3]:
import gymnasium as gym
from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder
import torch
import torch.nn as nn
import numpy as np

# Define the DQN model (same as in the training script)
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Set up the environment
env = gym.make('CartPole-v1', render_mode='rgb_array')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# Set up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model
model = DQN(state_size, action_size).to(device)
model.load_state_dict(torch.load('./folder/dqn_model_episode_240.pth'))  # Adjust the filename as needed
model.eval()

# Set up video recording
video_recorder = VideoRecorder(env, './video/dqn_cartpole_test.mp4', enabled=True)

# Run the test episodes
num_test_episodes = 5
max_steps = 500

for episode in range(num_test_episodes):
    state, _ = env.reset()
    total_reward = 0

    for step in range(max_steps):
        env.render()
        video_recorder.capture_frame()

        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action = model(state_tensor).argmax().item()

        next_state, reward, done, _, _ = env.step(action)
        state = next_state
        total_reward += reward

        if done:
            break

    print(f"Test Episode: {episode + 1}, Total Reward: {total_reward}")

print("Testing completed. Saved video.")

video_recorder.close()
video_recorder.enabled = False
env.close()

Test Episode: 1, Total Reward: 500.0
Test Episode: 2, Total Reward: 500.0
Test Episode: 3, Total Reward: 500.0
Test Episode: 4, Total Reward: 500.0
Test Episode: 5, Total Reward: 500.0
Testing completed. Saved video.
Moviepy - Building video ./video/dqn_cartpole_test.mp4.
Moviepy - Writing video ./video/dqn_cartpole_test.mp4



                                                                  

Moviepy - Done !
Moviepy - video ready ./video/dqn_cartpole_test.mp4
