Skip to content

Commit

Permalink
make _task vs task be version sensitive.
Browse files Browse the repository at this point in the history
Signed-off-by: Terence Parr <parrt@antlr.org>
  • Loading branch information
parrt committed Jan 28, 2023
1 parent ffda2a0 commit 2f10aaa
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions dtreeviz/models/tensorflow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Mapping

import numpy as np
import tensorflow_decision_forests
from tensorflow_decision_forests.component.py_tree.node import LeafNode
from tensorflow_decision_forests.keras import RandomForestModel
from tensorflow_decision_forests.tensorflow.core import Task
Expand Down Expand Up @@ -55,6 +56,8 @@ def get_children_right(self):
return self.children_right

def is_classifier(self) -> bool:
if tensorflow_decision_forests.__version__<'1.2.0':
return self.model._task == Task.CLASSIFICATION
return self.model.task == Task.CLASSIFICATION

def get_class_weights(self):
Expand Down

0 comments on commit 2f10aaa

Please sign in to comment.