In [462]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from PIL import Image
from reward import SimulatedReward
from helper import detect, capture, pad_inner_array, SocketListener
from pynput import keyboard
from pynput.keyboard import Controller, Key
from concurrent.futures import ThreadPoolExecutor
import mss
import time
import pathlib
import os
import json
import random
import logging
import warnings

In [463]:
if not os.path.exists('yolov5'):
    !git clone https://github.com/ultralytics/yolov5
    !pip install -r yolov5/requirements.txt

warnings.simplefilter("ignore", FutureWarning)
logging.getLogger('ultralytics').setLevel(logging.ERROR)

In [464]:
class AC_Net(nn.Module):
  def __init__(self, input, action_space):
    super(AC_Net, self).__init__()
    self.fc1 = nn.Linear(input[0] * input[1], 128)
    self.fc2 = nn.Linear(128, 128)
    self.lstm = nn.LSTM(128, 128, batch_first=True)
    self.actors = nn.ModuleList([nn.Linear(128, action.n) for action in action_space])
    self.critic = nn.Linear(128, 1)
    
    for layer in [self.fc1, self.fc2, *self.actors, self.critic]:
      nn.init.xavier_uniform_(layer.weight)
      nn.init.constant_(layer.bias, 0)
    
  def forward(self, x, hx=None):
    x = x.flatten()
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = nn.Dropout(0.3)(x)
    x, hx = self.lstm(x.unsqueeze(0), hx)
    return [actor(x) for actor in self.actors], self.critic(x), hx

