Skip to content

Commit

Permalink
Bug fix for calling BatchRandTransforms
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Feb 15, 2023
1 parent af36a26 commit 1c9168e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fastxtend/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def before_call(self,
):
"Randomly select `self.idxs` and set `self.do` based on `self.p` if not valid `split_idx`"
self.idxs = self.bernoulli.sample((find_bs(b),)).bool() if not split_idx and self.p<1. else torch.ones(find_bs(b)).bool()
self.do = self.p==1. or self.idxs.shape[-1] > 0
self.do = self.p==1. or self.idxs.sum() > 0

def __call__(self,
b:Tensor|tuple[Tensor,...], # Batch item(s)
Expand Down
2 changes: 1 addition & 1 deletion nbs/transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
" ):\n",
" \"Randomly select `self.idxs` and set `self.do` based on `self.p` if not valid `split_idx`\"\n",
" self.idxs = self.bernoulli.sample((find_bs(b),)).bool() if not split_idx and self.p<1. else torch.ones(find_bs(b)).bool()\n",
" self.do = self.p==1. or self.idxs.shape[-1] > 0\n",
" self.do = self.p==1. or self.idxs.sum() > 0\n",
"\n",
" def __call__(self,\n",
" b:Tensor|tuple[Tensor,...], # Batch item(s)\n",
Expand Down

0 comments on commit 1c9168e

Please sign in to comment.