# UNSLOTH CHALLENGE 5 SUBMISSION - Memory Efficient Backprop

## Problem statement

---
---
---
<a name="MATH"></a>
## E) Memory Efficient Backprop [Difficulty: Medium to Hard] [Max points: 10]

In LLMs, the last layer is a projection matrix to calculate the probabilities of the next token, ie $\sigma(XW)$. However, if the vocabulary size is very large, say 128K, then the materialization of the logits causes VRAM spikes.

For example, if the `bsz = 4, qlen = 4096, hd = 4096, vocab = 128K`, then the memory usage for the logits in bfloat16 would be 4GB. In the worst case, we might even need to upcast logits to float32, so 8GB is needed.

In Unsloth, we utilize [Apple's Cut Cross Entropy Loss](https://machinelearning.apple.com/research/cut-your-losses) to reduce VRAM usage, by allowing a Triton kernel to create the logits on the fly to calculate the cross entropy loss. But this does not generalize well to other functions.

Our goal is to generalize this ultimately, but directly creating logits on the fly will be hard. Instead, let's take a slightly less complex approach. Let's first review some stuff. We first notice that during the normal case after forming the intermediate logits for 2 batches, we then do a gather function to aggregate the intermediate results into a single column:
$$
\begin{align}
\begin{bmatrix} x_1 \\ x_2 \end{bmatrix} \times W &= \begin{bmatrix} x_1 W \\ x_2 W \end{bmatrix} \\
f \bigg( \begin{bmatrix} x_1 W \\ x_2 W \end{bmatrix} \bigg) &= \begin{pmatrix} y_1 \\ y_2 \end{pmatrix}
\end{align}
$$

So, if we can somehow skip the materialization of the intermediate logits, and just output the output of `f`, we can save a lot of VRAM!

Notice during backpropagation we can use the chain rule:
$$
\begin{align}
\frac{dL}{dX} &= \frac{dL}{dy} \frac{dy}{dX} ; \frac{dL}{dW} = \frac{dL}{dy} \frac{dy}{dW} \\
\frac{dL}{dy} &= \text{Downstream from backprop} \\
\frac{dy}{dX} &= W^T \\
\frac{dy}{dW} &= X^T \\
\frac{dL}{dX} &= \frac{dL}{dy} W^T \\
\frac{dL}{dW} &= X^T \frac{dL}{dy} \\
\end{align}
$$

If we simply compute the intermediate tensors on the fly via batches, say we do batch 1, then batch 2, we can reduce VRAM usage from 4GB to 2GB!

$$
\begin{align}
\frac{dL}{dX} &= \begin{bmatrix} \frac{dL_1}{dy_1} W^T \\ \frac{dL_2}{dy_2} W^T \end{bmatrix} \\
\frac{dL}{dW} &= \bigg( X_1^T \frac{dL_1}{dy_1} + X_2^T  \frac{dL_2}{dy_2} \bigg)
\end{align}
$$

1. Your goal is to write a `torch.autograd.Function` with a `forward` and `backward` pass showcasing this memory efficient implementation.

2. You must NOT hard code the derivatives - move the transformation function from the logits / intermeditate tensors to a smaller tensor as a separate function which can allow `autograd` to pass through it.

3. As a hint, look at `torch.checkpoint` at https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py. Also, don't forget about the upstream gradients! We need to multiply them to the current gradients!

4. Make the Cross Entropy Loss work. You must show other functions working as well.

## Evaluation parameters

## Marking Criteria for E) Max points = 10
```python
if attemped_E:
    E_score = 0
    if VRAM_50_percent_reduction: E_score += 2
    if remove_float32_upcast: E_score = 0
    if show_ce_loss_works: E_score += 1
    if show_other_functions_work: E_score += 1
    if hardcoded_gradients: E_score = 0
    if allows_dynamic_chunk_sizes: E_score += 1
    if llama_1B_training_loss_matches: E_score += 1
    else: E_score = 0
    if GRPO_memory_efficient_linear_works: E_score += 4
    final_score += E_score
else:
    final_score += 0
```

lets start by loading up the libraries necessary for this 

In [1]:
!pip install torch



all you need is torch !!!! and you are good to goooo......  :) :0

