In [1]:
from mss import mss
import pydirectinput
import cv2
import numpy as np
import pytesseract
from matplotlib import pyplot as plt
import time
from gym import Env
from gym.spaces import Box, Discrete

In [14]:
### Create Environment

class WebGame(Env):
    ## setup the environment action and observation shapes
    def __init__(self):
        super().__init__()
        self.observation_space = Box(low=0, high=255, shape=(1,83,100), dtype=np.uint8)
        self.action_space = Discrete(3)
        # Define extraction parameters for the game
        self.cap = mss()
        self.game_location = {'top': 300, 'left': 0, 'width':600, 'height':500}
        self.done_location = {'top': 405, 'left': 630, 'width':660, 'height':70}



    ## what is called to do seomthing in the game
    def step(self, action):
        ## Action key - 0 = space, 1 = Duck(down), 2 = No action (no op)
        action_map = {
            0: 'space',
            1: 'down',
            2: 'no_op'
        }
        if action !=2:
            pydirectinput.press(action_map[action])
        
        ### checking whether the game is done
        done, done_cap = self.get_done()
        ### Get the next observation
        new_observation = self.get_observation()
        ## Reward - we get a point for every frame we're alive
        reward=1
        #Info dictionary
        info={}

        return new_observation, reward, done, info
    


    ## Visualize the game
    def render(self):
        cv2.imshow('Game', np.array(self.cap.grab(self.game_location))[:,:,:3])
        # plt.imshow('Game', np.array(self.cap.grab(self.game_location)))
        if cv2.waitkey(1) & 0xFF == ord('q'):
            self.close




    ### Restart the game
    def reset(self):
        time.sleep(1)
        pydirectinput.click(x=150, y=150)
        pydirectinput.press('space')
        return self.get_observation()

    


    ## This closes down the observation
    def close(self):
        cv2.destroyAllWindows()

    


    ## Get the part of the observation of the game that we want
    def get_observation(self):
        ## Get screen capture of the game
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3].astype(np.uint8)
        ## Grayscale
        gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
        ## Resize
        resized = cv2.resize(gray, (100, 83))
        ## Add cahnnels first
        channel = np.reshape(resized, (1, 83, 100))
        return channel




    ## Get the done text
    def get_done(self):
        ## Get done screen
        done_cap = np.array(self.cap.grab(self.done_location))[:, :, :3]
        ## Valid done text
        done_strings = ['GAME', 'GAHE']

        ## Apply OCR
        doen = False
        res = pytesseract.image_to_string(done_cap)[:4]
        if res in done_strings:
            done=True
        
        return done, done_cap

In [15]:
env = WebGame()

ScreenShotError: $DISPLAY not set.

In [16]:
np.array(env.get_observation())

array(None, dtype=object)

In [9]:
env.action_space.sample()

1

In [11]:
env.observation_space.sample()

array([[[ 48, 165, 204, ...,  52, 186,  19],
        [ 79, 158, 181, ...,   6, 146, 127],
        [103, 247,  33, ..., 189, 128, 149],
        ...,
        [ 57,   3,  53, ...,  50,  64, 237],
        [ 67, 102, 218, ..., 195,  33, 188],
        [154, 200,  14, ..., 231,  84, 134]]], dtype=uint8)

In [None]:
### TEst Environment
env = WebGame()


## Play 10 games
for episode in range(10):
    obs = env.reset()
    done = False
    total_reward=0

    while not done:
        obs, reward, done, info = env.step(env.action_space.sample())
        total_reward += reward
    print(f'Total Reward for episode {episode} is {total_reward}')


In [None]:
#### Create Callback
import os
from stable_baseline3.common.callbacks import BaseCallback
from stable_baseline3.common import env_checker

In [None]:
env_checker.check_env(env)

In [None]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__()
        self.check_freq = check_freq
        self.save_path = save_path


    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)


        return True

In [None]:
CHECKPOINT_DIR = './run/checkpoints/'
LOG_DIR = './run/logs/'

In [None]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

In [None]:
### Build DQN and Train
from stable_baseline3 import DQN

In [None]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, 
            buffer_size=1200000, learning_starts=0)

In [None]:
### Train model
model.learn(total_timesteps=100000, callback=callback)

In [None]:
## load model
model = DQN.load(os.path.join('train_first', 'best_model_88000'))

In [None]:
#### Test model
for episode in range(10):
    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        # time.sleep(0.01)
        total_reward += reward
    
    print(f'Total Reward for episode {episode} is {total_reward}')
    # time.sleep(2)