In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from numba import njit

In [None]:
# Load the DataFrame from the pickle file
df = pd.read_pickle('./data/Sheik_vs_Fox_full_input_data.pkl')
labels = ['DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_DOWN', 'DPAD_UP', 'Z', 'R', 'L', 'A', 'B', 'X', 'Y', 'START','J_X','J_Y','C_X','C_Y','T_L','T_R']

# Now, 'loaded_df' contains the data from the pickle file as a DataFrame

In [None]:
# Assuming df is your DataFrame with 'TimeSeries' column containing 2D numpy arrays
# Define the indices of the columns to be removed
columns_to_remove = [0, 1, 2, 3, 6, 7, 11]

# Use a list comprehension to create a new 'TimeSeries' column with modified arrays
df['TimeSeries'] = [np.delete(array, columns_to_remove, axis=1) for array in df['TimeSeries']]


In [None]:
# Define the labels for the new columns
new_column_labels = ['Z', 'A', 'B', 'X', 'Y', 'J_X', 'J_Y', 'C_X', 'C_Y', 'T_L', 'T_R']

# Iterate through the new column labels and add them to the DataFrame
for label in new_column_labels:
    df[label] = df['TimeSeries'].apply(lambda arr: arr[:, new_column_labels.index(label)])


In [None]:
# Create a new column 'max_X_Y' with the maximum of 'X' and 'Y' numpy arrays
df['max_X_Y'] = df.apply(lambda row: np.maximum(row['X'], row['Y']), axis=1)

# Create a new column 'max_T' with the maximum of 'T_L' and 'T_Y' numpy arrays
df['max_T'] = df.apply(lambda row: np.maximum(row['T_L'], row['T_R']), axis=1)

# Convert 'J_X' and 'J_Y' to polar coordinates and create 'J_theta' and 'J_radius' columns
df['J_theta'] = df.apply(lambda row: np.arctan2(row['J_Y'], row['J_X']), axis=1)
df['J_radius'] = df.apply(lambda row: np.sqrt(row['J_X'] ** 2 + row['J_Y'] ** 2), axis=1)

# Convert 'C_X' and 'C_Y' to polar coordinates and create 'C_theta' and 'C_radius' columns
df['C_theta'] = df.apply(lambda row: np.arctan2(row['C_Y'], row['C_X']), axis=1)
df['C_radius'] = df.apply(lambda row: np.sqrt(row['C_X'] ** 2 + row['C_Y'] ** 2), axis=1)



# Display the resulting DataFrame
print(df)

In [None]:
df.columns

In [None]:


def overlay_charts_from_columns(df, column_lists, row_index):
    """
    Generate an overlaid chart from selected columns in a DataFrame row.

    Parameters:
    - df: DataFrame containing the data.
    - column_lists: List of lists where each sublist contains column names to overlay.
    - row_index: Row index in the DataFrame for the data to plot.

    Example usage:
    overlay_charts_from_columns(df, [['max_X_Y', 'max_T']], 0)
    """

    # Get the row data for the specified index
    row_data = df.iloc[row_index]

    # Create a new figure for the chart
    plt.figure(figsize=(100, 10))

    # Flatten the list of column names within column_lists
    flattened_column_names = [col for sublist in column_lists for col in sublist]

    # Loop through the flattened column names and plot each column individually
    for column_name in flattened_column_names:
        plt.plot(row_data[column_name], label=f'{column_name}')

    # Set chart title based on the selected column names
    chart_title = ', '.join(flattened_column_names)
    plt.title(chart_title)

    # Add legend to distinguish different columns
    plt.legend()

    # Customize labels, axis titles, etc. as needed
    plt.xlabel('X-axis Label')
    plt.ylabel('Y-axis Label')

    # Show the chart
    plt.show()

# Example usage:
# overlay_charts_from_columns(df, [['max_X_Y', 'max_T']], 0)



In [None]:
n = 7

row = 2*n
# ['TimeSeries', 'Label', 'FName', 'Z', 'R', 'L', 'A', 'B', 'X', 'Y',
#        'J_X', 'J_Y', 'C_X', 'C_Y', 'T_L', 'T_R', 'max_X_Y', 'max_T', 'J_theta',
#        'J_radius', 'C_theta', 'C_radius'],
# column_lists = [['max_X_Y','max_T']]
column_lists = [['J_Y','B']]
overlay_charts_from_columns(df, column_lists, row)

row += 1 
# ['TimeSeries', 'Label', 'FName', 'Z', 'R', 'L', 'A', 'B', 'X', 'Y',
#        'J_X', 'J_Y', 'C_X', 'C_Y', 'T_L', 'T_R', 'max_X_Y', 'max_T', 'J_theta',
#        'J_radius', 'C_theta', 'C_radius'],
# column_lists = [['max_X_Y','max_T']]
column_lists = [['J_Y','B']]
overlay_charts_from_columns(df, column_lists, row)

In [None]:
# Add game lenght column
df['game_length_sec'] = df['Z'].apply(len) / 60
df['game_length_sec'].describe()


In [None]:
# Filter out games longer than 8 minutes
filtered_df = df[df['game_length_sec'] <= 8*60]

# Create a histogram plot of the 'game_length' column
# bins are 4 seconds wide
plt.hist(filtered_df['game_length_sec'] / 60, bins=8*15, color='blue', edgecolor='black')

# Customize the plot with labels and title
plt.xlabel('Game Length (minutes)')
plt.ylabel('Frequency')
plt.title('Histogram of Game Length')

# Show the histogram
plt.show()

In [None]:
# Function to count value switches in a numpy array
@njit()
def count_value_switches(arr):
    if len(arr) < 2:
        return 0

    switches = 0
    prev_value = arr[0]

    for value in arr[1:]:
        if value != prev_value:
            switches += 1
            prev_value = value

    return switches

df['num_B'] = df['B'].apply(count_value_switches)
# Display the resulting DataFrame
print(df['num_B'])

df['B_per_sec'] = df['num_B'] / df['game_length_sec']

In [None]:


# Split the DataFrame into two based on 'Label' values
label_1_data = df[df['Label'] == 1]
label_0_data = df[df['Label'] == 0]

# Create a figure with two subplots (side-by-side)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Calculate the common range for both histograms
common_range = [min(df['B_per_sec']), max(df['B_per_sec'])]

bins_count = 50

# Plot the histograms for 'B_per_sec' for Label 1 and Label 0 with the same range
axs[0].hist(label_1_data['B_per_sec'], bins=bins_count, color='blue', edgecolor='black', alpha=0.7, range=common_range)
axs[1].hist(label_0_data['B_per_sec'], bins=bins_count, color='red', edgecolor='black', alpha=0.7, range=common_range)

# Set the same Y-axis range for both subplots
y_max = max(max(axs[0].get_ylim()), max(axs[1].get_ylim()))
axs[0].set_ylim(0, y_max)
axs[1].set_ylim(0, y_max)

# Set titles and labels for the subplots
axs[0].set_title('Histogram of B_per_sec (Shiek)')
axs[1].set_title('Histogram of B_per_sec (Fox)')
axs[0].set_xlabel('B_per_sec')
axs[1].set_xlabel('B_per_sec')
axs[0].set_ylabel('Frequency')
axs[1].set_ylabel('Frequency')

# Show the histograms
plt.tight_layout()
plt.show()