Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TVAE: Adapt batch_size to data size #135

Closed
csala opened this issue Mar 8, 2021 · 0 comments · Fixed by #136
Closed

TVAE: Adapt batch_size to data size #135

csala opened this issue Mar 8, 2021 · 0 comments · Fixed by #136
Assignees
Labels
bug Something isn't working
Milestone

Comments

@csala
Copy link
Contributor

csala commented Mar 8, 2021

Environment Details

Please indicate the following details about the environment in which you found the bug:

  • CTGAN version: 0.4.0

Error Description

Current TVAE fails to be fitted on datasets that are shorter than the batch size (and it does it silently).

The culprit is the drop_last=True in this line:

loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)

As a consequence of this line, if the data is shorter than the batch_size TVAE does nothing during the training and results in a model that kept the initialization weights, unoptimized.

Fix

The indicated line could be changed to something like this which:

  • Reduces the batch size to the data length if the data has less rows than the batch size
  • Drops the last batch only if the data size is not divisible by the batch size
        data_length = len(table_data)
        batch_size = min(self.batch_size, data_length)
        drop_last = (data_length % batch_size) != 0
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
@csala csala added the bug Something isn't working label Mar 8, 2021
@csala csala added this to the 0.5.0 milestone Mar 8, 2021
@fealho fealho mentioned this issue Mar 10, 2021
fealho added a commit that referenced this issue Mar 11, 2021
* Bump version: 0.4.1.dev0 → 0.4.1.dev1

* Set drop_last to False

* Add drop_last variable

* Fix lint

* Changed drop_last to False
@pvk-developer pvk-developer modified the milestones: 0.5.0, 0.4.1 Mar 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants