@@ -601,8 +601,6 @@ def __init__(
601
601
self .datasets : dict [str , SupervisedDataset ] = {
602
602
ds .outcome_names [0 ]: ds for ds in datasets
603
603
}
604
- self .feature_names = datasets [0 ].feature_names
605
- self .outcome_names = list (self .datasets .keys ())
606
604
self .parameter_decomposition = parameter_decomposition
607
605
self .metric_decomposition = metric_decomposition
608
606
self ._validate_datasets ()
@@ -614,6 +612,14 @@ def __init__(
614
612
}
615
613
self .group_indices = None
616
614
615
+ @property
616
+ def feature_names (self ) -> list [str ]:
617
+ return self .datasets [self .outcome_names [0 ]].feature_names
618
+
619
+ @property
620
+ def outcome_names (self ) -> list [str ]:
621
+ return list (self .datasets .keys ())
622
+
617
623
@property
618
624
def X (self ) -> Tensor :
619
625
return self .datasets [self .outcome_names [0 ]].X
@@ -737,6 +743,14 @@ def _validate_decompositions(self) -> None:
737
743
f"{ outcome } is missing in metric_decomposition."
738
744
)
739
745
746
+ def __eq__ (self , other : Any ) -> bool :
747
+ return (
748
+ type (other ) is type (self )
749
+ and self .datasets == other .datasets
750
+ and self .parameter_decomposition == other .parameter_decomposition
751
+ and self .metric_decomposition == other .metric_decomposition
752
+ )
753
+
740
754
def clone (
741
755
self , deepcopy : bool = False , mask : Tensor | None = None
742
756
) -> ContextualDataset :
0 commit comments