In [None]:
class SimOsuEnvironment(gym.Env):
  def __init__(self, max_notes=8, render_mode=False):
    self.max_notes = max_notes
    self.keys_reference = ['s', 'd', 'k', 'l'] # Used for keyboard input with index corresponding to a lane
    self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=(self.max_notes, 3))
    self.action_space = gym.spaces.MultiDiscrete([4] * len(self.keys_reference))
    self.observation = np.zeros((self.max_notes, 3))

    self.custom_reward = [50, 100, 2, -1, -15, -2, -2, -150, -1] # Custom reward function to be used with SimulatedReward
    self.reward_scale = 100 # Scaling reward to avoid extremely small values
    self.custom_reward_scaled = []  # Scaled custom reward, to be set in reset
    self.invalid_action_penalty = -5
    self.invalid_action_penalty_scaled = 0 # Scaled invalid action penalty, to be set in reset
    self.sim_reward = SimulatedReward()
    self.executor = ThreadPoolExecutor()
    
    self.model = None 
    self.capture_region = None
    self._vision_setup()
    
    self.chosen_folder = None # A randomly chosen folder from available_folders, to be set in reset
    self.available_folders = os.listdir('./frames/rl_simulated_training_sample')
    self.frames = [] # A list of frames in the chosen folder, to be set in reset
    self.frame_notes = [] # A note representation of each frame
    self.metadata = {} # Metadata of the chosen song, to be set in reset
    self.frame_step = 0
    self.note_count = 0
    
    self.invalid_actions = 0 # Counts the number of invalid actions in a step
    self.terminated = False
    self.render_mode = render_mode
    
    self.total_invalid_actions = 0
    self.actions_taken = np.zeros((4, 4))
    
    # Render mode requires the game to be running and the game window to be visible
    if self.render_mode:
      self.listener = SocketListener()
      self.listener.start()
      self.keyboard = Controller()
    
  def reset(self):
    self.observation = np.zeros((self.max_notes, 3))
    self.invalid_actions = 0
    self.total_invalid_actions = 0
    self.actions_taken = np.zeros((4, 4))
    
    if not self.render_mode:
      random_folder = random.choice(self.available_folders)
      self.chosen_folder = './frames/rl_simulated_training_sample/' + random_folder
      self.frames = os.listdir(self.chosen_folder)[:-1]
      self.frame_notes = []
      self.frame_step = 0
      
      self.metadata = json.load(open(self.chosen_folder + '/metadata.json'))
      self.note_count = self.metadata['note_count'] + self.metadata['hold_note_count'] * 2
      
      self.custom_reward_scaled = [(r / self.note_count) * self.reward_scale for r in self.custom_reward]
      self.invalid_action_penalty_scaled = (self.invalid_action_penalty / self.note_count) * self.reward_scale
      self.sim_reward.set_custom_rewards(self.custom_reward_scaled)
    
    return self.observation
  
  def step(self, multi_actions):
    if self.render_mode:
      # Wait for the first connection to be established when starting a song in-game
      while self.listener.is_first_connection and not self.listener.has_connection:
        pass
      
      if self.listener.has_connection:
        img = capture(self.capture_region)
      
      # Terminates when the song ends
      self.terminated = not self.listener.is_first_connection and not self.listener.has_connection
    else:
      img = Image.open(self.chosen_folder + '/' + self.frames[self.frame_step])
      self.frame_step += 1
      self.terminated = self.frame_step >= len(self.frames)
    
    if img:
      vision_thread = self.executor.submit(detect, img, self.model)
      self.observation = vision_thread.result()
      self.frame_notes.append(self.observation)
      
    actions, action_types = self._parse_multi_actions(multi_actions)
    reward = self.sim_reward.get_simulated_reward(actions, self.frame_notes[-2] if len(self.frame_notes) > 3 else []) + self.invalid_action_penalty_scaled * self.invalid_actions
    self.total_invalid_actions += self.invalid_actions
    self.invalid_actions = 0
    
    if len(actions) > 0 and self.render_mode:
      self.executor.submit(self._perform_keyboard_action, actions, action_types)
    
    info = {}
    self.observation = pad_inner_array([self.observation], [0, 0 ,0], self.max_notes)[0]
    
    
    return self.observation, reward, self.terminated, info
    
  def get_meta_data(self):
    '''
    Returns the metadata of the current song in the form of 
    {
      'song_name' (str), \\
      'song_duration' (int) in seconds, \\
      'note_count' (int,) \\
      'hold_note_count' (int), \\
      'difficulty' (float)
    }
    '''
    return self.metadata
  
  def _parse_multi_actions(self, multi_actions):
    '''
    Parses the multi actions into a list of actions and their types
    
    '''
    parsed_actions = []
    action_type = []
    
    for action_lane, action in enumerate(multi_actions):
      match action:
        case 0: # Do nothing
          self.actions_taken[action_lane][0] += 1
        case 1: # Release
          self.actions_taken[action_lane][1] += 1
          
          if not self.render_mode:
            if not self.sim_reward.get_key_held(action_lane):
              self.invalid_actions += 1
              continue
            
            self.sim_reward.update_keys_held(action_lane, False)
          
          parsed_actions.append(action_lane)
        case 2: # Press
          self.actions_taken[action_lane][2] += 1
          
          if self.sim_reward.get_key_held(action_lane) and not self.render_mode:
            self.invalid_actions += 1
            
          self.sim_reward.update_keys_held(action_lane, False)
          parsed_actions.append(action_lane)
        case 3: # Hold
          self.actions_taken[action_lane][3] += 1
          
          if not self.sim_reward.get_key_held(action_lane):
            parsed_actions.append(action_lane)
          
          if not self.render_mode:
            self.sim_reward.update_keys_held(action_lane, True)

      if action_type != 0:
        action_type.append(action)
          
    return parsed_actions, action_type
  
  def _perform_keyboard_action(self, parsed_actions, action_type):
    for action, action_type in zip(parsed_actions, action_type):
      if action_type in [2, 3]:
        self.keyboard.press(self.keys_reference[action])
        if action_type == 3:
          self.sim_reward.update_keys_held(action, True)
          
      time.sleep(0.04)
      
      if action_type in [1, 2]:
        self.keyboard.release(self.keys_reference[action])
        if action_type == 1:
          self.sim_reward.update_keys_held(action, False)
  
  def _vision_setup(self):
    if os.name == 'nt':
      pathlib.PosixPath = pathlib.WindowsPath # https://github.com/ultralytics/yolov5/issues/10240#issuecomment-1662573188
    
    if self.render_mode:
      monitor = mss.mss().monitors[-1]
      t, l, w, h = monitor['top'], monitor['left'], monitor['width'], monitor['height']
      self.capture_region = {'left': l+int(w * 0.338), 'top': t, 'width': w-int(w * 0.673), 'height': h} 
      
    self.model = torch.hub.load('ultralytics/yolov5', 'custom', path='./models/best.pt', force_reload=True)
    

In [466]:
osu_env = SimOsuEnvironment(max_notes=4)
ac_net = AC_Net(osu_env.observation_space.shape, osu_env.action_space)
print(ac_net)

Downloading: "https://github.com/ultralytics/yolov5/zipball/master" to C:\Users\tiany/.cache\torch\hub\master.zip
YOLOv5  2024-11-29 Python-3.10.6 torch-2.5.1+cu118 CUDA:0 (NVIDIA GeForce RTX 3050 Ti Laptop GPU, 4096MiB)

Fusing layers... 
Model summary: 157 layers, 7018216 parameters, 0 gradients, 15.8 GFLOPs
Adding AutoShape... 


