diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index 54ae7535208..6630566d3ee 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -70,7 +70,7 @@ def matmul_kernel( assert quantize_activation assert q_x_scratch is not None assert x_scale_scratch is not None - quant = out_idx == 0 + quant = (out_idx == 0) else: assert q_x_scratch is None assert x_scale_scratch is None @@ -78,8 +78,8 @@ def matmul_kernel( if save_acc: assert acc_scratch is not None - is_first_step = in_idx == 0 - is_last_step = in_idx == n_in - 1 + is_first_step = (in_idx == 0) + is_last_step = (in_idx == (n_in - 1)) else: assert acc_scratch is None is_first_step = True