diff --git a/caliban_toolbox/dataset_splitter.py b/caliban_toolbox/dataset_splitter.py index c009755..603e50a 100644 --- a/caliban_toolbox/dataset_splitter.py +++ b/caliban_toolbox/dataset_splitter.py @@ -55,7 +55,7 @@ def __init__(self, seed=0, splits=None): self.splits = splits def _validate_dict(self, train_dict): - if 'X' not in train_dict.keys() or 'y' not in train_dict.keys(): + if 'X' not in train_dict or 'y' not in train_dict: raise ValueError('X and y must be keys in the training dictionary') def split(self, train_dict):