In [1]:
# !{os.sys.executable} -m pip install gymnasium
# !{os.sys.executable} -m pip install Pillow
# !{os.sys.executable} -m pip install ipython
# !{os.sys.executable} -m pip install pygame
# !{os.sys.executable} -m pip install torchsummary
# !{os.sys.executable} -m pip install tensorboardX

In [2]:
# Import useful packages

import os
import sys
import gymnasium as gym
import random
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary

from DTQN_Model import DTQN

from tensorboardX import SummaryWriter

from collections import namedtuple, deque

In [3]:
env_name = 'CartPole-v1'
goal_score = 200
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
import pygame
from pygame.locals import QUIT

# Initialize pygame
pygame.init()
def play_game_with_model(model_name, load_best = False):
    # Initialize the environment
    env = gym.make(env_name, render_mode="human")
    state = env.reset()[0]
    done = False
    
    # Load the model
    state_size = env.observation_space.low.size
    action_size = env.action_space.n
    loaded_online_net = DTQN(state_size, action_size).to(device)
    if( not load_best):
        loaded_online_net.load_state_dict(torch.load(f'saved_models/DTQN/{model_name}/online_net.pth', map_location=device))
    else:
        loaded_online_net.load_state_dict(torch.load(f'saved_models/DTQN/{model_name}/best_online_net.pth', map_location=device))
    loaded_online_net.eval()  # Set the model to evaluation mode


    # Initialize the display
    screen = pygame.display.set_mode((600, 400))
    pygame.display.set_caption("CartPole with DTQN")
    clock = pygame.time.Clock()

    while not done:
        for event in pygame.event.get():
            if event.type == QUIT:
                print("Quit")
                pygame.quit()
                done = True
                return

        # Preprocess the state
        state_tensor = torch.Tensor(state).to(device)
        action = loaded_online_net.get_action(state_tensor)
        
        # Take the action in the environment
        # next_state, reward, done, _, _ = env.step(action)
        next_state, reward, terminated, truncated,_ = env.step(action)
        done = terminated or truncated
#         print("done", done)
        if done:
            state = env.reset()[0]
        else:
            state = next_state
        done = False

        # Render the environment
        env.render()
        
        # Limit the frame rate
        clock.tick(60)

    env.close()
    pygame.quit()

In [5]:
play_game_with_model('train_20240430_133324', load_best = False)


KeyboardInterrupt

