In [92]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from collections import Counter, defaultdict, deque
import os, sys, glob, copy, json, time, pickle

In [10]:
from chainer import Chain, ChainList, cuda, gradient_check, Function, Link, \
    optimizers, serializers, utils, Variable, dataset, datasets, using_config, training, iterators
from chainer.training import extensions
from chainer import functions as F
from chainer import links as L
import chainer

In [238]:
class Agent:
    """
    General Q-Learning Agent for an environment
    whose states are given by images and whose actions are discrete
    """
    def __init__(self, num_action):
        """
        num_action (int) : 可能な行動の個数．
        """
        self.sars_list = deque()
        self.num_action = num_action
        self.gamma = 0.95
        self.epsilon = 0.1
        self.iteration = 0
        self.batch_size = 10
        self.Q = QFunc(num_action)
        self.opt = optimizers.MomentumSGD().setup(self.Q)
        self.opt.add_hook(chainer.optimizer.GradientClipping(5.0))
        
    def observe(self, sars):
        """
        sars: [s0, a, r, s1]
            s0, s1: 3-dim image[c,h,w]，  a : int, r : float  (32bit-ize is not needed)
        """
        self.sars_list.append(sars)
        if len(self.sars_list) >= 10000:
            self.sars_list.popleft()            
        # 容量がしんどそう　　VAEなどで次元削減しておけばあるいは．
        # 驚きの小さかったイベントを優先的に忘れた方が良いのでは？
        self.iteration += 1
        if self.iteration % 10 == 0: self._learn()
    def _learn(self):
        idx = np.random.choice(len(self.sars_list), self.batch_size, replace=False)
        s0, a, r, s1 = zip(*[self.sars_list[i] for i in idx])
        s0, r, s1 = np.array(s0, dtype=np.float32), np.array(r, dtype=np.float32), np.array(s1, dtype=np.float32)
        a = np.array(a, dtype=np.int32)
        # calculate loss of Q-learning
        with chainer.no_backprop_mode():
            rhs = r + self.gamma * F.max(self.Q(s1), axis=1)
        lhs = F.select_item(self.Q(s0), a)
        loss = F.mean((rhs - lhs) ** 2) # scalar
        self.Q.cleargrads()
        loss.backward()
        self.opt.update()
    def determine(self, s):
        """
        s : single array => a : single int
        """
        if np.random.rand() > self.epsilon:
            return np.argmax(self.Q(Variable(np.array([s],dtype=np.float32))).data.flatten())
        else:
            return np.random.randint(self.num_action)
        
    def save(self, path):
        with open(path, 'wb') as f:
            pickle.dump(self, f)
            
    def load(self, path):
        with open(path, 'rb') as f:
            self.__dict__.update(pickle.load(f).__dict__)
        
        
class QFunc(Chain):
    def __init__(self, num_action):
        super().__init__()
        with self.init_scope():
            self.add_link('b1', L.BatchNormalization(3))
            self.add_link('c1', L.Convolution2D(3, 6, ksize=5, stride=4))
            self.add_link('b2', L.BatchNormalization(6))
            self.add_link('c2', L.Convolution2D(6, 12, ksize=5, stride=4))
            self.add_link('b3', L.BatchNormalization(12))
            self.add_link('c3', L.Convolution2D(12, 24, ksize=5, stride=4))
            self.add_link('l1', L.Linear(None, 64))
            self.add_link('l2', L.Linear(64, num_action))
    def __call__(self, x):
        '''
        x (Variable) => (Variable)
        '''
        assert x.ndim == 4
        assert x.shape[1] == 3
        h = x
        h = self.b1(h)
        h = F.relu(self.c1(h))
        h = self.b2(h)
        h = F.relu(self.c2(h))
        h = self.b3(h)
        h = F.relu(self.c3(h))
        h = F.relu(self.l1(h))
        h = self.l2(h)
        return h

In [None]:
import mss
class DesktopEnv:
    """
    desktop environment with openai-gym-like format
    """
    def __init__(self):
        self.sct = mss.mss()
        self.ltrb = (0, 0, 640, 480) # (l, t, r, b)
        
    def _reset(self):
        """
        Returns: initial state
        """
        return _get_screenshot()
    
    def _step(self, action):
        """
        action (single int) => new_state, reward, is_done, info
        """
        obs = _get_screenshot
        return obs, reward, False, {}
    
    def _get_screenshot(self):
        return np.array(self.sct.grab(self.ltrb))[...,:3].transpose(2,0,1) # (color[BGR], height, width) 
        
# ここまで書いて気づいたが， DesktopEnv は受け身で実行される立場なので（画面の状態に追従して実行される），
# openai-gym の形式じゃない方が書きやすいかも．

