In [1]:
import os
import warnings
import logging
import time
import mss
import torch
import numpy as np
from pynput import keyboard
from pynput.keyboard import Controller, Key
from concurrent.futures import ThreadPoolExecutor
import pathlib
from helper import SocketListener
if not os.path.exists('yolov5'):
    !git clone https://github.com/ultralytics/yolov5
    !pip install -r yolov5/requirements.txt

warnings.simplefilter("ignore", FutureWarning)
logging.getLogger('ultralytics').setLevel(logging.ERROR)

executor = ThreadPoolExecutor(max_workers=8)

In [2]:
pathlib.PosixPath = pathlib.WindowsPath # https://github.com/ultralytics/yolov5/issues/10240#issuecomment-1662573188
model = torch.hub.load('ultralytics/yolov5', 'custom', path='./models/best.pt', force_reload=True)  

monitor = mss.mss().monitors[1]
t, l, w, h = monitor['top'], monitor['left'], monitor['width'], monitor['height']
region = {'left': l+int(w * 0.338), 'top': t, 'width': w-int(w * 0.673), 'height': h} 

Downloading: "https://github.com/ultralytics/yolov5/zipball/master" to C:\Users\bohui/.cache\torch\hub\master.zip
YOLOv5  2024-11-23 Python-3.11.5 torch-2.5.0+cu118 CUDA:0 (NVIDIA GeForce RTX 3050 Ti Laptop GPU, 4096MiB)

Fusing layers... 
Model summary: 157 layers, 7018216 parameters, 0 gradients, 15.8 GFLOPs
Adding AutoShape... 


In [3]:
def detect(img, model):
  lanes = {
    0 : (10, 180),
    1 : (150, 320),
    2 : (300, 470),
    3 : (440, 610)
  }
  ret = []  
  res = model(img)
  
  for box in res.xyxy[0]:
    # Confidence level is less than 50%
    if box[4] < 0.50:
      continue
    
    x_center = int((box[0] + box[2]) / 2)
    y_center = int((box[1] + box[3]) / 2)
    class_id = int(box[5]) # classes are 0: end_hold, 1: note, 2: start_hold
    
    # Identify the lane of the note based on x_center
    for lane, (start, end) in lanes.items():
      if start <= x_center <= end:
        break
    
    ret.append([class_id, lane, y_center])
  
  if ret:
    ret = sorted(ret, key=lambda note: (-note[2], note[1]))

  return ret 

def capture(region):
  with mss.mss() as sct:
    return sct.grab(region)

In [4]:
# Store pressed keys
pressed_keys = set()

def on_press(key):
    try:
        pressed_keys.add(key.char)  # Record alphanumeric keys
    except AttributeError:
        pressed_keys.add(str(key))  # Record special keys

def on_release(key):
    try:
        pressed_keys.discard(key.char)  # Remove released alphanumeric keys
    except AttributeError:
        pressed_keys.discard(str(key))  # Remove released special keys

def record_key():
    # Record the current frame's keys
    current_keys = list(pressed_keys) 
    return current_keys


In [5]:
# preprocess the actions into shape [4,4,4,4]
def preprocess_actions(actions):
    key_hold = [False] * 4
    map_key = {"s": 0, "d" : 1, "k" : 2, "l" : 3}
    clean_actions = []
    for i in range(len(actions)):
        chars = actions[i]
        action = [0, 0, 0, 0]

        if chars == []:
            clean_actions.append(action)
            continue

        for char in chars:
            key = map_key.get(char)
            if key is None:
                continue
            
            if not key_hold[key]:
                action[key] = 1
            else:
                action[key] = 2

            if i + 1 < len(actions):
                if char not in actions[i+1]:
                    if key_hold[key]:
                        action[key] = 3            
                else:
                    key_hold[key] = True
                    action[key] = 2
                
        clean_actions.append(action)
    
    return clean_actions

In [6]:
currently_hold = [False] * 4
invalid = False
key_pressed = False
def keyboard_action(lane, key, action):
    match action:
        case 0: # do nothing
            if currently_hold[lane]:
                invalid = True
            else:
                return
        case 1: # press
            if currently_hold[lane]:
                invalid = True
            else:
                keyboard_controller.press(key)
                keyboard_controller.release(key)
                key_pressed = True
        case 2: # hold
            if currently_hold[lane]:
                return
            keyboard_controller.press(key)
            currently_hold[lane] = True
            key_pressed = True
        case 3: # release
            if currently_hold[lane]:
                keyboard_controller.release(key)
                currently_hold[lane] = False
            else:
                invalid = True

In [7]:
def perform_action(action):
    keys = ["s", "d", "k", "l"]
    threads = []
    for lane in range(len(action)):
        key = executor.submit(keyboard_action, lane, keys[lane], action[lane])
    threads.append(key)

    for key in threads:
        key.result()

In [8]:
# Capture keyboard input at x FPS
FPS = 15
frame_duration = 1 / FPS

# track notes and key input
frame = []
actions = []
hit_type = []

In [9]:
key_listener = keyboard.Listener(on_press=on_press, on_release=on_release)
key_listener.start()
listener = SocketListener()
listener.start()

keyboard_controller = Controller()

Listening on 127.0.0.1:5555


In [10]:
# placehold for action_fuc
def action_fuc():
    return

In [11]:
song_begin = False
while listener.is_first_connection or listener.is_listening:
    if listener.has_connection:
        song_begin = True
        time_start = time.time()
        vision_thread = executor.submit(capture, region)
        image = vision_thread.result()

        vision_thread = executor.submit(detect, np.array(image), model)
        notes = vision_thread.result() 

        keys = record_key()
        data = listener.fetch_data(action_fuc, 0.01)

        frame.append(notes)
        actions.append(keys)
        hit_type.append(data)

        elapsed_time = time.time() - time_start
        if elapsed_time < frame_duration:
                time.sleep(frame_duration-elapsed_time)
                
    elif song_begin:
         break

In [12]:
actions = preprocess_actions(actions)
# frame = frame[:len(hit_type)]
# actions = actions[:len(hit_type)]

In [13]:
def saveData(name):
    os.makedirs('./expert_demo', exist_ok=True)

    data = {"frame": frame,
            "action": actions,
            "hit_type": hit_type}
    
    torch.save(data, f"./expert_demo/{name}_{time.time()}.pth")

In [14]:
print("frame: ", len(frame))
print("action: ", len(actions))
print("hit type: ", len(hit_type))

frame:  1030
action:  1030
hit type:  1030


In [15]:
saveData("replay")

In [None]:
#test in game
# song_begin = False
# while listener.is_listening or listener.is_first_connection:
#     if listener.has_connection:
#         song_begin = True
#         time_start = time.time()
#         vision_thread = executor.submit(capture, region)
#         image = vision_thread.result()

#         vision_thread = executor.submit(detect, np.array(image), model)
#         notes = vision_thread.result() 
        
#         perform_action(actions[i])

#         elapsed_time = time.time() - time_start
#         if elapsed_time < frame_duration:
#                 time.sleep(frame_duration-elapsed_time)

#     elif song_begin:
#          break