Skip to content

Commit

Permalink
RF PT pad, small device fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 25, 2024
1 parent 349ea73 commit a86ee40
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,11 @@ def pad(
for out_dim, middle, (left, right) in zip(out_dims, axes, padding):
if middle.need_masking() or (isinstance(left, Dim) and left.need_masking()):
if isinstance(right, Dim) or right > 0:
mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext)
mask = rf.compare_bc(
rf.range_over_dim(out_dim, device=out.device),
"<",
rf.copy_to_device((left + middle).dyn_size_ext, out.device),
)
out.raw_tensor = torch.where(
mask.copy_compatible_to(out, check_dtype=False, check_sparse=False).raw_tensor,
out.raw_tensor,
Expand Down

0 comments on commit a86ee40

Please sign in to comment.