Skip to content
6 changes: 3 additions & 3 deletions docs/helion_puzzles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ A scalar version of FlashAttention.
scores = q_tile[:, None] * k_tile[None, :]

# Find max for numerical stability
batch_max = torch.max(scores, dim=1)[0]
batch_max = torch.amax(scores, dim=1)
new_max = torch.maximum(max_val, batch_max)

# Scale old accumulations
Expand Down Expand Up @@ -468,7 +468,7 @@ A batched 2D convolution.
.. code-block:: python

def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]:
z = torch.zeros(4, 8, 8)
z = torch.zeros(4, 8, 8, device=x.device)
x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0)
for i in range(8):
for j in range(8):
Expand All @@ -495,7 +495,7 @@ A batched 2D convolution.
# Extract the patch
patch = x_padded[tile_batch, i:i+kh, j:j+kw]
# Apply the kernel
out[tile_batch, i, j] = (k[tile_batch] * patch).sum([1, 2])
out[tile_batch, i, j] = (k[tile_batch,:,:] * patch).sum([1, 2])

return out

Expand Down
Loading