In [35]:
import os
from gymnasium import Env
from gymnasium.spaces import Discrete, Box
import cv2
import matplotlib.pyplot as plt
import numpy as np
from mss import mss
from time import time, sleep
from pynput.keyboard import Key, Controller as KeyboardController
from pynput.mouse import Button, Controller as MouseController
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker
from stable_baselines3 import DQN

# Creating the Environment

In [27]:
class DinoGame(Env):
    
    def __init__(self):
        super().__init__()
        self.observation_space = Box(low=0, high=255, shape=(5, 80, 244), dtype=np.uint8)
        self.action_space = Discrete(3)
        self.capture = mss()
        self.gamedims = {"top": 300, "left": 60, "width": 1240, "height": 500}
        self.fps = 29.97
        self.over_frame = cv2.imread("dino-over.jpg", cv2.IMREAD_GRAYSCALE)
        self.keyboard = KeyboardController()
        self.mouse = MouseController()
    
    def step(self, action):
        action_map = {0: Key.up, 1: Key.down, 2: None}
        key = action_map[action]
        if key:
            self.keyboard.press(key)
            self.keyboard.release(key)
        done = self.get_done()
        next_observation = self.get_observation()
        reward = 1
        return next_observation, reward, done, False, {}
    
    def reset(self, seed=0):
        sleep(1.5)
        self.mouse.position = (500, 500)
        self.mouse.press(Button.left)
        self.mouse.release(Button.left)
        sleep(1.5)
        return self.get_observation(), {}
    
    def render(self):
        pass
    
    def close(self):
        pass

    def get_one_frame(self):
        frame = np.array(self.capture.grab(self.gamedims))
        frame = frame[:, :, :3]
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        frame = frame[100:1200, 20:1300]
        scale = 0.2
        frame = cv2.resize(frame, (0, 0), fx=scale, fy=scale)
        return frame

    def get_observation(self):
        frames = []
        for i in range(5):
            start = time()
            frame = self.get_one_frame()
            frames.append(frame)
            while time() - start < 1/self.fps: pass
        return np.array(frames)
    
    def get_done(self):
        frame = self.get_one_frame()
        result = cv2.matchTemplate(
            frame, self.over_frame, cv2.TM_CCOEFF_NORMED)
        threshold = 0.8
        loc = np.where(result >= threshold)
        return loc[0].size > 0

# Make your Environment Object

In [28]:
env = DinoGame()

# Try out the game

In [29]:
# print("Starting Simulation in 5 secs...")
# sleep(5)
# N_epi = 5

# for i in range(N_epi):
#     env.reset()
#     done = False
#     episode_reward = 0
#     while not done:
#         action = env.action_space.sample()
#         next_observation, reward, done, _ = env.step(action)
#         episode_reward += reward
#     print(f"Episode {i+1}: {episode_reward}")

# Check the Environment

In [30]:
env_checker.check_env(env)

# Define Callbacks

In [33]:
class TrainAndLoggingCallbacks(BaseCallback):
    
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallbacks, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, f"best_model_{self.n_calls}")
            self.model.save(model_path)
        return True

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

CHECKPOINT_PATH = "./dino-checkpoints/"
LOG_PATH = "./dino-logs/"

In [34]:
callback = TrainAndLoggingCallbacks(check_freq=1000, save_path=CHECKPOINT_PATH)

# Define Deep-Q Network Model

In [40]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_PATH, verbose=1,
            buffer_size=50000, learning_starts=1000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


# Start Training

In [41]:
model.learn(total_timesteps=5000, callback=callback)

Logging to ./dino-logs/DQN_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 26.5     |
|    ep_rew_mean      | 26.5     |
|    exploration_rate | 0.799    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 3        |
|    time_elapsed     | 35       |
|    total_timesteps  | 106      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 28.4     |
|    ep_rew_mean      | 28.4     |
|    exploration_rate | 0.569    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 3        |
|    time_elapsed     | 70       |
|    total_timesteps  | 227      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 28.1     |
|    ep_rew_mean      | 28.1     |
|    exploration_rate | 0.36     |
| time/               |   

KeyboardInterrupt: 