diff --git a/chefboost/training/Training.py b/chefboost/training/Training.py index cf4b047..2a8f231 100644 --- a/chefboost/training/Training.py +++ b/chefboost/training/Training.py @@ -556,11 +556,8 @@ def buildDecisionTree( # --------------------------- # add else condition in the decision tree - if df.Decision.dtypes == "object": # classification - pivot = pd.DataFrame(subdataset.Decision.value_counts()).sort_values( - by=["count"], ascending=False - ) + pivot = pd.DataFrame(df.Decision.value_counts()).sort_values(by=["count"], ascending=False) else_decision = f"return '{str(pivot.iloc[0].name)}'" if enableParallelism != True: @@ -588,7 +585,7 @@ def buildDecisionTree( decision_rules.append(sample_rule) else: # regression - else_decision = f"return {subdataset.Decision.mean()}" + else_decision = f"return {df.Decision.mean()}" if enableParallelism != True: functions.storeRule(file, (functions.formatRule(root), "else:"))