Skip to content

Commit

Permalink
Merge pull request #382 from torchsynth/feature/fix_floodir
Browse files Browse the repository at this point in the history
Fix floordiv
  • Loading branch information
turian committed Aug 19, 2022
2 parents a51a36c + 9b39832 commit 81b3422
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion torchsynth/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,22 @@ def _batch_idx_to_is_train(
assert len(idxs) == self.batch_size
# As specified in our paper, the first 9x1024 samples
# are train, and the next 1024 are test.
is_train = (idxs // N_BATCHSIZE_FOR_TRAIN_TEST_REPRODUCIBILITY) % 10 != 9
# __floordiv__ is deprecated, and its behavior will
# change in a future version of pytorch. It currently
# rounds toward 0 (like the 'trunc' function NOT 'floor').
# This results in incorrect rounding for negative values.
# To keep the current behavior, use torch.div(a, b,
# rounding_mode='trunc'), or for actual floor division,
# use torch.div(a, b, rounding_mode='floor').
is_train = (
torch.div(
idxs,
N_BATCHSIZE_FOR_TRAIN_TEST_REPRODUCIBILITY,
rounding_mode="trunc",
)
% 10
!= 9
)
else:
is_train = None
return is_train
Expand Down

0 comments on commit 81b3422

Please sign in to comment.