diff --git a/docs/helion_puzzles.rst b/docs/helion_puzzles.rst index bd463a7ab..3045df48c 100644 --- a/docs/helion_puzzles.rst +++ b/docs/helion_puzzles.rst @@ -433,7 +433,7 @@ A scalar version of FlashAttention. scores = q_tile[:, None] * k_tile[None, :] # Find max for numerical stability - batch_max = torch.max(scores, dim=1)[0] + batch_max = torch.amax(scores, dim=1) new_max = torch.maximum(max_val, batch_max) # Scale old accumulations @@ -468,7 +468,7 @@ A batched 2D convolution. .. code-block:: python def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]: - z = torch.zeros(4, 8, 8) + z = torch.zeros(4, 8, 8, device=x.device) x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0) for i in range(8): for j in range(8): @@ -495,7 +495,7 @@ A batched 2D convolution. # Extract the patch patch = x_padded[tile_batch, i:i+kh, j:j+kw] # Apply the kernel - out[tile_batch, i, j] = (k[tile_batch] * patch).sum([1, 2]) + out[tile_batch, i, j] = (k[tile_batch,:,:] * patch).sum([1, 2]) return out