In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from PIL import Image
from reward import SimulatedReward
from helper import detect, capture, pad_inner_array, SocketListener
from pynput import keyboard
from pynput.keyboard import Controller, Key
from concurrent.futures import ThreadPoolExecutor
from environment import OsuEnvironment
import mss
import time
import pathlib
import os
import json
import random
import logging
import warnings

In [9]:
if not os.path.exists('yolov5'):
    !git clone https://github.com/ultralytics/yolov5
    !pip install -r yolov5/requirements.txt

warnings.simplefilter("ignore", FutureWarning)
logging.getLogger('ultralytics').setLevel(logging.ERROR)

In [10]:
class AC_Net(nn.Module):
  def __init__(self, input, action_space):
    super(AC_Net, self).__init__()
    self.fc1 = nn.Linear(input[1] * input[2], 128)
    self.fc2 = nn.Linear(128, 128)
    self.lstm = nn.LSTM(128, 128, batch_first=True)
    self.actors = nn.ModuleList([nn.Linear(128, action.n) for action in action_space])
    self.critic = nn.Linear(128, 1)
    
    for layer in [self.fc1, self.fc2, *self.actors, self.critic]:
      nn.init.xavier_uniform_(layer.weight)
      nn.init.constant_(layer.bias, 0)
    
  def forward(self, x, hx=None):
    x = x.flatten()
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = nn.Dropout(0.3)(x)
     
    x = x.unsqueeze(0)
    x, hx = self.lstm(x, hx)
    return [actor(x) for actor in self.actors], self.critic(x), hx

In [11]:
osu_env = OsuEnvironment(num_frame=1, max_notes=4, monitor_id=2)
ac_net = AC_Net(osu_env.observation_space.shape, osu_env.action_space)

Downloading: "https://github.com/ultralytics/yolov5/zipball/master" to C:\Users\tiany/.cache\torch\hub\master.zip
YOLOv5  2024-11-30 Python-3.10.6 torch-2.5.1+cu118 CUDA:0 (NVIDIA GeForce RTX 3050 Ti Laptop GPU, 4096MiB)

Fusing layers... 
Model summary: 157 layers, 7018216 parameters, 0 gradients, 15.8 GFLOPs
Adding AutoShape... 


[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted


Traceback (most recent call last):
  File "c:\Users\tiany\Desktop\Work\Dev\UB CLASS\FALL 2024\RL\final project\helper.py", line 112, in _listen
    self.sock.bind((self.server, self.port))
OSError: [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted


In [12]:
ac_net.load_state_dict(torch.load('./models/ac_net_2.pth'))

<All keys matched successfully>

In [None]:
def test(
  ac_net,
  osu_env,
  max_episode
):
  ac_net.eval()
  total_rewards = []
  hx = None

  for episode in range(max_episode):
    done = False
    states, rewards = [], []

    state = torch.tensor(np.array(osu_env.reset()).squeeze(), dtype=torch.float32)
    osu_env.pick_random_song()
    while not done and osu_env.checking_connection():
      with torch.no_grad():
        if osu_env.song_begin():
          probs, value, hx = ac_net(state)
          action = []

          for prob in probs:
            prob = F.softmax(prob, dim=-1).squeeze()
            action.append(torch.argmax(prob).item())

          next_state, reward, terminated, truncated = osu_env.step(action, train=False)
          done = terminated or truncated

          states.append(state)
          rewards.append(reward)

          state = torch.tensor(np.array(next_state).squeeze(), dtype=torch.float32)
          print(state)
          
        if osu_env.lost_connection():
          break
        
    osu_env.return_to_song_selection_after_song()

    total_rewards.append(sum(rewards))
    print(f'Episode {episode + 1} - Reward: {total_rewards[-1]}')
    
  return total_rewards

In [14]:
total_rewards = test(ac_net, osu_env, 3)

Episode 1 - Reward: 0
Episode 2 - Reward: 0


KeyboardInterrupt: 