starting out with a basic functional backprop 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def transformation_function(batch, weight, labels):
    """
    Compute the logits and then the cross entropy loss in sum-reduction mode.
    Note: batch is expected to be (B, S, D) and labels (B, S).
    """
    # Compute logits: shape (B, S, vocab)
    x = F.linear(batch, weight).float()
    loss_fct = nn.CrossEntropyLoss(reduction="sum")
    # Flatten so that loss_fct sees (B*S, vocab) and (B*S,)
    loss = loss_fct(x.view(-1, x.size(-1)), labels.view(-1))
    return loss

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, weight, labels, forward_function, batch_chunk_size, seq_chunk_size):
        """
        X: Tensor of shape (B, S, D)
        weight: Projection weight of shape (vocab, D)
        labels: Tensor of shape (B*S,) - will be reshaped to (B, S)
        forward_function: a function that computes loss for a given chunk.
        batch_chunk_size: how many examples in the batch to process at once.
        seq_chunk_size: how many tokens (sequence length) to process per chunk.
        """
        # Save for backward
        ctx.save_for_backward(X, weight, labels)
        ctx.batch_chunk_size = batch_chunk_size
        ctx.seq_chunk_size = seq_chunk_size
        ctx.forward_function = forward_function

        total_loss = 0.0
        total_tokens = 0

        B, S, _ = X.shape
        # Reshape labels into (B, S) for easier chunking
        labels_reshaped = labels.view(B, S)
        for i in range(0, B, batch_chunk_size):
            X_batch = X[i:i+batch_chunk_size]
            labels_batch = labels_reshaped[i:i+batch_chunk_size]
            for j in range(0, S, seq_chunk_size):
                X_chunk = X_batch[:, j:j+seq_chunk_size]
                labels_chunk = labels_batch[:, j:j+seq_chunk_size]
                # Compute chunk loss (using sum reduction)
                chunk_loss = forward_function(X_chunk, weight, labels_chunk)
                total_loss += chunk_loss
                total_tokens += X_chunk.size(0) * X_chunk.size(1)

        # Average the total loss over tokens.
        if total_tokens == 0:
            total_loss_tensor = torch.tensor(0.0, device=X.device)
        else:
            total_loss_tensor = total_loss / total_tokens

        ctx.total_tokens = total_tokens
        ctx.input_shape = X.shape
        return total_loss_tensor

    @staticmethod
    def backward(ctx, d_loss):
        X, W, labels = ctx.saved_tensors
        batch_chunk_size = ctx.batch_chunk_size
        seq_chunk_size = ctx.seq_chunk_size
        forward_function = ctx.forward_function
        total_tokens = ctx.total_tokens
        B, S, _ = X.shape

        # Allocate gradients with the same shape as the inputs
        d_X = torch.zeros_like(X) if X.requires_grad else None
        d_W = torch.zeros_like(W) if W.requires_grad else None

        # Reshape labels to (B, S)
        labels_reshaped = labels.view(B, S)

        # Loop over both dimensions
        for i in range(0, B, batch_chunk_size):
            X_batch = X[i:i+batch_chunk_size]
            labels_batch = labels_reshaped[i:i+batch_chunk_size]
            for j in range(0, S, seq_chunk_size):
                # Detach the chunk and set requires_grad for recomputation
                X_chunk = X_batch[:, j:j+seq_chunk_size].detach().requires_grad_(True)
                labels_chunk = labels_batch[:, j:j+seq_chunk_size]
                with torch.enable_grad():
                    # Recompute the loss for the chunk
                    chunk_loss = forward_function(X_chunk, W, labels_chunk)
                    # Scale the loss contribution as in the forward pass
                    local_loss = chunk_loss / total_tokens
                    # Compute gradients with respect to X_chunk and W.
                    gX, gW = torch.autograd.grad(local_loss, (X_chunk, W), retain_graph=True)
                if d_X is not None:
                    d_X[i:i+batch_chunk_size, j:j+seq_chunk_size] += gX * d_loss
                if d_W is not None:
                    d_W += gW * d_loss

        # Return gradients for all six inputs: (d_X, d_W, None, None, None, None)
        return d_X, d_W, None, None, None, None

# Example usage and test
if __name__ == "__main__":
    device = 'cuda'
    # Test parameters as given:
    bsz, qlen, hd, vocab = 4, 4096, 4096, 128000
    X = torch.randn(bsz, qlen, hd, dtype=torch.bfloat16, device=device, requires_grad=True)
    W = torch.randn(vocab, hd, dtype=torch.bfloat16, device=device, requires_grad=True)
    labels = torch.randint(0, vocab, (bsz * qlen,), device=device)

    # Call using positional arguments (do not use keyword arguments)
    loss = MemoryEfficientLinear.apply(X, W, labels, transformation_function, 1, 1024)
    loss.backward()

    print("Loss:", loss.item())
    print("Gradients for X computed:", X.grad is not None)
    print("Gradients for W computed:", W.grad is not None)

Loss: 283.1170959472656
Gradients for X computed: True
Gradients for W computed: True
