# Tree methods

## Imports

In [None]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn import tree
from sklearn import ensemble
from sklearn.model_selection import train_test_split

from src.utils.const import DATA_DIR, SEED

### Useful path to data

In [None]:
ROOT_DIR = os.path.join(os.getcwd(), '..')
PROCESSED_DIR = os.path.join(ROOT_DIR, DATA_DIR, 'processed')

## Import final dataset

In [None]:
final = pd.read_parquet(
    os.path.join(PROCESSED_DIR, 'final.parquet')
)

final.head()

## Prepare features and target

In [None]:
target = 'rating_mean'

X = final.loc[:, final.columns != target]
y = pd.cut(final.loc[:, target], bins=10, labels=False)

### Split in train, test and validation

In [None]:
def split(data):
    train_tmp, test = train_test_split(data, test_size=0.2, random_state=SEED)
    train, val = train_test_split(train_tmp, test_size=0.1, random_state=SEED)

    return train, test, val

In [None]:
X_train, X_test, X_val = split(X)
y_train, y_test, y_val = split(y)

## DecisionTreeClassifier

In [None]:
dtc = tree.DecisionTreeClassifier()
dtc.fit(X_train, y_train)

In [None]:
# tree.plot_tree(dtc, filled=True)
# print(tree.export_text(dtc))

In [None]:
dtc.predict(X_test)
print(f'Avg accuracy: {dtc.score(X_test, y_test)}')

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

ConfusionMatrixDisplay.from_estimator(dtc, y_train, y_test)

## RandomForestClassifier

In [None]:
rf = ensemble.RandomForestClassifier()
rf.fit(X_train, np.ravel(y_train))
rf.predict(X_test)
print(f'Avg accuracy: {rf.score(X_test, y_test)}')
rankVar = pd.Series(rf.feature_importances_, index=X.columns).sort_values(ascending=False)
print(rankVar)

In [None]:
sns.barplot(x=rankVar, y=rankVar.index)
plt.xlabel('Variable Importance Score')
plt.ylabel('Variables')
plt.show()

## GradientBoostingClassifier

In [None]:
gb = ensemble.GradientBoostingClassifier()
gb.fit(X_train, y_train)
gb.predict(X_test)
score = gb.score(X_test, y_test)
print(f'Avg accuracy: {score}')