In [None]:
import json
import numpy as np
import pandas as pd
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.multiclass import OneVsRestClassifier

# Assuming these variables are set
dsr_training_file = ""
dsr_predict_column = ""
dsr_features = []
dsr_parameters_str = ""

# Load and prepare data
train = pd.read_csv(dsr_training_file)
target = train[dsr_predict_column]
train = train.drop([dsr_predict_column], axis=1)

x = pd.DataFrame(train)
y = target
col_names = pd.Index(dsr_features)
features = x[col_names]

# Standardize features
scaler = StandardScaler().fit(features.values)
features = scaler.transform(features.values)
x[col_names] = features
x = np.array(x)

# Encode target labels
le = LabelEncoder()
y = le.fit_transform(y)

# Split the data
x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.3, random_state=2439
)

# Set up GridSearchCV
if dsr_parameters_str:
    param_grid = json.loads(dsr_parameters_str)
else:
    param_grid = {
        'estimator__var_smoothing': np.logspace(-9, -5, 30)
    }

# Use OneVsRestClassifier for multiclass support
grid_search = GridSearchCV(
    OneVsRestClassifier(GaussianNB()),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1
)

# Fit the model
grid_search.fit(x_train, y_train)
best_nb = grid_search.best_estimator_

# Make predictions
y_pred_nb = best_nb.predict(x_test)
y_pred_proba_nb = best_nb.predict_proba(x_test)

# Calculate metrics
accuracy_score_result = accuracy_score(y_test, y_pred_nb)

# For multiclass ROC AUC, we use the OvR approach
roc_auc_score_result = roc_auc_score(y_test, y_pred_proba_nb, multi_class='ovr', average='macro')

# For multiclass F1 score, we use the macro average
f1_score_result = f1_score(y_test, y_pred_nb, average='macro')

# Print results
print("best_parameters:", json.dumps(grid_search.best_params_))
print("accuracy_score:", accuracy_score_result)
print("roc_auc_score:", roc_auc_score_result)
print("f1_score:", f1_score_result)
print("--end--")
