diff --git a/benchmarks/run.py b/benchmarks/run.py index 4195b16e8..3a4ead1d4 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -280,6 +280,11 @@ class RunResult: "examples.jagged_sum", "jagged_sum_tritonbench", ), + "template_attention": ( + "tritonbench.operators.template_attention.operator", + "examples.template_attention", + "template_attention_tritonbench", + ), } @@ -538,6 +543,12 @@ class RunResult: "helion_fp8_gemm_tritonbench-speedup": "helion_speedup", "helion_fp8_gemm_tritonbench-accuracy": "helion_accuracy", }, + "template_attention": { + "test_with_exp2-speedup": "triton_speedup", + "test_with_exp2-accuracy": "triton_accuracy", + "helion_template_attention_tritonbench-speedup": "helion_speedup", + "helion_template_attention_tritonbench-accuracy": "helion_accuracy", + }, } diff --git a/examples/template_attention.py b/examples/template_attention.py new file mode 100644 index 000000000..fc7ba647a --- /dev/null +++ b/examples/template_attention.py @@ -0,0 +1,382 @@ +""" +Template Attention Example +========================= + +This code implements a templated attention kernel using Helion that mirrors the Triton template attention implementation. +It demonstrates masked causal attention with configurable parameters and optimization features. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +from typing import Callable + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + + +# %% +# Template Attention Kernel with Causal Masking +# -------------------------------------------- +@helion.kernel(config=helion.Config(block_sizes=[128, 64])) +def template_attention_causal( + q_in: torch.Tensor, + k_in: torch.Tensor, + v_in: torch.Tensor, +) -> tuple[torch.Tensor]: + """ + Computes scaled dot-product attention with causal masking. + + Based on the Triton template attention implementation, this kernel: + - Uses causal masking (queries can only attend to keys at or before their position) + - Implements flash attention algorithm for memory efficiency + - Uses online softmax for numerical stability + + Args: + q_in: Query tensor of shape [batch, heads, seq_len, D] + k_in: Key tensor of shape [batch, heads, seq_len, D] + v_in: Value tensor of shape [batch, heads, seq_len, D] + + Returns: + Output tensor of shape [batch, heads, seq_len, D] + """ + M = q_in.size(-2) # seq_len + N = k_in.size(-2) # seq_len + assert v_in.size(-2) == N + D = hl.specialize(q_in.size(-1)) + assert D == k_in.size(-1) == v_in.size(-1) + + # Reshape to [batch*heads, seq_len, D] + q_view = q_in.reshape([-1, M, D]) + v_view = v_in.reshape([-1, N, D]) + k_view = k_in.reshape([-1, N, D]).transpose(1, 2) # [batch*heads, D, seq_len] + + out = torch.empty_like(q_view) + + # Scale factor (no exp2 conversion yet) + # template_attention does not use 1.0 / math.sqrt(D) + qk_scale = 1.0 + + # Process in tiles: [batch*heads, seq_len_q] + block_size_m = hl.register_block_size(M) + block_size_n = hl.register_block_size(N) + for tile_m, tile_b in hl.tile( + [M, q_view.size(0)], block_size=[block_size_m, 1] + ): # BLOCK_M = 128 + # Initialize flash attention statistics + m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) + l_i = hl.zeros([tile_b, tile_m], dtype=torch.float32) + acc = hl.zeros([tile_b, tile_m, D], dtype=torch.float32) + + # Load query block + q = q_view[tile_b, tile_m, :] * qk_scale + + # Iterate over key/value blocks + for tile_n in hl.tile(N, block_size=block_size_n): # BLOCK_N = 64 + # Load key and value + k = k_view[tile_b, :, tile_n] # [batch, D, block_n] + v = v_view[tile_b, tile_n, :] # [batch, block_n, D] + + # Compute attention scores: [batch, block_m, block_n] + qk = torch.bmm(q, k) * qk_scale + + # Apply causal mask + # Create indices for this tile + q_indices = (tile_m.begin + hl.arange(tile_m.block_size))[:, None] + k_indices = (tile_n.begin + hl.arange(tile_n.block_size))[None, :] + + # Causal condition: query_pos >= key_pos (can attend to current and previous) + causal_mask = q_indices >= k_indices + + # Boundary mask + tmp0 = hl.full([1], 1024, torch.int64) + tmp1 = (q_indices) <= tmp0 + tmp2 = (k_indices) <= tmp0 + tmp3 = tmp1 & tmp2 + mask = tmp3 | causal_mask + + # Apply mask by setting invalid positions to -inf + qk = torch.where(mask, qk, float("-inf")) + + # Online softmax (flash attention) + row_max = torch.amax(qk, dim=-1) # Row max + m_i_new = torch.maximum(m_i, row_max) + + # Compute exponentials + alpha = torch.exp(m_i - m_i_new) + p = torch.exp(qk - m_i_new[:, :, None]) + + # Update statistics + l_i_new = l_i * alpha + torch.sum(p, dim=-1) + + # Update accumulator + acc = acc * alpha[:, :, None] + p = p.to(v.dtype) + acc = torch.baddbmm(acc, p, v) + + # Update running statistics + l_i = l_i_new + m_i = m_i_new + + # Normalize and store output + acc = acc / l_i[:, :, None] + out[tile_b, tile_m, :] = acc.to(out.dtype) + + return (out.view(q_in.size()),) + + +# %% +# Template Attention with exp2 optimization +# --------------------------------------- +@helion.kernel(config=helion.Config(block_sizes=[128, 64])) +def template_attention_causal_exp2( + q_in: torch.Tensor, + k_in: torch.Tensor, + v_in: torch.Tensor, +) -> tuple[torch.Tensor]: + """ + Optimized version using exp2 for better performance, matching triton_tem_fused_with_exp2. + + This version includes optimizations from the Triton implementation: + - Uses exp2 instead of exp for better numerical properties + - Scales by log2(e) to convert between bases + - Includes compiler optimization hints + + Args: + q_in: Query tensor of shape [batch, heads, seq_len, D] + k_in: Key tensor of shape [batch, heads, seq_len, D] + v_in: Value tensor of shape [batch, heads, seq_len, D] + + Returns: + Output tensor of shape [batch, heads, seq_len, D] + """ + # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + SCORE_MOD_IS_LINEAR = False + ROWS_GUARANTEED_SAFE = False + + M = q_in.size(-2) # seq_len + N = k_in.size(-2) # seq_len + assert v_in.size(-2) == N + D = hl.specialize(q_in.size(-1)) + assert D == k_in.size(-1) == v_in.size(-1) + + # Reshape to [batch*heads, seq_len, D] + q_view = q_in.reshape([-1, M, D]) + v_view = v_in.reshape([-1, N, D]) + k_view = k_in.reshape([-1, N, D]).transpose(1, 2) # [batch*heads, D, seq_len] + + out = torch.empty_like(q_view) + + # Scale by log_2(e) for exp2 optimization + # template_attention does not use 1.0 / math.sqrt(D) + qk_scale = 1.0 + + # Process in tiles: [batch*heads, seq_len_q] + block_size_m = hl.register_block_size(M) + block_size_n = hl.register_block_size(N) + for tile_m, tile_b in hl.tile( + [M, q_view.size(0)], block_size=[block_size_m, 1] + ): # BLOCK_M = 128 + # Initialize flash attention statistics + m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) + l_i = hl.zeros([tile_b, tile_m], dtype=torch.float32) + acc = hl.zeros([tile_b, tile_m, D], dtype=torch.float32) + + # Load and scale query by log_2(e) + if SCORE_MOD_IS_LINEAR: + qk_scale *= 1.44269504 + q = q_view[tile_b, tile_m, :] * qk_scale + + # Iterate over key/value blocks + for tile_n in hl.tile(N, block_size=block_size_n): # BLOCK_N = 64 + # Load key and value + k = k_view[tile_b, :, tile_n] # [batch, D, block_n] + v = v_view[tile_b, tile_n, :] # [batch, block_n, D] + + # Compute attention scores: [batch, block_m, block_n] + qk = torch.bmm(q, k) + + # Apply causal mask + # Create indices for this tile + q_indices = (tile_m.begin + hl.arange(tile_m.block_size))[:, None] + k_indices = (tile_n.begin + hl.arange(tile_n.block_size))[None, :] + + # Causal condition: query_pos >= key_pos (can attend to current and previous) + causal_mask = q_indices >= k_indices + + # Boundary mask + tmp0 = hl.full([1], 1024, torch.int64) + tmp1 = (q_indices) <= tmp0 + tmp2 = (k_indices) <= tmp0 + tmp3 = tmp1 & tmp2 + mask = tmp3 | causal_mask + + # Apply mask by setting invalid positions to -inf + qk = torch.where(mask, qk, float("-inf")) + if not SCORE_MOD_IS_LINEAR: + qk *= 1.44269504 + + # Online softmax with exp2 (flash attention) + row_max = torch.amax(qk, dim=-1) # Row max + m_i_new = torch.maximum(m_i, row_max) + masked_out_rows = m_i_new == float("-inf") + + # Compute exponentials using exp2 + alpha = torch.exp2(m_i - m_i_new) + p = torch.exp2(qk - m_i_new[:, :, None]) + if not ROWS_GUARANTEED_SAFE: + alpha = torch.where(masked_out_rows, 0, alpha) + p = torch.where(masked_out_rows[:, :, None], 0, p) + + # Update statistics + l_i_new = l_i * alpha + torch.sum(p, dim=-1) + + # Update accumulator + acc = acc * alpha[:, :, None] + p = p.to(v.dtype) + acc = torch.baddbmm(acc, p, v) + + # Update running statistics + l_i = l_i_new + m_i = m_i_new + + # Normalize and store output + acc = acc / l_i[:, :, None] + out[tile_b, tile_m, :] = acc.to(out.dtype) + + return (out.view(q_in.size()),) + + +# %% +# Testing Functions +# -------------- +def ref_causal_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dtype: torch.dtype = torch.float16, +) -> tuple[torch.Tensor]: + """Reference causal attention implementation with boundary mask""" + # scale = 1.0 / math.sqrt(D) + scale = 1.0 + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + + # Apply combined boundary and causal mask (matching lines 306-315 in triton_attention.py) + seq_len = q.size(-2) + + # Create index tensors for query and key positions + q_indices = torch.arange(seq_len, device=q.device)[:, None] + k_indices = torch.arange(seq_len, device=q.device)[None, :] + + # Boundary condition: both query and key must be <= 1024 + boundary_threshold = 1024 + tmp1 = q_indices <= boundary_threshold + tmp2 = k_indices <= boundary_threshold + tmp3 = tmp1 & tmp2 + + # Causal condition: query_pos >= key_pos + tmp4 = q_indices >= k_indices + + # Combined mask: (both within boundary) OR (causal condition satisfied) + tmp5 = tmp3 | tmp4 + + # Apply mask by setting invalid positions to -inf + scores = scores.masked_fill(~tmp5, float("-inf")) + + attn_weights = torch.softmax(scores, dim=-1).to(dtype) + return (torch.matmul(attn_weights, v),) + + +def test_template_attention( + batch_size: int, + num_heads: int, + seq_len: int, + D: int, + dtype: torch.dtype = torch.float16, + device: torch.device | str = "cuda", +) -> None: + """ + Test the template attention kernels against reference implementations. + + Args: + batch_size: Batch size + num_heads: Number of attention heads + seq_len: Sequence length + D: Head dimension + dtype: Data type for tensors + device: Device to run on + """ + # Create test tensors + q, k, v = [ + torch.randn((batch_size, num_heads, seq_len, D), dtype=dtype, device=device) + for _ in range(3) + ] + + # Create wrappers that extract only the first element of the tuple + def baseline_template_attention_wrapper( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """Wrapper that extracts first element from template_attention_causal output""" + return ref_causal_attention(q, k, v)[0] + + def template_attention_causal_wrapper( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """Wrapper that extracts first element from template_attention_causal output""" + return template_attention_causal(q, k, v)[0] + + def template_attention_causal_exp2_wrapper( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """Wrapper that extracts first element from template_attention_causal_exp2 output""" + return template_attention_causal_exp2(q, k, v)[0] + + print("Testing template_attention_causal:") + run_example( + template_attention_causal_wrapper, + baseline_template_attention_wrapper, + (q, k, v), + ) + + print("\nTesting template_attention_causal_exp2:") + run_example( + template_attention_causal_exp2_wrapper, + baseline_template_attention_wrapper, + (q, k, v), + ) + + +# %% +# Tritonbench Integration +# ----------------------- +def template_attention_tritonbench( + tb_op: object, p1: torch.Tensor, p2: torch.Tensor, p3: torch.Tensor +) -> Callable: + return lambda: template_attention_causal_exp2(p1, p2, p3) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs template attention tests. + Tests with parameters similar to the Triton benchmark: 16 batch*heads, 4096 sequence length, 64 head dimension. + """ + # Test with smaller sizes first for debugging + test_template_attention(2, 8, 512, 64) + + # Test with tritonbench parameters: batch=16, heads=16, seq_len=4096, D=64 + test_template_attention(16, 16, 4096, 64) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index d7d69b397..321b3ae95 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -3811,6 +3811,139 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher): _launcher(_helion_swiglu, (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_template_attention) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_template_attention_causal_exp2(q_view, k_view, v_view, out, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, M, N, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_3: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < M + offset_2 = pid_1 + indices_2 = offset_2 + tl.zeros([1], tl.int32) + indices_4 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32) + m_i = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_0], float('-inf'), tl.float32) + l_i = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_0], 0.0, tl.float32) + acc = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_0, 64], 0.0, tl.float32) + load = tl.load(q_view + (indices_2[:, None, None] * q_view_stride_0 + indices_0[None, :, None] * q_view_stride_1 + indices_4[None, None, :] * q_view_stride_2), mask_0[None, :, None], other=0) + v_0 = 1.0 + v_1 = load * v_0 + for offset_3 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_3 < N + v_1_copy = v_1 + m_i_copy = m_i + l_i_copy = l_i + acc_copy = acc + v_1_copy_0 = v_1_copy + m_i_copy_0 = m_i_copy + l_i_copy_0 = l_i_copy + acc_copy_0 = acc_copy + k = tl.load(k_view + (indices_2[:, None, None] * k_view_stride_0 + indices_4[None, :, None] * k_view_stride_1 + indices_3[None, None, :] * k_view_stride_2), mask_1[None, None, :], other=0) + v = tl.load(v_view + (indices_2[:, None, None] * v_view_stride_0 + indices_3[None, :, None] * v_view_stride_1 + indices_4[None, None, :] * v_view_stride_2), mask_1[None, :, None], other=0) + qk = tl.reshape(tl.dot(tl.reshape(tl.cast(v_1_copy_0, tl.float16), [_BLOCK_SIZE_0, 64]), tl.reshape(tl.cast(k, tl.float16), [64, _BLOCK_SIZE_1]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_1]) + add = 1 + _BLOCK_SIZE_0 + iota = tl.arange(0, _BLOCK_SIZE_0) + v_2 = tl.cast(offset_0, tl.int32) + v_3 = iota + v_2 + q_indices = v_3[:, None] + iota_1 = tl.arange(0, _BLOCK_SIZE_1) + v_4 = tl.cast(offset_3, tl.int32) + v_5 = iota_1 + v_4 + k_indices = v_5[None, :] + v_6 = q_indices >= k_indices + tmp0 = tl.full([1], 1024, tl.int64) + v_7 = tl.cast(q_indices, tl.int64) + v_8 = tmp0[None, :] + v_9 = v_7 <= v_8 + v_10 = tl.cast(k_indices, tl.int64) + v_11 = tmp0[None, :] + v_12 = v_10 <= v_11 + v_13 = v_9 & v_12 + v_14 = v_13 | v_6 + v_15 = float('-inf') + v_16 = v_14[None, :, :] + v_17 = v_15[None, None, None] + v_18 = tl.where(v_16, qk, v_17) + v_19 = 1.44269504 + v_20 = v_18 * v_19 + _mask_to_2 = tl.where(tl.broadcast_to(mask_0[None, :, None] & mask_1[None, None, :], [_BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_1]), v_20, tl.full([], float('-inf'), tl.float16)) + row_max = tl.cast(tl.max(_mask_to_2, 2), tl.float16) + v_21 = tl.cast(row_max, tl.float32) + v_22 = triton_helpers.maximum(m_i_copy_0, v_21) + v_23 = float('-inf') + v_24 = v_22 == v_23 + v_25 = m_i_copy_0 - v_22 + v_26 = libdevice.exp2(v_25) + subscript_2 = v_22[:, :, None] + v_27 = tl.cast(v_20, tl.float32) + v_28 = v_27 - subscript_2 + v_29 = libdevice.exp2(v_28) + v_30 = 0.0 + v_31 = v_30[None, None] + v_32 = tl.where(v_24, v_31, v_26) + subscript_3 = v_24[:, :, None] + v_33 = 0.0 + v_34 = v_33[None, None, None] + v_35 = tl.where(subscript_3, v_34, v_29) + v_36 = l_i_copy_0 * v_32 + _mask_to_3 = tl.where(tl.broadcast_to(mask_0[None, :, None] & mask_1[None, None, :], [_BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_1]), v_35, tl.full([], 0, tl.float32)) + sum_1 = tl.cast(tl.sum(_mask_to_3, 2), tl.float32) + v_37 = v_36 + sum_1 + subscript_4 = v_32[:, :, None] + v_38 = acc_copy_0 * subscript_4 + v_39 = tl.cast(v_35, tl.float16) + _mask_to_4 = tl.where(tl.broadcast_to(mask_0[None, :, None] & mask_1[None, None, :], [_BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_1]), v_39, tl.full([], 0, tl.float16)) + acc = tl.reshape(tl.dot(tl.reshape(tl.cast(_mask_to_4, tl.float16), [_BLOCK_SIZE_0, _BLOCK_SIZE_1]), tl.reshape(tl.cast(v, tl.float16), [_BLOCK_SIZE_1, 64]), acc=tl.reshape(v_38, [_BLOCK_SIZE_0, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_2, _BLOCK_SIZE_0, 64]) + l_i = v_37 + m_i = v_22 + subscript = l_i[:, :, None] + v_40 = acc / subscript + v_41 = tl.cast(v_40, tl.float16) + tl.store(out + (indices_2[:, None, None] * out_stride_0 + indices_0[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_41, mask_0[None, :, None]) + +def template_attention_causal_exp2(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher): + """ + Optimized version using exp2 for better performance, matching triton_tem_fused_with_exp2. + + This version includes optimizations from the Triton implementation: + - Uses exp2 instead of exp for better numerical properties + - Scales by log2(e) to convert between bases + - Includes compiler optimization hints + + Args: + q_in: Query tensor of shape [batch, heads, seq_len, D] + k_in: Key tensor of shape [batch, heads, seq_len, D] + v_in: Value tensor of shape [batch, heads, seq_len, D] + + Returns: + Output tensor of shape [batch, heads, seq_len, D] + """ + M = q_in.size(-2) + N = k_in.size(-2) + assert v_in.size(-2) == N + D = 64 + assert D == k_in.size(-1) == v_in.size(-1) + q_view = q_in.reshape([-1, M, D]) + v_view = v_in.reshape([-1, N, D]) + k_view = k_in.reshape([-1, N, D]).transpose(1, 2) + out = torch.empty_like(q_view) + _BLOCK_SIZE_0 = 128 + _RDIM_SIZE_3 = 64 + _BLOCK_SIZE_1 = 64 + _launcher(_helion_template_attention_causal_exp2, (triton.cdiv(M, _BLOCK_SIZE_0) * q_view.size(0),), q_view, k_view, v_view, out, k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), M, N, _BLOCK_SIZE_0, _RDIM_SIZE_3, 1, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return (out.view(q_in.size()),) + --- assertExpectedJournal(TestExamples.test_template_via_closure0) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 42b80963a..14f7e770f 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1332,6 +1332,31 @@ def test_exp_bwd(self): ) ) + def test_template_attention(self): + batch_size, num_heads, seq_len, D = 16, 16, 1024, 64 + q, k, v = [ + torch.randn( + (batch_size, num_heads, seq_len, D), dtype=torch.float16, device=DEVICE + ) + for _ in range(3) + ] + + args = (q, k, v) + + # Import and use the reference implementation + mod = import_path(EXAMPLES_DIR / "template_attention.py") + expected = mod.ref_causal_attention(q, k, v) + + self.assertExpectedJournal( + check_example( + "template_attention", + args, + expected, + fn_name="template_attention_causal_exp2", + block_sizes=[128, 64], + ) + ) + if __name__ == "__main__": unittest.main()