Skip to content

Bug in chunking? #439

Closed
Closed
@cinjon

Description

@cinjon

🐛 Describe the bug

I'm a bit confused by the chunked_loss implementation in src/liger_kernel/chunked_loss/fused_linear_preference.py. Namely, it seems more like a batched_loss than a chunk loss.

My expectation is that it will chunk on the tokens, a la https://pytorch.org/torchtune/0.3/generated/torchtune.modules.loss.CEWithChunkedOutputLoss.html. But it chunks on the batch instead by first separating the chosen from the rejected, then choosing chunks to be based on the batch dimension.

Is this intended?

Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len)
....

        len_chosen = target.shape[0] // 2
        chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
        _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
        _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
        _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
        _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

Reproduce

No response

Versions

Environment Report:

Operating System: Linux-6.1.100+-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.5.0+cu124
CUDA version: 12.4
Triton version: 3.1.0
Transformers version: 4.42.3

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions