diff --git a/dtreeviz/shadow.py b/dtreeviz/shadow.py index 6e085b7f..3e4730fe 100644 --- a/dtreeviz/shadow.py +++ b/dtreeviz/shadow.py @@ -40,7 +40,7 @@ def __init__(self, tree_model, self.class_names = class_names self.class_weight = tree_model.class_weight - if getattr(tree_model, 'tree_') is None: # make sure model is fit + if not hasattr(tree_model, 'tree_'): # make sure model is fit tree_model.fit(X_train, y_train) if tree_model.tree_.n_classes > 1: