<h1> Detecting Whether the Inputs are a Sheik </h1>

We want to train a binary classifier to accurately predict whether a multi-channel time series (representing a Super Smash Bros. Melee player's inputs) was produced by a Sheik player. We first load our required libraries.

In [16]:
import os as os

import numpy as np
import pandas as pd

# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import StandardScaler

# import torch
# import torch.nn as nn
# import torch.optim as optim

import tqdm

import slippi as slp

from joblib import Parallel, delayed


<h2> Preliminary Functions </h2>

We use these functions to one-hot encode the button bitmask and get the frame data for a given port number and frames object.

In [17]:
# Set the number of time steps in the model inputs
frames_per_input = 60 * 12     # 12 seconds of gameplay


def one_hot_encode(bitmask):
    labels = ['DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_DOWN', 'DPAD_UP', 'Z', 'R', 'L', 'A', 'B', 'X', 'Y', 'START']
    encoded_values = [1, 2, 4, 8, 16, 32, 64, 256, 512, 1024, 2048, 4096]

    # Create a dictionary mapping labels to their encoded values
    label_to_value = dict(zip(labels, encoded_values))

    # Initialize a list to store the one-hot encoded values
    one_hot_encoded = [0] * len(labels)

    # Iterate through labels and set the corresponding one-hot encoded value
    for label, value in label_to_value.items():
        if bitmask & value:
            one_hot_encoded[labels.index(label)] = 1

    return one_hot_encoded


## Append a new row to a numpy list
# def get_frame_data(frames, port):
#     sheik_inputs = np.empty((0, 18))  # Initialize an empty NumPy array

#     for i, frame in enumerate(frames[300: ]):   # Take frames_per_input frames. skips first 5 seconds.
#         buttons = one_hot_encode(frame.ports[port].leader.pre.buttons.physical.value)
#         j_x = frame.ports[port].leader.pre.joystick.x
#         j_y = frame.ports[port].leader.pre.joystick.y
#         c_x = frame.ports[port].leader.pre.cstick.x
#         c_y = frame.ports[port].leader.pre.cstick.y
#         t_l = frame.ports[port].leader.pre.triggers.physical.l
#         t_r = frame.ports[port].leader.pre.triggers.physical.r

#         frame_data = np.array(buttons + [j_x, j_y, c_x, c_y, t_l, t_r]).reshape(1, -1)
#         sheik_inputs = np.vstack((sheik_inputs, frame_data))

#     return sheik_inputs

## Create a numpy list that is the correct size and fill it with a loop
def get_frame_data(frames, port):
    sheik_inputs = np.empty((len(frames)-300, 18))  # Initialize an empty NumPy array
        
    for i, frame in enumerate(frames[300:  ]):   # Take frames_per_input frames. skips first 5 seconds.
        buttons = one_hot_encode(frame.ports[port].leader.pre.buttons.physical.value)
        j_x = frame.ports[port].leader.pre.joystick.x
        j_y = frame.ports[port].leader.pre.joystick.y
        c_x = frame.ports[port].leader.pre.cstick.x
        c_y = frame.ports[port].leader.pre.cstick.y
        t_l = frame.ports[port].leader.pre.triggers.physical.l
        t_r = frame.ports[port].leader.pre.triggers.physical.r

        frame_data = buttons + [j_x, j_y, c_x, c_y, t_l, t_r]
        sheik_inputs[i] = frame_data

    return sheik_inputs

## List comprehension
# def get_frame_data(frames, port):
#     return np.array([
#         one_hot_encode(frame.ports[port].leader.pre.buttons.physical.value) +
#         [frame.ports[port].leader.pre.joystick.x,
#          frame.ports[port].leader.pre.joystick.y,
#          frame.ports[port].leader.pre.cstick.x,
#          frame.ports[port].leader.pre.cstick.y,
#          frame.ports[port].leader.pre.triggers.physical.l,
#          frame.ports[port].leader.pre.triggers.physical.r]
#         # for frame in frames[300: 300 + frames_per_input]
#         for frame in frames[300:]
#     ])

# Append to a numpy list is vastly slower than the other two.
# Getting the frames in the second two algorithms take the same amount of time,
# but the third option takes longer because it seems to have more to do after it is done to get the dataframe.

<h2> Data Loading </h2>

We begin by iterating through the Slippi Public Dataset, extracting replays of Sheik-Fox games:

In [19]:
import os
# import slp
import tqdm
import pandas as pd
from joblib import Parallel, delayed
from multiprocessing import Manager

# Function to process a single SLP file and append to shared lists
def process_slp_file(slp_file, dataset_path, time_series_list, label_list, ids):
    try:
        file_path = os.path.join(dataset_path, slp_file)
        game = slp.Game(file_path)
        frames = game.frames

        if len(frames) < 300 + frames_per_input:  # Ignore games that are <3600 frames (i.e. <60 seconds)
            return

        # List occupied ports
        occupied_ports = [i for i, port in enumerate(game.start.players) if port is not None]
        port_1 = occupied_ports[0]
        port_2 = occupied_ports[1]

        if len(occupied_ports) > 2:  # Ignore games that aren't singles
            return

        # Determine characters playing
        port_1_character = game.start.players[port_1].character.name
        port_2_character = game.start.players[port_2].character.name

        frame_data = get_frame_data(frames, port_1)
        time_series_list.append(frame_data)
        label_list.append(1 if port_1_character == 'SHEIK' else 0)
        ids.append(slp_file)
        frame_data = get_frame_data(frames, port_2)
        time_series_list.append(frame_data)
        label_list.append(1 if port_2_character == 'SHEIK' else 0)
        ids.append(slp_file)
    except Exception as e:
        print(f"Error processing {slp_file}: {str(e)}")

# Set your dataset_path and frames_per_input
dataset_path = './Slippi_Public_Dataset_v3/'
# frames_per_input = ...

slp_files = [file for file in os.listdir(dataset_path) if file.endswith('.slp') and 'Sheik' in file and 'Fox' in file]
print(len(slp_files))

# Create shared lists to store results
manager = Manager()
time_series_list = manager.list()
label_list = manager.list()
ids = manager.list()

# Use joblib to parallelize processing of SLP files
Parallel(n_jobs=-1, verbose=10)(delayed(process_slp_file)(slp_file, dataset_path, time_series_list, label_list, ids) for slp_file in tqdm.tqdm(slp_files))

# Create a DataFrame from the results
df = pd.DataFrame({"TimeSeries": list(time_series_list), "Label": list(label_list), "FName": list(ids)})


df.sort_values(by=['FName','Label'],inplace=True)
df.reset_index(drop=True)
print(df)


2362


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 24 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    0.7s
[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed:    1.0s
[Parallel(n_jobs=-1)]: Done  50 tasks      | elapsed:    1.2s
[Parallel(n_jobs=-1)]: Done  65 tasks      | elapsed:    1.6s
[Parallel(n_jobs=-1)]: Done  80 tasks      | elapsed:    1.9s
[Parallel(n_jobs=-1)]: Done  97 tasks      | elapsed:    2.4s
[Parallel(n_jobs=-1)]: Done 114 tasks      | elapsed:    2.7s
[Parallel(n_jobs=-1)]: Done 133 tasks      | elapsed:    3.2s
[Parallel(n_jobs=-1)]: Done 152 tasks      | elapsed:    3.6s
[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    4.2s
[Parallel(n_jobs=-1)]: Done 194 tasks      | elapsed:    4.6s
[Parallel(n_jobs=-1)]: Done 217 tasks      | elapsed:    5.3s
[Parallel(n_jobs=-1)]: Done 240 tasks      | elapsed:  

                                             TimeSeries  Label  \
18    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
33    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
30    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
44    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
13    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
...                                                 ...    ...   
4497  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
4508  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
4505  [[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,...      1   
4510  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
4509  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   

                                                  FName  
18            00_37_01.564Z [314] Fox + Sheik (FoD).slp  
33            00_37_01.564Z [314] Fox + Sheik (FoD).slp  
30             00_40_07.217Z [314

In [26]:
# Get memory usage for each column in bytes
memory_usage = df.memory_usage(deep=True)

# Sum the memory usage values to get the total memory usage of the DataFrame
total_memory_usage = memory_usage.sum()

print(f"Total memory usage of the DataFrame: {total_memory_usage} bytes")

Total memory usage of the DataFrame: 1103462 bytes


(4261, 18)
(5702, 18)
(7032, 18)
(10314, 18)


In [33]:
import pandas as pd

# Specify the file path where you want to save the pickle file
pickle_file_path = './data/Sheik_vs_Fox_full_input_data.pkl'

# Save the DataFrame as a pickle file
df.to_pickle(pickle_file_path)


<h1> Data Visualization </h1>

In [22]:
# df.hist()

In [23]:
# df.info()

In [24]:
# duplicate_rows = df[df.duplicated(subset = 'TimeSeries', keep = False)]

# print(duplicate_rows)

<h1> Data Preprocessing </h1>

In [25]:
# # Convert to PyTorch tensors
# time_series_tensor = torch.tensor(np.array(time_series_list), dtype=torch.float32)
# label_tensor = torch.tensor(label_list, dtype=torch.float32)

# channels = 18

# # Normalize each channel individually
# scaler = StandardScaler()
# time_series_normalized = torch.zeros(time_series_tensor.shape)

# # Iterate over channels
# # for i in range(channels):
# #     time_series_normalized[:, :, i] = torch.tensor(scaler.fit_transform(time_series_tensor[:, :, i]))

# # print(time_series_sensor.shape)
# # print(time_series_normalized.shape)

# train_data, test_data, train_labels, test_labels = train_test_split(time_series_normalized, label_tensor, test_size = 0.2, shuffle = True, stratify = label_tensor)

# print(torch.isnan(time_series_normalized).any())
# print(torch.isnan(label_tensor).any())

