Closed
Description
🐛 Describe the bug
I'm a bit confused by the chunked_loss implementation in src/liger_kernel/chunked_loss/fused_linear_preference.py. Namely, it seems more like a batched_loss than a chunk loss.
My expectation is that it will chunk on the tokens, a la https://pytorch.org/torchtune/0.3/generated/torchtune.modules.loss.CEWithChunkedOutputLoss.html. But it chunks on the batch instead by first separating the chosen from the rejected, then choosing chunks
to be based on the batch dimension.
Is this intended?
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len)
....
len_chosen = target.shape[0] // 2
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
Reproduce
No response
Versions
Environment Report:
Operating System: Linux-6.1.100+-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.5.0+cu124
CUDA version: 12.4
Triton version: 3.1.0
Transformers version: 4.42.3
Metadata
Metadata
Assignees
Labels
No labels