import tilelang
import tilelang.language as T

VEC_SIZE = 32

@tilelang.jit
def test_kernel(
    B: int,
    M: int,
    N: int,
    BLOCK_MN: int,
    BLOCK_K: int,
):
    @T.prim_func
    def main(
        a: T.Buffer((B, M, N), "bfloat16"),
        amax: T.Buffer((B, M, N // VEC_SIZE), "float32"),
    ):
        with T.Kernel(
            T.ceildiv(M, BLOCK_MN),
            T.ceildiv(N, BLOCK_K),
            B,
            threads=128
        ) as (pid_m, pid_n, pid_b):
            a_shared = T.alloc_shared((BLOCK_MN, BLOCK_K), "bfloat16")
            a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
            a_max_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE,), "float32")
            offs_m = pid_m * BLOCK_MN
            offs_n = pid_n * BLOCK_K
            for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
                if offs_m + i < M and offs_n + j < N:
                    a_shared[i, j] = a[pid_b, offs_m + i, offs_n + j]
                else:
                    a_shared[i, j] = T.cast(0.0, "bfloat16")
            
            a_shared_reshape = T.reshape(a_shared, (BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE))
            for i, j in T.Parallel(BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE):
                a_fp32_local[i, j] = T.cast(a_shared_reshape[i, j], "float32")
            
            T.reduce_absmax(a_fp32_local, a_max_local, dim=-1, clear=True)
            
            offs_n_amax = pid_n * (BLOCK_K // VEC_SIZE)
            for i, j in T.Parallel(BLOCK_MN, BLOCK_K // VEC_SIZE):
                if offs_m + i < M and offs_n + j < (N // VEC_SIZE):
                    amax[pid_b, offs_m + i, offs_n_amax + j] = a_max_local[i * (BLOCK_K // VEC_SIZE) +j]
                else:
                    amax[pid_b, offs_m + i, offs_n + j] = T.cast(0.0, "float32")
    return main

import torch
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

B, M_orig, N_orig = 1, 1, 33
M_padded, N_padded = 32, 64

a_orig = torch.randn(1, 1, 33, dtype=torch.bfloat16).cuda()
a_padded = torch.zeros(B, M_padded, N_padded, dtype=torch.bfloat16).cuda()
a_padded[:, :M_orig, :N_orig] = a_orig

BLOCK_MN = 32
BLOCK_K = 64
kernel = test_kernel(B, M_padded, N_padded, BLOCK_MN, BLOCK_K)
amax_shape = (B, M_padded, N_padded // VEC_SIZE)
amax = torch.empty(amax_shape, dtype=torch.float32).cuda()

code = kernel.get_kernel_source()
print(code, flush=True)

kernel(a_padded, amax)

print(a_padded[0, 0, 32], flush=True)
print(amax, flush=True)