<h1> Detecting Whether the Inputs are a Sheik </h1>

We want to train a binary classifier to accurately predict whether a multi-channel time series (representing a Super Smash Bros. Melee player's inputs) was produced by a Sheik player. We first load our required libraries.

In [1]:
import os as os

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.optim as optim

import tqdm

import slippi as slp

<h2> Preliminary Functions </h2>

We use these functions to one-hot encode the button bitmask and get the frame data for a given port number and frames object.

In [2]:
# Set the number of time steps in the model inputs
frames_per_input = 60 * 12     # 12 seconds of gameplay

def one_hot_encode(bitmask):
    labels = ['DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_DOWN', 'DPAD_UP', 'Z', 'R', 'L', 'A', 'B', 'X', 'Y', 'START']
    encoded_values = [1, 2, 4, 8, 16, 32, 64, 256, 512, 1024, 2048, 4096]

    # Create a dictionary mapping labels to their encoded values
    label_to_value = dict(zip(labels, encoded_values))

    # Initialize a list to store the one-hot encoded values
    one_hot_encoded = [0] * len(labels)

    # Iterate through labels and set the corresponding one-hot encoded value
    for label, value in label_to_value.items():
        if bitmask & value:
            one_hot_encoded[labels.index(label)] = 1

    return one_hot_encoded

def get_frame_data(frames, port):
    sheik_inputs = np.empty((0, 18))  # Initialize an empty NumPy array

    for i, frame in enumerate(frames[300: 300 + frames_per_input]):   # Take frames_per_input frames. skips first 5 seconds.
        buttons = one_hot_encode(frame.ports[port].leader.pre.buttons.physical.value)
        j_x = frame.ports[port].leader.pre.joystick.x
        j_y = frame.ports[port].leader.pre.joystick.y
        c_x = frame.ports[port].leader.pre.cstick.x
        c_y = frame.ports[port].leader.pre.cstick.y
        t_l = frame.ports[port].leader.pre.triggers.physical.l
        t_r = frame.ports[port].leader.pre.triggers.physical.r

        frame_data = np.array(buttons + [j_x, j_y, c_x, c_y, t_l, t_r]).reshape(1, -1)
        sheik_inputs = np.vstack((sheik_inputs, frame_data))

    return sheik_inputs

<h2> Data Loading </h2>

We begin by iterating through the Slippi Public Dataset, extracting replays of Sheik-Fox games:

In [5]:
dataset_path = './Slippi_Public_Dataset_v3/'

# List files in the dataset with Sheik
slp_files = [file for file in os.listdir(dataset_path) if file.endswith('.slp') and 'Sheik' in file and 'Fox' in file]
print(len(slp_files))

time_series_list =[]
label_list = []
ids = []

# Load the .slp files
for i, slp_file in enumerate(tqdm.tqdm(slp_files)):
    
    # Get file path and store game variable
    file_path = os.path.join(dataset_path, slp_file)
    game = slp.Game(file_path)
    frames = game.frames

    if len(frames) < 300 + frames_per_input:          # Ignore games that are <3600 frames (i.e. <60 seconds)
        continue
    
    # List occupied ports
    occupied_ports = [i for i, port in enumerate(game.start.players) if port is not None]
    port_1 = occupied_ports[0]
    port_2 = occupied_ports[1]    

    if (len(occupied_ports)) > 2:   # Ignore games that aren't singles
            continue

    # Determine characters playing
    port_1_character = game.start.players[port_1].character.name
    port_2_character = game.start.players[port_2].character.name

    frame_data = get_frame_data(frames, port_1)
    time_series_list.append(frame_data)
    label_list.append(1 if port_1_character == 'SHEIK' else 0)
    ids.append(slp_file)
    frame_data = get_frame_data(frames, port_2)
    time_series_list.append(frame_data)
    label_list.append(1 if port_2_character == 'SHEIK' else 0)
    ids.append(slp_file)
    
df = pd.DataFrame({"TimeSeries": time_series_list, "Label": label_list, "FName": ids})

print(df)


2362


 23%|██▎       | 537/2362 [02:02<06:55,  4.39it/s]


KeyboardInterrupt: 

<h1> Data Visualization </h1>

In [3]:
# import os
# import slp
# import pandas as pd
import multiprocessing

# Function to process each slp file
def process_slp_file(slp_file):
    file_path = os.path.join(dataset_path, slp_file)
    game = slp.Game(file_path)
    frames = game.frames

    if len(frames) < 300 + frames_per_input:          # Ignore games that are <3600 frames (i.e. <60 seconds)
        print(f"Ignoring {slp_file} - Game length is less than required.")
        return None
    
    # List occupied ports
    occupied_ports = [i for i, port in enumerate(game.start.players) if port is not None]
    port_1 = occupied_ports[0]
    port_2 = occupied_ports[1]    

    if len(occupied_ports) > 2:   # Ignore games that aren't singles
        print(f"Ignoring {slp_file} - Not a singles game.")
        return None

    # Determine characters playing
    port_1_character = game.start.players[port_1].character.name
    port_2_character = game.start.players[port_2].character.name

    frame_data1 = get_frame_data(frames, port_1)
    frame_data2 = get_frame_data(frames, port_2)

    print(f"Processed {slp_file}")
    return frame_data1, port_1_character, frame_data2, port_2_character, slp_file

if __name__ == '__main__':
    dataset_path = './Slippi_Public_Dataset_v3/'
    frames_per_input = 300  # You'll need to define this variable

    slp_files = [file for file in os.listdir(dataset_path) if file.endswith('.slp') and 'Sheik' in file and 'Fox' in file]
    print(f"Total files to process: {len(slp_files)}")

    time_series_list = []
    label_list = []
    ids = []

    # Create a pool of worker processes
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())

    # Process slp files in parallel
    results = pool.map(process_slp_file, slp_files)

    # Close the pool of worker processes
    pool.close()
    pool.join()

    # Filter out the None values from the results and extract the data
    results = [result for result in results if result is not None]
    for frame_data1, port_1_character, frame_data2, port_2_character, slp_file in results:
        time_series_list.append(frame_data1)
        label_list.append(1 if port_1_character == 'SHEIK' else 0)
        ids.append(slp_file)
        time_series_list.append(frame_data2)
        label_list.append(1 if port_2_character == 'SHEIK' else 0)
        ids.append(slp_file)

    df = pd.DataFrame({"TimeSeries": time_series_list, "Label": label_list, "FName": ids})
    print(df)


2362


  0%|          | 0/2362 [00:00<?, ?it/s]

In [None]:
df.hist()

In [None]:
df.info()

In [None]:
duplicate_rows = df[df.duplicated(subset = 'TimeSeries', keep = False)]

print(duplicate_rows)

<h1> Data Preprocessing </h1>

In [None]:
# Convert to PyTorch tensors
time_series_tensor = torch.tensor(np.array(time_series_list), dtype=torch.float32)
label_tensor = torch.tensor(label_list, dtype=torch.float32)

channels = 18

# Normalize each channel individually
scaler = StandardScaler()
time_series_normalized = torch.zeros(time_series_tensor.shape)

# Iterate over channels
# for i in range(channels):
#     time_series_normalized[:, :, i] = torch.tensor(scaler.fit_transform(time_series_tensor[:, :, i]))

# print(time_series_sensor.shape)
# print(time_series_normalized.shape)

train_data, test_data, train_labels, test_labels = train_test_split(time_series_normalized, label_tensor, test_size = 0.2, shuffle = True, stratify = label_tensor)

print(torch.isnan(time_series_normalized).any())
print(torch.isnan(label_tensor).any())

