Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMASynthesizer throws an error when sampling multi table models with three levels of depths #1600

Closed
portovep opened this issue Sep 21, 2023 · 6 comments · Fixed by #1608
Closed
Assignees
Labels
bug Something isn't working
Milestone

Comments

@portovep
Copy link

Environment Details

Please indicate the following details about the environment in which you found the bug:

  • SDV version: 1.4.1.dev0
  • Python version: Python 3.11.5
  • Operating System: All

Error Description

@elisaherrmann and I are getting an error when sampling a multi table model with three levels of depth and multiple root parent nodes with the HMASynthesizer.

The model we are trying to sample:

image

Steps to reproduce

We created an integration test to reproduce:

   def test_hma_two_rootparents_two_children_one_grandchild(self):
            """Test it works on a 'parent-children-grandchild' dataset."""
            # Setup
            child1 = pd.DataFrame(data={
                'child_ID': ['a', 'b', 'c', 'd', 'e'],
                'parent_ID1': [0, 1, 2, 3, 3],
                'data': ['0', '1', '2', '3', '4']
            })
            child2 = pd.DataFrame(data={
                'child_ID': ['a', 'b', 'c', 'd', 'e'],
                'parent_ID2': [0, 1, 2, 3, 4],
                'data': ['0', '1', '2', '3', '4']
            })
            root_parent1 = pd.DataFrame(data={
                'parent_ID1': [0, 1, 2, 3, 4],
                'data': [True, False, False, False, True]
            })
            root_parent2 = pd.DataFrame(data={
                'parent_ID2': [0, 1, 2, 3, 4],
                'data': ['Yes', 'Yes', 'Maybe', 'No', 'No']
            })
            grandchild = pd.DataFrame(data={
                'grandchild_ID': ['a', 'b', 'c', 'd', 'e'],
                'child_ID1': ['a', 'b', 'c', 'd', 'e'],
                'child_ID2': ['a', 'b', 'c', 'd', 'e'],
                'data': ['0', '1', '2', '3', '4']
            })
            data = {'parent1': root_parent1, 'child1': child1, 'child2': child2, 'parent2': root_parent2, 'grandchild': grandchild}
            metadata = MultiTableMetadata.load_from_dict({
                'tables': {
                    'parent1': {
                        'primary_key': 'parent_ID1',
                        'columns': {
                            'parent_ID1': {'sdtype': 'id'},
                            'data': {'sdtype': 'categorical'}
                        }
                    },
                    'parent2': {
                        'primary_key': 'parent_ID2',
                        'columns': {
                            'parent_ID2': {'sdtype': 'id'},
                            'data': {'sdtype': 'categorical'}
                        }
                    },
                    'child1': {
                        'primary_key': 'child_ID',
                        'columns': {
                            'child_ID': {'sdtype': 'id'},
                            'parent_ID1': {'sdtype': 'id'},
                            'data': {'sdtype': 'categorical'}
                        }
                    },
                    'child2': {
                        'primary_key': 'child_ID',
                        'columns': {
                            'child_ID': {'sdtype': 'id'},
                            'parent_ID2': {'sdtype': 'id'},
                            'data': {'sdtype': 'categorical'}
                        }
                    },
                    'grandchild': {
                        'primary_key': 'grandchild_ID',
                        'columns': {
                            'grandchild_ID': {'sdtype': 'id'},
                            'child_ID1': {'sdtype': 'id'},
                            'child_ID2': {'sdtype': 'id'},
                            'data': {'sdtype': 'categorical'}
                        }
                    },
                },
                'relationships': [
                    {
                        'parent_table_name': 'parent1',
                        'parent_primary_key': 'parent_ID1',
                        'child_table_name': 'child1',
                        'child_foreign_key': 'parent_ID1'
                    },
                    {
                        'parent_table_name': 'parent2',
                        'parent_primary_key': 'parent_ID2',
                        'child_table_name': 'child2',
                        'child_foreign_key': 'parent_ID2'
                    },
                    {
                        'parent_table_name': 'child1',
                        'parent_primary_key': 'child_ID',
                        'child_table_name': 'grandchild',
                        'child_foreign_key': 'child_ID1'
                    },
                    {
                        'parent_table_name': 'child2',
                        'parent_primary_key': 'child_ID',
                        'child_table_name': 'grandchild',
                        'child_foreign_key': 'child_ID1'
                    },
                ]
            })
            synthesizer = HMASynthesizer(metadata)

            # Run
            synthesizer.fit(data)
            samples = synthesizer.sample(scale=1)

            # Assert tables are the same
            assert set(samples) == set(data)

            # Assert columns are the same
            for table_name, table in samples.items():
                assert set(table.columns) == set(data[table_name].columns)

            # Assert data values all exist in the original tables
            for table_name, table in samples.items():
                assert table['data'].isin(data[table_name]['data']).all()

When we run the integration test above we get this error:

            if (parent_name, child_name) not in added_relationships:
                self._add_foreign_key_columns(
>                   sampled_data[child_name],
                    sampled_data[parent_name],
                    child_name,
                    parent_name
                )
E               KeyError: 'child1'

sdv/sampling/hierarchical_sampler.py:213: KeyError

We think the graph traversal algorithm implemented contains a bug. The last child node to be traversed is never sampled by the BaseHierarchicalSampler. When adding relationships in the BaseHierarchicalSampler, the last child (in the provided exampled, child1) is not found in the sampled_data dictionary causing the error.

Notes

We switched to the latest development version (1.4.1.dev0) as we found a similar error when sampling the provided model with the latest stable version (1.4.0). We observed that In version 1.4.0 the hierarchical sampler was not able to sample more than one level of depth.

@portovep portovep added bug Something isn't working new Automatic label applied to new issues labels Sep 21, 2023
@portovep
Copy link
Author

We are working on a potential fix for the current graph traversal algorithm implementation. We will create a pull request with the fix.

@npatki
Copy link
Contributor

npatki commented Sep 21, 2023

Hi @portovep and @elisaherrmann, very nice to meet you and thanks for filing such a detailed issue.

Please hold off on any such pull request, as we already have a fix that was merged into main branch two weeks ago. See #1562.

I don't believe any release candidates were made after this merge. So if you'd like to test it, I'd recommend installing directly from the main branch. If you're still encountering issues on main do let us know. We plan to be releasing these features in the next few weeks in SDV 1.5.0.

@npatki npatki added under discussion Issue is currently being discussed and removed new Automatic label applied to new issues labels Sep 21, 2023
@portovep
Copy link
Author

portovep commented Sep 21, 2023

Hi @npatki, nice to meet you too. Thanks for replying so quickly and we are glad to heard that support for multi table sampling for models with 3+ levels of depth will be added in SDV 1.5.0. Our current use case requires this feature so we are very pleased to see this coming in the next release.

The error described on this issue was encountered while running the latest code from the main branch, which includes changes introduced in #1562.

We found the error while traversing graphs with the following characteristics:

  • A graph 3 levels depth where the child on level 3 has two parents on level 2 (see graph attached to this issue).
  • A graph 3 levels depth with three root nodes where a child on level has two root parents.

You can see the proposed fix implementation and two integration tests that covered the above mentioned scenarios in this fork's branch:
main...portovep:SDV:issue-1600-fix-hierarchical-sampler-graph-traversal-algorithm

Let me know if you would like me to raise a PR so you can check if the proposed fix is valid.

@npatki
Copy link
Contributor

npatki commented Sep 21, 2023

Hi @portovep, no problem. I had hoped to avoid any duplicate efforts but I realize that you were already using the up-to-date code. Indeed, I can replicate the problem on the main branch using the code you've provided.

Seems like it never even reaches the asserts and fails right on the sample. I'm attaching a stack trace here for future reference.

stack_trace.txt

Next Steps

Supporting these types of schemas is important to us. Our team has been actively working on this area, and we'd like to clean up more of the traversal code to make it simpler, and prevent these edge cases. We'll also be checking to ensure it works with other parts of the SDV software. So no need to submit any fixes as of yet.

In the meantime, you are welcome to continue using your fork for personal use if it's working for you. We'll reach out if we need any PRs.

@portovep
Copy link
Author

@npatki cleaning up the traversal code make it simpler and prevent edge cases makes sense. Perhaps a established graph traversal algorithm like Depth-first search (DFS) could help here.

Thanks for your help. We will wait until the fix gets released as part of a future version and use a work-around in the meantime.

@npatki
Copy link
Contributor

npatki commented Sep 22, 2023

My pleasure, @portovep. We hope to have that fix soon.

Perhaps a established graph traversal algorithm like Depth-first search (DFS) could help here.

Yes indeed, the initial SDV paper had an DFS approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants