# 收集数据

In [None]:
import os
import time

import cv2
import numpy as np
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)

from pysekiro.img_tools.get_vertices import roi
from pysekiro.img_tools.grab_screen import get_screen
from pysekiro.key_tools.get_keys import key_check

def get_output(keys):

    output = [0, 0, 0, 0, 0]

    if   'J' in keys:
        output[0] = 1    # 等同于[1, 0, 0, 0, 0]
    elif 'K' in keys:
        output[1] = 1    # 等同于[0, 1, 0, 0, 0]
    elif 'LSHIFT' in keys:
        output[2] = 1    # 等同于[0, 0, 1, 0, 0]
    elif 'SPACE' in keys:
        output[3] = 1    # 等同于[0, 0, 0, 1, 0]
    else:
        output[4] = 1    # 等同于[0, 0, 0, 0, 1]

    return output

height   = 300
width    = 300
channels = 1

x   = 250
x_w = 550
y   = 75
y_h = 375

class Data_collection:
    def __init__(self, target):
        self.target = target    # 目标
        self.dataset = list()    # 保存数据的容器
        self.save_path = os.path.join('The_battle_memory', self.target)    # 保存的位置
        if not os.path.exists(self.save_path):    # 确保保存的位置存在
            os.mkdir(self.save_path)

        self.step = 0    # 计步器

    def save_data(self):
        print('\n\nStop, please wait')
        n = 1
        while True:    # 直到找到保存位置并保存就 break
            filename = f'training_data-{n}.npy'
            save_path = os.path.join(self.save_path, filename)
            if not os.path.exists(save_path):    # 没有重复的文件名就执行保存并退出
                print(save_path)
                np.save(save_path, self.dataset)
                break
            n += 1
        print('Done!')

    def collect_data(self):

        print('Ready!')
        paused = True
        while True:
            last_time = time.time()
            keys = key_check()
            if paused:
                if 'T' in keys:
                    paused = False
                    print('Starting!')
            else:

                self.step += 1

                screen = get_screen()    # 获取屏幕图像
                screen_gray = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY)
                screen_roi = cv2.resize(roi(screen_gray, x, x_w, y, y_h), (width, height))

                t = 0.04 - last_time
                if t > 0:
                	time.sleep(t)

                if not (np.sum(screen_roi == 0) > int(width * height * channels / 8)):    # 当图像有1/8变成黑色（像素值为0）的时候停止暂停存储数据
                    action_list = get_output(keys)    # 获取按键输出列表
                    self.dataset.append([screen_roi, action_list])    # 图像和输出打包在一起，保证一一对应

                print(f'\rstep:{self.step:>4}. Loop took {round(time.time()-last_time, 3):>5} seconds.', end='')

                if 'P' in keys:    # 结束
                    self.save_data()    # 保存数据
                    break

target = 'Genichiro_Ashina' # 苇名弦一郎
# target = 'Inner_Genichiro' # 心中的弦一郎
# target = 'Isshin,_the_Sword_Saint' # 剑圣一心
# target = 'Inner_Isshin' # 心中的一心

c = Data_collection(target)
c.collect_data()

# 数据预览

**你得有数据才能预览对吧。。。**

In [None]:
import os

import cv2
import numpy as np

data_path = os.path.join('The_battle_memory', 'Genichiro_Ashina', 'training_data-1.npy')
data = np.load(data_path, allow_pickle=True)

for img, action_value in data:
    action = np.argmax(action_value)
    if action == 0:
        a = '攻击'
    elif action == 1:
        a = '防御'
    elif action == 2:
        a = '垫步'
    elif action == 3:
        a = '跳跃'
    else:
        a = '其他'
    print(f'\r{a:>3}', end='')
    
    cv2.imshow('screen', img)
    cv2.waitKey(1)
    
else:
    cv2.destroyAllWindows()

# 训练模型

In [None]:
import os

import cv2
import numpy as np

from pysekiro.model import MODEL

in_depth    = 5
in_width    = 300
in_height   = 300
in_channels = 1
lr = 0.01

n_action = 5

one_hot = lambda x: [(1 if y == x else 0) for y in range(n_action)]

