Skip to content

Misaligned address in JSD kernel #755

@yf225

Description

@yf225

Source: #733 (comment)

Repro:

import torch
import helion
import helion.language as hl

@helion.kernel(
    config=helion.Config(
        block_sizes=[4, 256],
        indexing="tensor_descriptor",
        num_stages=4,
        num_warps=4,
        pid_type="flat",
        range_flattens=[None, False],
        range_multi_buffers=[None, False],
        range_num_stages=[0, 4],
        range_unroll_factors=[0, 0],
        range_warp_specializes=[],
    ),
    static_shapes=True,
)
def jsd_forward_kernel(
    _input: torch.Tensor,
    target: torch.Tensor,
    shift_labels: torch.Tensor | None = None,
    beta: float = 0.5,
    ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor]:
    BT, V = _input.shape
    assert target.shape == _input.shape, (
        f"Shape mismatch: {target.shape} != {_input.shape}"
    )
    block_size_n = hl.register_block_size(V)
    block_size_m = hl.register_block_size(BT)

    # Create output tensor for accumulating loss
    loss = torch.zeros([BT], dtype=torch.float32, device=_input.device)
    dX = torch.empty_like(loss)

    one_minus_beta = 1 - beta

    # Count non-ignored elements
    n_non_ignore = float(BT)
    if shift_labels is not None:
        n_non_ignore = float((shift_labels != ignore_index).sum().item())
        if n_non_ignore == 0:
            return torch.zeros(
                [], dtype=_input.dtype, device=_input.device
            ), torch.zeros_like(_input)

    # Process each sequence position
    for tile_bt in hl.tile(BT, block_size=block_size_m):
        # Check for label masking
        if shift_labels is not None:
            if shift_labels[tile_bt] == ignore_index:
                for tile_X in hl.tile(V):
                    dX[tile_bt, tile_X] = 0.0
                continue
        intermediate_loss = hl.zeros([tile_bt, block_size_n], dtype=torch.float32)
        intermediate_dX = hl.zeros([tile_bt, block_size_n], dtype=_input.dtype)
        for tile_v in hl.tile(V, block_size=block_size_n):
            # Load log probabilities and convert to float32
            X = _input[tile_bt, tile_v]
            Y = target[tile_bt, tile_v]

            if beta == 0.0:  # Forward KL: KL(P || Q)
                Y_max = torch.amax(Y, dim=0)
                Y_shift = Y - Y_max
                Y_prob = torch.exp(Y_shift) * torch.exp(
                    Y_max
                )  # Compensate for the shift
                intermediate_loss += Y_prob * (Y - X)
                intermediate_dX += -Y_prob
            elif beta == 1.0:  # Reverse KL: KL(Q || P)
                X_max = torch.amax(X, dim=0)
                X_shift = X - X_max
                X_prob = torch.exp(X_shift) * torch.exp(
                    X_max
                )  # Compensate for the shift
                intermediate_loss += X_prob * (X - Y)
                intermediate_dX += intermediate_loss + X_prob
            else:  # General JSD: beta*KL(P||M) + (1-beta)*KL(Q||M)
                Q = torch.exp(X)  # = exp(X)
                P = torch.exp(Y)  # = exp(Y)

                beta_P = beta * P
                one_minus_beta_Q = one_minus_beta * Q
                M = beta_P + one_minus_beta_Q
                log_M = torch.log(
                    M
                )
                x_minus_log_m = X - log_M
                kl_q_m = one_minus_beta_Q * x_minus_log_m
    
                intermediate_loss += beta_P * (Y - log_M) + kl_q_m
                intermediate_dX += kl_q_m

        # Accumulate over vocabulary dimension
        scale = 1.0 / n_non_ignore
        loss[tile_bt] = torch.sum(intermediate_loss * scale, dim=1)
        dX[tile_bt] = torch.sum(intermediate_dX * scale, dim=1)

    # Normalize by number of non-ignored elements, run it on host to match liger_kernel
    final_loss = torch.sum(
        loss
    )
    return final_loss, dX

vocab = 512
batch = 512
log_q = torch.randn(batch, vocab, device=DEVICE).log_softmax(dim=-1)
log_p = torch.randn(batch, vocab, device=DEVICE).log_softmax(dim=-1)

# Current implementation hits an illegal barrier in the generated Triton kernel.
# ExpectedFailure records the regression until the alignment fix lands.
jsd_forward_kernel(log_q, log_p)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions