diff --git a/fiftyone/core/dataset.py b/fiftyone/core/dataset.py index 672ce1a36b..651a58e338 100644 --- a/fiftyone/core/dataset.py +++ b/fiftyone/core/dataset.py @@ -7605,14 +7605,16 @@ def _merge_samples_python( dynamic=False, num_samples=None, ): + if dataset.media_type == fom.GROUP: + dst = dataset.select_group_slices(_allow_mixed=True) + else: + dst = dataset + if ( isinstance(samples, foc.SampleCollection) and samples.media_type == fom.GROUP ): samples = samples.select_group_slices(_allow_mixed=True) - dst = dataset.select_group_slices(_allow_mixed=True) - else: - dst = dataset if num_samples is None: try: diff --git a/tests/unittests/group_tests.py b/tests/unittests/group_tests.py index dfb54b480a..9bba38883f 100644 --- a/tests/unittests/group_tests.py +++ b/tests/unittests/group_tests.py @@ -1046,6 +1046,26 @@ def test_merge_groups5(self): self.assertEqual(dataset.count("frames"), 4) self.assertEqual(len(set(samples.values("frames.id", unwind=True))), 4) + @drop_datasets + def test_merge_groups6(self): + dataset = _make_group_dataset() + + view = dataset.select_group_slices(_allow_mixed=True) + samples = list(view) + + self.assertEqual(len(dataset), 2) + self.assertEqual(dataset.count("frames"), 2) + self.assertEqual(len(view), 6) + + key_fcn = lambda sample: sample.filepath + dataset.merge_samples(samples, key_fcn=key_fcn) + + view = dataset.select_group_slices(_allow_mixed=True) + + self.assertEqual(len(dataset), 2) + self.assertEqual(dataset.count("frames"), 2) + self.assertEqual(len(view), 6) + @drop_datasets def test_indexes(self): dataset = _make_group_dataset()