<h1> Detecting Whether the Inputs are a Sheik </h1>

We want to train a binary classifier to accurately predict whether a multi-channel time series (representing a Super Smash Bros. Melee player's inputs) was produced by a Sheik player. We first load our required libraries.

In [25]:
# pip install numba


Note: you may need to restart the kernel to use updated packages.


In [1]:
import os as os

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.optim as optim

import tqdm

import slippi as slp

from joblib import Parallel, delayed
from numba import jit, njit

<h2> Preliminary Functions </h2>

We use these functions to one-hot encode the button bitmask and get the frame data for a given port number and frames object.

In [50]:
# Set the number of time steps in the model inputs
frames_per_input = 60 * 12     # 12 seconds of gameplay

# @jit()
def one_hot_encode(bitmask):
    labels = ['DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_DOWN', 'DPAD_UP', 'Z', 'R', 'L', 'A', 'B', 'X', 'Y', 'START']
    encoded_values = [1, 2, 4, 8, 16, 32, 64, 256, 512, 1024, 2048, 4096]

    # Create a dictionary mapping labels to their encoded values
    label_to_value = dict(zip(labels, encoded_values))

    # Initialize a list to store the one-hot encoded values
    one_hot_encoded = [0] * len(labels)

    # Iterate through labels and set the corresponding one-hot encoded value
    for label, value in label_to_value.items():
        if bitmask & value:
            one_hot_encoded[labels.index(label)] = 1

    return one_hot_encoded
# @jit()
def get_frame_data(frames, port):
    sheik_inputs = np.empty((0, 18))  # Initialize an empty NumPy array

    for i, frame in enumerate(frames[300: 300 + frames_per_input]):   # Take frames_per_input frames. skips first 5 seconds.
        buttons = one_hot_encode(frame.ports[port].leader.pre.buttons.physical.value)
        j_x = frame.ports[port].leader.pre.joystick.x
        j_y = frame.ports[port].leader.pre.joystick.y
        c_x = frame.ports[port].leader.pre.cstick.x
        c_y = frame.ports[port].leader.pre.cstick.y
        t_l = frame.ports[port].leader.pre.triggers.physical.l
        t_r = frame.ports[port].leader.pre.triggers.physical.r

        frame_data = np.array(buttons + [j_x, j_y, c_x, c_y, t_l, t_r]).reshape(1, -1)
        sheik_inputs = np.vstack((sheik_inputs, frame_data))

    return sheik_inputs

<h2> Data Loading </h2>

We begin by iterating through the Slippi Public Dataset, extracting replays of Sheik-Fox games:

In [51]:
import os
# import slp
import tqdm
import pandas as pd
from joblib import Parallel, delayed
from multiprocessing import Manager

# Function to process a single SLP file and append to shared lists
def process_slp_file(slp_file, dataset_path, time_series_list, label_list, ids):
    try:
        file_path = os.path.join(dataset_path, slp_file)
        game = slp.Game(file_path)
        frames = game.frames

        if len(frames) < 300 + frames_per_input:  # Ignore games that are <3600 frames (i.e. <60 seconds)
            return

        # List occupied ports
        occupied_ports = [i for i, port in enumerate(game.start.players) if port is not None]
        port_1 = occupied_ports[0]
        port_2 = occupied_ports[1]

        if len(occupied_ports) > 2:  # Ignore games that aren't singles
            return

        # Determine characters playing
        port_1_character = game.start.players[port_1].character.name
        port_2_character = game.start.players[port_2].character.name

        frame_data = get_frame_data(frames, port_1)
        time_series_list.append(frame_data)
        label_list.append(1 if port_1_character == 'SHEIK' else 0)
        ids.append(slp_file)
        frame_data = get_frame_data(frames, port_2)
        time_series_list.append(frame_data)
        label_list.append(1 if port_2_character == 'SHEIK' else 0)
        ids.append(slp_file)
    except Exception as e:
        print(f"Error processing {slp_file}: {str(e)}")

# Set your dataset_path and frames_per_input
dataset_path = './Slippi_Public_Dataset_v3/'
# frames_per_input = ...

slp_files = [file for file in os.listdir(dataset_path) if file.endswith('.slp') and 'Sheik' in file and 'Fox' in file]
print(len(slp_files))

# Create shared lists to store results
manager = Manager()
time_series_list = manager.list()
label_list = manager.list()
ids = manager.list()

# Use joblib to parallelize processing of SLP files
Parallel(n_jobs=-1, verbose=10)(delayed(process_slp_file)(slp_file, dataset_path, time_series_list, label_list, ids) for slp_file in tqdm.tqdm(slp_files))

# Create a DataFrame from the results
df = pd.DataFrame({"TimeSeries": list(time_series_list), "Label": list(label_list), "FName": list(ids)})

print(df)


2362





[A[A[A[Parallel(n_jobs=-1)]: Using backend LokyBackend with 24 concurrent workers.



[A[A[A[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:    0.3s
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    0.4s



[A[A[A[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed:    0.6s



[A[A[A[Parallel(n_jobs=-1)]: Done  50 tasks      | elapsed:    0.8s
[Parallel(n_jobs=-1)]: Done  65 tasks      | elapsed:    1.0s



[A[A[A[Parallel(n_jobs=-1)]: Done  80 tasks      | elapsed:    1.2s
[Parallel(n_jobs=-1)]: Done  97 tasks      | elapsed:    1.5s



[A[A[A[Parallel(n_jobs=-1)]: Done 114 tasks      | elapsed:    1.7s



[A[A[A[Parallel(n_jobs=-1)]: Done 133 tasks      | elapsed:    2.0s



[A[A[A[Parallel(n_jobs=-1)]: Done 152 tasks      | elapsed:    2.2s



[A[A[A[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:    2.6s



[A[A[A[Parallel(n_jobs=-1)]: Done 194 tasks      | elapsed:    2.

                                             TimeSeries  Label  \
0     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
1     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
2     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
3     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,...      1   
4     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
...                                                 ...    ...   
4511  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
4512  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
4513  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
4514  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,...      1   
4515  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,...      0   

                                                  FName  
0               02_29_27.272Z Sheik + [C2] Fox (YS).slp  
1                         10_33_38 Sheik + Fox (DL).slp  
2     01_01_39.822Z [EASY] Fox + 

In [53]:
print(df )

                                             TimeSeries  Label  \
0     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
1     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
2     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
3     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,...      1   
4     [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
...                                                 ...    ...   
4511  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
4512  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
4513  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
4514  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,...      1   
4515  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,...      0   

                                                  FName  
0               02_29_27.272Z Sheik + [C2] Fox (YS).slp  
1                         10_33_38 Sheik + Fox (DL).slp  
2     01_01_39.822Z [EASY] Fox + 

In [54]:
# Get memory usage for each column in bytes
memory_usage = df.memory_usage(deep=True)

# Sum the memory usage values to get the total memory usage of the DataFrame
total_memory_usage = memory_usage.sum()

print(f"Total memory usage of the DataFrame: {total_memory_usage} bytes")

Total memory usage of the DataFrame: 1103482 bytes


In [None]:
dataset_path = './Slippi_Public_Dataset_v3/'

# List files in the dataset with Sheik
slp_files = [file for file in os.listdir(dataset_path) if file.endswith('.slp') and 'Sheik' in file and 'Fox' in file][:100]
print(len(slp_files))

time_series_list =[]
label_list = []
ids = []

# Load the .slp files
for i, slp_file in enumerate(tqdm.tqdm(slp_files)):
    
    # Get file path and store game variable
    file_path = os.path.join(dataset_path, slp_file)
    game = slp.Game(file_path)
    frames = game.frames

    if len(frames) < 300 + frames_per_input:          # Ignore games that are <3600 frames (i.e. <60 seconds)
        continue
    
    # List occupied ports
    occupied_ports = [i for i, port in enumerate(game.start.players) if port is not None]
    port_1 = occupied_ports[0]
    port_2 = occupied_ports[1]    

    if (len(occupied_ports)) > 2:   # Ignore games that aren't singles
            continue

    # Determine characters playing
    port_1_character = game.start.players[port_1].character.name
    port_2_character = game.start.players[port_2].character.name

    frame_data = get_frame_data(frames, port_1)
    time_series_list.append(frame_data)
    label_list.append(1 if port_1_character == 'SHEIK' else 0)
    ids.append(slp_file)
    frame_data = get_frame_data(frames, port_2)
    time_series_list.append(frame_data)
    label_list.append(1 if port_2_character == 'SHEIK' else 0)
    ids.append(slp_file)
    
df = pd.DataFrame({"TimeSeries": time_series_list, "Label": label_list, "FName": ids})

print(df)


100


100%|██████████| 100/100 [00:19<00:00,  5.22it/s]

                                            TimeSeries  Label  \
0    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
1    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
2    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
3    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
4    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
..                                                 ...    ...   
195  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
196  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      1   
197  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
198  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...      0   
199  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,...      1   

                                                FName  
0           00_37_01.564Z [314] Fox + Sheik (FoD).slp  
1           00_37_01.564Z [314] Fox + Sheik (FoD).slp  
2            00_40_07.217Z [314] Fox + Sheik (DL).s




In [61]:
pip install feather-format


Collecting feather-format
  Downloading feather-format-0.4.1.tar.gz (3.2 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Building wheels for collected packages: feather-format
  Building wheel for feather-format (setup.py): started
  Building wheel for feather-format (setup.py): finished with status 'done'
  Created wheel for feather-format: filename=feather_format-0.4.1-py3-none-any.whl size=2453 sha256=f4d3de5a42b814a1ef4c33b0019258f785fe91744d4607b1a2b646679e4ed932
  Stored in directory: c:\users\jaspa\appdata\local\pip\cache\wheels\77\5b\0e\0e63d10b6353208a085a321ea2eed2578f220a77bb8a4bd7ab
Successfully built feather-format
Installing collected packages: feather-format
Successfully installed feather-format-0.4.1
Note: you may need to restart the kernel to use updated packages.


In [63]:
import pandas as pd

# Specify the file path where you want to save the pickle file
pickle_file_path = 'example.pkl'

# Save the DataFrame as a pickle file
df.to_pickle(pickle_file_path)


<h1> Data Visualization </h1>

In [None]:
# import os
# import slp
# import pandas as pd
import multiprocessing

# Function to process each slp file
def process_slp_file(slp_file):
    file_path = os.path.join(dataset_path, slp_file)
    game = slp.Game(file_path)
    frames = game.frames

    if len(frames) < 300 + frames_per_input:          # Ignore games that are <3600 frames (i.e. <60 seconds)
        print(f"Ignoring {slp_file} - Game length is less than required.")
        return None
    
    # List occupied ports
    occupied_ports = [i for i, port in enumerate(game.start.players) if port is not None]
    port_1 = occupied_ports[0]
    port_2 = occupied_ports[1]    

    if len(occupied_ports) > 2:   # Ignore games that aren't singles
        print(f"Ignoring {slp_file} - Not a singles game.")
        return None

    # Determine characters playing
    port_1_character = game.start.players[port_1].character.name
    port_2_character = game.start.players[port_2].character.name

    frame_data1 = get_frame_data(frames, port_1)
    frame_data2 = get_frame_data(frames, port_2)

    print(f"Processed {slp_file}")
    return frame_data1, port_1_character, frame_data2, port_2_character, slp_file

if __name__ == '__main__':
    dataset_path = './Slippi_Public_Dataset_v3/'
    frames_per_input = 300  # You'll need to define this variable

    slp_files = [file for file in os.listdir(dataset_path) if file.endswith('.slp') and 'Sheik' in file and 'Fox' in file]
    print(f"Total files to process: {len(slp_files)}")

    time_series_list = []
    label_list = []
    ids = []

    # Create a pool of worker processes
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())

    # Process slp files in parallel
    results = pool.map(process_slp_file, slp_files)

    # Close the pool of worker processes
    pool.close()
    pool.join()

    # Filter out the None values from the results and extract the data
    results = [result for result in results if result is not None]
    for frame_data1, port_1_character, frame_data2, port_2_character, slp_file in results:
        time_series_list.append(frame_data1)
        label_list.append(1 if port_1_character == 'SHEIK' else 0)
        ids.append(slp_file)
        time_series_list.append(frame_data2)
        label_list.append(1 if port_2_character == 'SHEIK' else 0)
        ids.append(slp_file)

    df = pd.DataFrame({"TimeSeries": time_series_list, "Label": label_list, "FName": ids})
    print(df)


2362


  0%|          | 0/2362 [00:00<?, ?it/s]

In [None]:
df.hist()

In [None]:
df.info()

In [None]:
duplicate_rows = df[df.duplicated(subset = 'TimeSeries', keep = False)]

print(duplicate_rows)

<h1> Data Preprocessing </h1>

In [None]:
# Convert to PyTorch tensors
time_series_tensor = torch.tensor(np.array(time_series_list), dtype=torch.float32)
label_tensor = torch.tensor(label_list, dtype=torch.float32)

channels = 18

# Normalize each channel individually
scaler = StandardScaler()
time_series_normalized = torch.zeros(time_series_tensor.shape)

# Iterate over channels
# for i in range(channels):
#     time_series_normalized[:, :, i] = torch.tensor(scaler.fit_transform(time_series_tensor[:, :, i]))

# print(time_series_sensor.shape)
# print(time_series_normalized.shape)

train_data, test_data, train_labels, test_labels = train_test_split(time_series_normalized, label_tensor, test_size = 0.2, shuffle = True, stratify = label_tensor)

print(torch.isnan(time_series_normalized).any())
print(torch.isnan(label_tensor).any())