In [256]:
import mss
import pyautogui as pgui
import pytesseract
class DesktopTrainer:
    """
    Agent と環境の間を仲介する．
    """
    def __init__(self, agent, ltrb=(0, 0, 640, 480), ltrb_score=(0, 0, 100, 20), action=[[],['left'],['up'],['right'],['down']]):
        self.sct = mss.mss()
        self.ltrb = ltrb
        self.ltrb_score = ltrb_score
        self.action = action
        self.agent = agent 
        self.screen_info_log = deque()
        self.log_ssa = deque()
        self.gameover = False
        
    def run(self):
        """
        学習を開始する．
        """
        while True:
            time.sleep(0.1)
            info = self._get_screen_info()
            selected_action = self.agent.determine(info["screen_array"])
            info["selected_action"] = selected_action
            # update key pressing w.r.t. selected_action
            # pgui.keyDown('shift')
            # pgui.keyUp('shift')
            # https://pyautogui.readthedocs.io/en/latest/introduction.html
            
            # logging screen info
            DesktopTrainer._update_log(self.screen_info_log, info, 100)
            
            # ゲームが続いているかの判定． とりあえず，スコアが読み取れたかどうかを使って判定する．
            score = np.array([log["score"] for log in self.screen_info_log])
            if len(self.screen_info_log) >= 20:
                if all(np.isnan(score[-20:])):
                    self.gameover = True
                    print("gameover detected")
                    # TODO: ここで agent に負の observation を伝える？
                elif not any(np.isnan(score[-20:])):
                    self.gameover = False
                    print("game started")
                    
            if self.gameover:
                continue
                
            if len(score) >= 20:
                key_frame = -11
                past_score = np.nanmedian(score[key_frame*2+2:-1])
                current_score = np.nanmedian(score[key_frame*2+3:])
                if not np.isnan(past_score) and not np.isnan(current_score):
                    # agent に observation を伝える．
                    r = current_score - past_score
                    r = np.sign(r) * np.log(1 + np.abs(r))                    
                    s0 = self.screen_info_log[key_frame]["screen_array"]
                    a = self.screen_info_log[key_frame]["selected_action"]
                    s1 = self.screen_info_log[key_frame+1]["screen_array"]
                    self.agent.observe([s0, a, r, s1])
                    if past_score != current_score:
                        print("[score changed] {} => {}".format(past_score, current_score))
    
    @staticmethod
    def _update_log(dq, item, maxlen):
        dq.append(item)
        if len(dq) > maxlen: dq.popleft()
    
    def _get_screen_info(self):
        screen_ary = np.array(self.sct.grab(self.ltrb))[...,:3].transpose(2,0,1) # (color[BGR], height, width) 
        score_str = pytesseract.image_to_string(np.array(self.sct.grab(self.ltrb_score)), lang='eng')
        score_digits = ''.join([s for s in score_str if s in '0123456789'])
        score = int(score_digits) if score_digits != "" else np.nan
        return {
            "screen_array": screen_ary,
            "score_string": score_str,
            "score": score            
        }
        

In [257]:
agent = Agent(3)
dt = DesktopTrainer(agent, ltrb=(0,0,320,240), ltrb_score=(25,159,158,207), action=[[], ['up'], ['down']])

In [255]:
dt.run()

past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0
past_score: 97.0 => current_score: 97.0


KeyboardInterrupt: 

# スコアを読み取る機能の実装が必須．

In [241]:
import mss, time
with mss.mss() as sct:
    # ary = np.array(sct.grab(sct.monitors[0])) # full screen
    ary = np.array(sct.grab((25,159,158,207))) # left, top, right, bottom
    # ary: (height, width, color[BGRA]) uint8
# from PIL import Image; Image.fromarray(ary[...,[2,1,0,3]]).show()
img = Image.fromarray(ary[...,[2,1,0,3]])

In [21]:
import pytesseract

In [28]:
res = pytesseract.image_to_boxes(img, output_type=pytesseract.Output.DICT)

In [249]:
import pytesseract
from PIL import ImageDraw, ImageFont
def ocr(PIL_image, config='--oem 1'): # add "--psm 7" for one-line
    img = copy.deepcopy(PIL_image)
    draw = ImageDraw.Draw(img)
    font_path = '/System/Library/Fonts/ヒラギノ角ゴシック W0.ttc'
    draw.font = ImageFont.truetype(font_path, 20)
    res = pytesseract.image_to_boxes(img, lang='eng', output_type=pytesseract.Output.DICT, config=config)
    res_t = pytesseract.image_to_string(img, lang='eng', config=config)
    for l,t,r,b,c in zip(res['left'], res['top'], res['right'], res['bottom'], res['char']):
        draw.rectangle((l,img.height-t,r,img.height-b), outline=(255, 64, 64))
        draw.text((r,img.height-b), c, fill=(255, 64, 64))
    img.show()
    return res_t

In [250]:
ocr(img)

'92'

In [169]:
class A:
    def __init__(self):
        self.hoge = np.random.randint(114514)
        
    def save(self, path):
        with open(path, 'wb') as f:
            pickle.dump(self, f)
            
    def load(self, path):
        with open(path, 'rb') as f:
            self.__dict__.update(pickle.load(f).__dict__)

In [173]:
a = A()

In [171]:
a.save('tmp.pickle')

In [176]:
a.hoge

70483

In [175]:
a.load('tmp.pickle')