In [1]:
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
import plotly.graph_objects as go
import numpy as np
import plotly.express as px
import plotly

In [2]:
data = pd.read_csv("../Datasets/star_classification.csv")

In [3]:
data.head()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
0,1.237661e+18,135.689107,32.494632,23.87882,22.2753,20.39501,19.16573,18.79371,3606,301,2,79,6.543777e+18,GALAXY,0.634794,5812,56354,171
1,1.237665e+18,144.826101,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.176014e+19,GALAXY,0.779136,10445,58158,427
2,1.237661e+18,142.18879,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.1522e+18,GALAXY,0.644195,4576,55592,299
3,1.237663e+18,338.741038,-0.402828,22.13682,23.77656,21.61162,20.50454,19.2501,4192,301,3,214,1.030107e+19,GALAXY,0.932346,9149,58039,775
4,1.23768e+18,345.282593,21.183866,19.43718,17.58028,16.49747,15.97711,15.54461,8102,301,3,137,6.891865e+18,GALAXY,0.116123,6121,56187,842


In [4]:
y = data["class"]
X = data.drop("class", axis=1)

In [5]:
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

In [6]:
classifier = RandomForestClassifier()

classifier.fit(x_train, y_train)

In [7]:
predictions = classifier.predict(x_test)

In [8]:
print(f"Accuracy {accuracy_score(y_test, predictions)}")
print(f"Precision {f1_score(y_test, predictions, average='macro')}")

Accuracy 0.9777333333333333
Precision 0.9739857982076933


In [9]:
arr = np.arange(1, 101, 10)
acc_arr = []
p_arr = []
r_arr = []
for i in arr:
    cls = RandomForestClassifier(n_estimators=i)
    cls.fit(x_train, y_train)
    predict = cls.predict(x_test)
    acc_arr.append(accuracy_score(y_test, predict))
    p_arr.append(precision_score(y_test, predict, average="macro"))
    r_arr.append(recall_score(y_test, predict, average="macro"))

In [10]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=arr, y=acc_arr, name="accuracy"))
fig.add_trace(go.Scatter(x=arr, y=p_arr, name="precision"))
fig.add_trace(go.Scatter(x=arr, y=r_arr, name="Recall"))
fig.update_layout(
    title="Varying trees in Random Forest on Stellar Dataset",
    xaxis_title="Number of trees/Bootstrap samples",
    yaxis_title="Score",
    legend_title="Metrics",
    font = dict(
        family="Courier new, monospace"
    )
)
fig.show()

In [11]:
plotly.io.write_image(fig, 'stars-varytrees.pdf', format='pdf')