In [None]:
import pickle
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
data_dict = pickle.load(open("data.pickle", "rb"))

data = np.asarray(data_dict["data"])
labels = np.asarray(data_dict["labels"])

# Identify the unique classes in the labels
unique_labels = np.unique(labels)

In [None]:
x_train, x_test, y_train, y_test = train_test_split(
    data, labels, test_size=0.2, shuffle=True, random_state=42, stratify=labels)

In [None]:
model = RandomForestClassifier(random_state=42)

model.fit(x_train, y_train)
y_pred = model.predict(x_test)

In [None]:
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

In [None]:
with open("model.p", "wb") as f:
    pickle.dump({"model": model}, f)

In [None]:
print("\nClassification Report:\n")
print(classification_report(y_test, y_pred, labels=unique_labels, target_names=unique_labels.astype(str)))


In [None]:
importances = model.feature_importances_
indices = np.argsort(importances)[::-1]

plt.figure(figsize=(12, 8))
plt.title("Feature Importances")
plt.bar(range(x_train.shape[1]), importances[indices], align="center")
plt.xticks(range(x_train.shape[1]), indices)
plt.xlim([-1, x_train.shape[1]])
plt.xlabel('Feature Index')
plt.ylabel('Importance')
plt.show()

In [None]:
accuracy_list = []
for seed in range(10):
    print("Seed:", seed)
    x_train, x_test, y_train, y_test = train_test_split(
        data, labels, test_size=0.2, shuffle=True, random_state=seed, stratify=labels)
    model = RandomForestClassifier(random_state=42)
    model.fit(x_train, y_train)
    y_pred = model.predict(x_test)
    accuracy_list.append(accuracy_score(y_test, y_pred))

plt.figure(figsize=(10, 6))
plt.plot(range(10), accuracy_list, marker='o', linestyle='-')
plt.title('Accuracy over Different Random States')
plt.xlabel('Random State')
plt.ylabel('Accuracy')
plt.show()