def get_X(data):
    X = list()
    box = list()
    for i in data:
        if len(box) == in_depth:
            X.append(list(box))
            box.clear()
        box.append(i[0])
    return np.array(X).reshape(-1, in_depth, in_height, in_width, in_channels)

def get_Y(data):
    Y = list()
    box = list()
    for i in data:
        if len(box) == in_depth:
            l_box = [box.count(i) for i in range(n_action)]
            if np.any(l_box[:n_action-1]):
                y = np.argmax(l_box[:n_action-1])
            else:
                y = n_action-1
            Y.append(one_hot(y))
            box.clear()
        box.append(np.argmax(i[1]))
    return np.array(Y)

def train(
    target,
    start=1,
    end=1,
    batch_size=8,
    epochs=3,
    load_weights_path=None
    ):

    model = MODEL(
        in_depth = in_depth,
        in_width = in_width,
        in_height = in_height,
        in_channels = in_channels,
        outputs = n_action,
        lr = lr,
        load_weights_path = load_weights_path
    )
    model.summary()

    model_weights = 'dl_weights.h5'

    # 读取一个数据集训练，然后再读取下一个数据集训练，以此类推
    for i in range(start, end+1):

        filename = f'training_data-{i}.npy'
        data_path = os.path.join('The_battle_memory', target, filename)

        if os.path.exists(data_path):    # 确保数据集存在
        
            # 加载数据集
            data = np.load(data_path, allow_pickle=True)
            print('\n', filename, f'total:{len(data):>5}')

            # 数据集处理成预训练格式
            X = get_X(data)
            Y = get_Y(data)

            # 训练模型，然后保存
            model.fit(X, Y, batch_size=batch_size, epochs=epochs, verbose=1, shuffle=False)
            model.save_weights(model_weights)
        else:
            print(f'{filename} does not exist ')

target = 'Genichiro_Ashina' # 苇名弦一郎
# target = 'Inner_Genichiro' # 心中的弦一郎
# target = 'Isshin,_the_Sword_Saint' # 剑圣一心
# target = 'Inner_Isshin' # 心中的一心
train(target)

# 测试模型(临时代码)

In [None]:
import os
import time

import cv2
import numpy as np
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)

from pysekiro.img_tools.get_vertices import roi
from pysekiro.img_tools.grab_screen import get_screen
from pysekiro.key_tools.actions import act
from pysekiro.key_tools.get_keys import key_check
from pysekiro.model import MODEL

in_depth    = 5
in_width    = 300
in_height   = 300
in_channels = 1
lr = 0.01

n_action = 5

load_weights_path = 'dl_weights.h5'

x   = 250
x_w = 550
y   = 75
y_h = 375

one_hot = lambda x: [(1 if y == x else 0) for y in range(n_action)]

def test():
    model = MODEL(
        in_depth = in_depth,
        in_width = in_width,
        in_height = in_height,
        in_channels = in_channels,
        outputs = n_action,
        lr = lr,
        load_weights_path = load_weights_path
    )
    
    screen_box = list()

    print('Ready!')
    paused = True
    while True:
        last_time = time.time()
        keys = key_check()
        if paused:
            if 'T' in keys:
                paused = False
                print('Starting!')
        else:
            screen = get_screen()    # 获取屏幕图像
            screen_gray = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY)
            screen_roi = cv2.resize(roi(screen_gray, x, x_w, y, y_h), (in_width, in_height))

            t = 0.04 - last_time
            if t > 0:
                time.sleep(t)

            if len(screen_box) == in_depth:
                X = np.array(screen_box).reshape(-1, in_depth, in_height, in_width, in_channels)
                y_box = list()
                l_box = list()
                for action_vlaue in model.predict(X):
                    y_box.append(np.argmax(action_vlaue))
                l_box = [y_box.count(i) for i in range(n_action)]
                Y = np.argmax(l_box)
                act(Y)
                screen_box = list()
                print(f'\rLoop took {round(time.time()-last_time, 3):>5} seconds. {Y}', end='')
            screen_box.append(screen_roi)

            if 'P' in keys:    # 结束
                break
    print('\nDone!')
test()