In [90]:
# Imports
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
import imodels
from imodels import OptimalTreeClassifier
import matplotlib.pyplot as plt

In [91]:
# Load data
data = pd.read_csv('Womens_Clothing_E-Commerce_Reviews.csv')

# Drop missing values in 'Review Text' and 'Recommended IND'
data = data.dropna(subset=['Review Text', 'Recommended IND'])

In [102]:
tfidf_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
X_tfidf = tfidf_vectorizer.fit_transform(data['Review Text'].values)
y = np.array(data['Recommended IND'].values)

feature_names = tfidf_vectorizer.get_feature_names_out()

In [103]:
# Splitting the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)

# Initialize and fit the optimal decision tree model
optimal_tree_model = OptimalTreeClassifier()  # Initialize without feature_names
optimal_tree_model.fit(X_train, y_train)

train_predictions = optimal_tree_model.predict(X_train)
test_predictions = optimal_tree_model.predict(X_test)




In [104]:
tree_structure_str = str(optimal_tree_model)

# Replace the generic feature labels with actual words
for i, feature_name in enumerate(feature_names):
    tree_structure_str = tree_structure_str.replace(f'feature_{i}', feature_name)

# Now, tree_structure_str contains the tree structure with actual feature words
print(tree_structure_str)

# Calculate and print training and test accuracy
train_predictions = optimal_tree_model.predict(X_train)
test_predictions = optimal_tree_model.predict(X_test)
training_accuracy = np.mean(train_predictions == y_train)
test_accuracy = np.mean(test_predictions == y_test)

print(f"Training Accuracy: {training_accuracy:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")

> ------------------------------
> Greedy CART Tree:
> 	Prediction is made by looking at the value in the appropriate leaf of the tree
> ------------------------------
|--- 1038 <= 0.12
|   |--- 12508 <= 0.05
|   |   |--- 1182 <= 0.12
|   |   |   |--- 11573 <= 0.16
|   |   |   |   |--- 11572 <= 0.18
|   |   |   |   |   |--- 12507 <= 0.19
|   |   |   |   |   |   |--- 1100 <= 0.14
|   |   |   |   |   |   |   |--- 0p49 <= 0.17
|   |   |   |   |   |   |   |   |--- 11571 <= 0.15
|   |   |   |   |   |   |   |   |   |--- 12533 <= 0.12
|   |   |   |   |   |   |   |   |   |   |--- 1190 <= 0.07
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 206
|   |   |   |   |   |   |   |   |   |   |--- 1190 >  0.07
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 94
|   |   |   |   |   |   |   |   |   |--- 12533 >  0.12
|   |   |   |   |   |   |   |   |   |   |--- 1190 <= 0.09
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 19
| 