In [1]:
import random

import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

In [2]:
data_df = None
important_columns = [
    "release_speed",
    "release_pos_x",
    "release_pos_y",
    "release_pos_z",
    "pfx_x",
    "pfx_z",
    "spin_axis",
    "release_spin_rate",
    "plate_x",
    "plate_z",
    "effective_speed",
    "pitch_name",
]

X = None
y = None

label_map = None

x_train = None
x_test = None
y_train = None
y_test = None

clf = None
wrapped_clf = None

## Data Preprocessing

In [3]:
def csv_to_df(file):
    df = pd.read_csv(file)
    df = df[important_columns]
    df = df[(df["pitch_name"] != '') & (df["pitch_name"] != 'Pitch Out') & (df["pitch_name"] != 'Intentional Ball')]
    df = df.apply(pd.to_numeric, errors='coerce').fillna(df)
    return df

In [4]:
def get_x_y_label_map(df):
    global data_df, X, y, label_map
    
    X = df.drop(columns=['pitch_name'])
    X = X.apply(pd.to_numeric, errors='coerce')
    imputer = SimpleImputer(strategy='mean')
    X = imputer.fit_transform(X)

    y = df[['pitch_name']]
    label_encoder = LabelEncoder()
    y = y.apply(label_encoder.fit_transform)

    label_map = {int(index): label for index, label in enumerate(label_encoder.classes_)}

    return X, y, label_map

In [5]:
def fit_model_get_acc():
    global x_train, y_train, clf
    clf = RandomForestClassifier(n_estimators=100, random_state=2024)
    clf.fit(x_train, y_train.values.ravel())

    return clf.score(x_test, y_test)

In [6]:
def prepare_split_and_fit_dataset(file):
    global data_df, important_columns, X, y, label_map, x_train, x_test, y_train, y_test, clf

    data_df = csv_to_df(file)
    X, y, label_map = get_x_y_label_map(data_df)
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2024)
    
    return fit_model_get_acc()

## Generate Random Test Data

In [7]:
def get_random_data():
    random_index = random.randint(0, len(x_test) - 1)
    random_data = x_test[random_index].reshape(1, -1)

    return random_data

## Generate Prediction

In [8]:
def get_prediction(df):
    global clf
    
    df = df.apply(pd.to_numeric, errors='coerce').fillna(df)
    input_data = df.values
    input_data = input_data.reshape(1, -1)
    y_prob = clf.predict_proba(input_data)
    return {label_map[index]: prob for index, prob in enumerate(y_prob[0])}

## Generate Scatter Plot

In [9]:
def plot_scatter():
    global data_df

    color_sequence = px.colors.qualitative.Light24

    fig = px.scatter(
        data_df,
        x="release_speed",
        y="release_spin_rate",
        color="pitch_name",
        size_max=8,
        size=[1]*len(data_df),
        color_discrete_sequence=color_sequence,
    )

    fig.update_layout(
        title="Scatter Plot of Release Speed vs Release Spin Rate",
        title_x=0.5,
        xaxis_title="release speed (mph)",
        yaxis_title="release spin rate (rpm)",
        xaxis=dict(
            type='linear',
        ),
        yaxis=dict(
            type='linear',
        ),
    )

    return fig

## Generate Captum Bar Chart

In [10]:
import torch
from captum.attr import FeatureAblation

class WrappedModel(torch.nn.Module):
    def __init__(self, rf_model):
        super(WrappedModel, self).__init__()
        self.rf_model = rf_model

    def forward(self, x):
        x_np = x.detach().numpy()
        preds = self.rf_model.predict_proba(x_np)
        return torch.tensor(preds, dtype=torch.float32)

In [11]:
def get_captum_barchart(df):
    global clf, feature_ablation

    wrapped_clf = WrappedModel(clf)
    feature_ablation = FeatureAblation(wrapped_clf)

    df = df.apply(pd.to_numeric, errors='coerce').fillna(df)
    input_data = df.values
    input_data = input_data.reshape(1, -1)
    
    y_prob = clf.predict_proba(input_data)
    y_pred = clf.predict(input_data)

    X_tensor = torch.tensor(df.values, dtype=torch.float32)
    attr = feature_ablation.attribute(X_tensor, target=y_pred.item())
    mean_contributions = attr.mean(dim=0).numpy()

    fig, ax = plt.subplots(figsize=(20, 5))
    ax.bar(df.columns, mean_contributions, width=0.5)
    ax.set(xlabel='Feature', ylabel='Contribution', title='Feature Contribution')

    return fig

## Gradio App

In [12]:
title = (
    """
    <center>
        <h1> ⚾️ MLB Pitch Classifier </h1>
        <b> Upload a dataset and the model will recognize the pitch type for the given data. </b>
    </center>
    """
)

In [13]:
import gradio as gr

columns = ["release_speed", "release_pos_x", "release_pos_y", "release_pos_z",
           "pfx_x", "pfx_z", "spin_axis", "release_spin_rate",
           "plate_x", "plate_z", "effective_speed"]

with gr.Blocks(theme="citrus") as app:
    gr.Markdown(title)
    gr.Markdown("Step 1: Upload a CSV file and the model will fit the dataset")

    with gr.Row(equal_height=True):
        with gr.Column():
            file_input = gr.File(label="Upload CSV", type="filepath")

        with gr.Column():
            acc_display = gr.Textbox(label="Accuracy")
            fit_dataset_button = gr.Button("Fit Model")
    
    with gr.Row():
        scatter_plot = gr.Plot(label="Scatter Plot")
    
    gr.Markdown("Step 2: Input the data and get the model's prediction")

    with gr.Row():
        data_input = gr.DataFrame(headers=columns, datatype="number", wrap=False)
    
    with gr.Row():
        random_data_button = gr.Button("Get Random Test Data")
        predict_button = gr.Button("Predict")

    with gr.Row():
        predict_output = gr.Label(num_top_classes=3, label="Prediction")

    with gr.Row():
        captum_output = gr.Plot(label="Feature Importance Bar Chart")

    fit_dataset_button.click(prepare_split_and_fit_dataset, inputs=file_input, outputs=acc_display)
    acc_display.change(plot_scatter, outputs=scatter_plot)

    random_data_button.click(get_random_data, outputs=data_input)

    predict_button.click(get_prediction, inputs=data_input, outputs=predict_output)
    predict_button.click(get_captum_barchart, inputs=data_input, outputs=captum_output)

app.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://66dbe67f1c3cd41383.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