AC_Net(
  (fc1): Linear(in_features=12, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (lstm): LSTM(128, 128, batch_first=True)
  (actors): ModuleList(
    (0-3): 4 x Linear(in_features=128, out_features=4, bias=True)
  )
  (critic): Linear(in_features=128, out_features=1, bias=True)
)


In [467]:
print(osu_env.action_space)
print(osu_env.observation_space)
print(osu_env.reset())

MultiDiscrete([4 4 4 4])
Box(0.0, inf, (4, 3), float32)
[[          0           0           0]
 [          0           0           0]
 [          0           0           0]
 [          0           0           0]]


In [468]:
def train(
  ac_net,
  osu_env,
  optimizer,
  max_episode,
  gamma=0.99,
  beta=0.01,
  grad_clip = 1.0
):
  total_rewards = []

  for episode in range(max_episode):
    terminated = False
    states, multi_actions, rewards = [], [], []
    hx = None

    state = torch.tensor(osu_env.reset(), dtype=torch.float32)
    while not terminated:
      probs, value, hx = ac_net(state, hx)
      action = []

      for prob in probs:
        prob = F.softmax(prob, dim=-1).squeeze()
        action.append(torch.distributions.Categorical(prob).sample().item())

      next_state, reward, terminated, _ = osu_env.step(action)

      states.append(state)
      rewards.append(reward)
      multi_actions.append(action)

      state = torch.tensor(next_state, dtype=torch.float32)
    R = 0 if terminated else value

    for i in reversed(range(len(rewards))):
      R = torch.tensor(rewards[i], dtype=torch.float32) + gamma * R
      R = R.detach()

      probs, value, hx = ac_net(states[i], hx)
      
      advantage = R - value
      policy_loss = []
      for i_a, prob in enumerate(probs):
        actions = multi_actions[i]
        softmax_probs = F.softmax(prob, dim=-1).squeeze()

        categorical_dist = torch.distributions.Categorical(softmax_probs)
        entropy = categorical_dist.entropy().sum() * beta

        log_probs = torch.log(softmax_probs)
        policy_loss.append(-log_probs[actions[i_a]] * advantage + entropy)

    value_loss = F.mse_loss(value, R)

    loss = torch.sum(torch.stack(policy_loss)) + value_loss 
    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(ac_net.parameters(), grad_clip)
    optimizer.step()

    total_rewards.append(sum(rewards))
    meta_data = osu_env.get_meta_data()
    print(f'Episode {episode:<5} {"[" + meta_data["song_name"] + "] " + str(meta_data["difficulty"]):<50} Reward: {sum(rewards):>10.4f}, loss: {loss.item():>10.4f}')
    print(osu_env.sim_reward.get_debug())
    print(osu_env.total_invalid_actions)
    print(osu_env.actions_taken)
    
  return total_rewards

In [469]:
def test(
  ac_net,
  osu_env,
  max_episode
):
  total_rewards = []
  hx = None

  for episode in range(max_episode):
    terminated = False
    truncated = False
    states, multi_actions, rewards = [], [], []

    state = torch.tensor(osu_env.reset(), dtype=torch.float32)
    while not terminated:
      probs, value, hx = ac_net(state)
      action = []

      for prob in probs:
        prob = F.softmax(prob, dim=-1).squeeze()
        action.append(torch.argmax(prob).item())

      next_state, reward, terminated, _ = osu_env.step(action)

      states.append(state)
      rewards.append(reward)
      multi_actions.append(action)

      state = torch.tensor(next_state, dtype=torch.float32)

    total_rewards.append(sum(rewards))
    meta_data = osu_env.get_meta_data()
    print(f'Episode {episode:<5} {"[" + meta_data["song_name"] + "] " + str(meta_data["difficulty"]):<50} Reward: {sum(rewards):>10.4f}')
    print(osu_env.sim_reward.get_debug())
    print(osu_env.total_invalid_actions)
    print(osu_env.actions_taken)
    
  return total_rewards

In [470]:
max_episode = 100
optimizer = torch.optim.Adam(ac_net.parameters(), lr=0.001)
total_rewards = train(ac_net, osu_env, optimizer, max_episode)

Episode 0     [Cirno Break] 2.11                                 Reward: -1120.0590, loss:  4800.7808
({'good_regular_notes': 460, 'good_end_holds': 0, 'good_hold': 53}, {'bad_hold': 68, 'broken_hold': 1693, 'bad_press': 699, 'bad_release': 17, 'missed_notes': 173, 'unnecessary_press': 0})
1597
[[        358         255         640         272]
 [        257         440         410         418]
 [        332         375         487         331]
 [        409         444         324         348]]
Episode 1     [Ahoy!! Warera Houshou Kaizokudan] 2.12            Reward:  -807.1525, loss:  2990.8501
({'good_regular_notes': 456, 'good_end_holds': 0, 'good_hold': 155}, {'bad_hold': 140, 'broken_hold': 1242, 'bad_press': 711, 'bad_release': 45, 'missed_notes': 209, 'unnecessary_press': 0})
1406
[[        365         265         573         174]
 [        297         387         417         276]
 [        260         340         437         340]
 [        409         380         301         28

KeyboardInterrupt: 

In [None]:
test_rewards = test(ac_net, osu_env, 1)

In [None]:
torch.save(ac_net.state_dict(), 'models/ac_net.pth')

In [None]:
%matplotlib inline
plt.plot(total_rewards)
plt.show()