Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bsr_dense_mm Triton kernel: fix out kwarg #96648

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 10 additions & 9 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,15 @@ def check(cond, msg):
"should be True.",
)

# Allocate out
if out is None:
out = dense.new_zeros(original_batch_dims_broadcasted + (m, n))
else:
out.zero_()

# Short circuit if lhs is zero
if bsr._nnz() == 0:
return dense.new_zeros(original_batch_dims_broadcasted + (m, n))
return out

# TODO: insert switch
if is_sparse_rowspace_mode is None:
Expand Down Expand Up @@ -486,10 +492,6 @@ def make_triton_contiguous(t):
dense_batch_dims = dense.shape[:-2]
batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)

# Allocate out
if out is None:
out = dense.new_zeros(batch_dims_broadcasted + (m, n))

# Broadcast batch dimensions and squash
def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
return t.broadcast_to(batch_dims + invariant_dims).flatten(
Expand Down Expand Up @@ -520,6 +522,8 @@ def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])

# NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
# out gets modified in-place, so we store a backup copy.
out_backup = out
out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])

# NOTE: this function will ALWAYS create a view
Expand Down Expand Up @@ -570,10 +574,7 @@ def valid_grid_dim(g, mg):

kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)

# Block dims need to rejoin with the corresponding block dimensions
# prior to reshape so that blocks do not end up being transposed.
# NB: type checker is not able to narrow Optional[Tensor] to tensor by this point
return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n)) # type: ignore[union-attr]
return out_backup
else:
bsr_dense_mm = None # type: ignore[assignment]

Expand Down