In [None]:
import torch
from torch import nn

In [None]:
def transformation_function(batch, linear, labels):
    x = linear(batch).float() # Up projection to large space
    from torch.nn import CrossEntropyLoss
    down_projection_function = CrossEntropyLoss(reduction = "mean")
    # Down projection to small space
    loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
    return loss

In [None]:
# example in the statement was given as follows:
# bsz = 4, qlen = 4096, hd = 4096, vocab = 128K
# lets do something similar

# example input with 4 samples, ctx = 8192 tokens, and hidden dimension of 2048
input = torch.randn(4, 8192, 2048, device="cuda", requires_grad=True)

# linear (dense) layer which accepts an input of 2048 hidden dimensions, and outputs to a vocabulary of 4096
forward = nn.Linear(2048, 4096).to("cuda")

# correct labels for the 16 samples in the batch, each sample needs to specify for ctx = 8192 tokens
labels = torch.randint(0, 4096, (4, 8192), device="cuda")

In [None]:
TEST_TRANSFORMATION_FUNCTION = False
if TEST_TRANSFORMATION_FUNCTION:
    # calculate the loss, which should result in a single scalar value
    loss = transformation_function(input, forward, labels)
    print(loss)

In [None]:
class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        # TODO early exit if there is only one sample or if the number
        # TODO of samples isn't a multiple of two
        outputs = [] # NOTE this is likely a hint
        X0, X1 = torch.chunk(X, chunks=2, dim=0)
        L0, L1 = torch.chunk(labels, chunks=2, dim=0)
        Y0 = forward_function(X1, linear, L0)
        Y1 = forward_function(X0, linear, L1)
        outputs.append(Y0)
        outputs.append(Y1)
        ctx.save_for_backward(X, Y0, Y1)
        return torch.mean(torch.tensor(outputs))

    @staticmethod
    def backward(ctx, dY):
        print(dY)
        X, Y0, Y1 = ctx.saved_tensors
        # EDIT THIS FUNCTION
        return X, None, None, None

In [None]:
loss = MemoryEfficientLinear.apply(input, forward, labels, transformation_function)

In [None]:
loss

In [None]:
loss.backward()

In [None]:
input

In [None]:
import torch
import torch.nn.functional as F

class MemoryEfficientLinearFunction(torch.autograd.Function):
    """
    A custom autograd function that computes Z = X @ W in chunks, applies a function f,
    and performs memory-efficient backpropagation.
    
    Args:
        X (torch.Tensor): Input tensor of shape [batch_size, sequence_length, hidden_dim]
        W (torch.Tensor): Weight matrix of shape [hidden_dim, vocab_size]
        f (callable): Transformation function applied to Z (e.g., lambda x: x for logits)
        num_chunks (int): Number of chunks to split X into along the batch dimension
    
    Returns:
        Y (torch.Tensor): Output tensor after applying f to Z
    """
    @staticmethod
    def forward(ctx, X, W, f, num_chunks):
        # Ensure inputs are in the correct format
        assert X.dim() == 3, "X must be 3D: [batch_size, sequence_length, hidden_dim]"
        assert W.dim() == 2, "W must be 2D: [hidden_dim, vocab_size]"
        assert callable(f), "f must be a callable function"
        
        # Split X into chunks along the batch dimension
        X_chunks = torch.chunk(X, num_chunks, dim=0)
        Y_chunks = []
        
        # Process each chunk
        for X_chunk in X_chunks:
            Z_chunk = X_chunk @ W  # Shape: [chunk_size, seq_len, vocab_size]
            Y_chunk = f(Z_chunk)   # Apply transformation function
            Y_chunks.append(Y_chunk)
        
        # Concatenate results
        Y = torch.cat(Y_chunks, dim=0)  # Shape: [batch_size, seq_len, vocab_size]
        
        # Save tensors for backward pass
        ctx.save_for_backward(X, W)
        ctx.f = f
        ctx.num_chunks = num_chunks
        
        return Y

    
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass computes gradients w.r.t. X and W using chunking.
        
        Args:
            grad_output (torch.Tensor): Gradient of loss w.r.t. Y (dL/dY)
        
        Returns:
            grad_X (torch.Tensor): Gradient w.r.t. X (dL/dX)
            grad_W (torch.Tensor): Gradient w.r.t. W (dL/dW)
            None, None: No gradients for f or num_chunks
        """
        X, W = ctx.saved_tensors
        f = ctx.f
        num_chunks = ctx.num_chunks
        
        # Split inputs and gradients into chunks
        X_chunks = torch.chunk(X, num_chunks, dim=0)
        grad_output_chunks = torch.chunk(grad_output, num_chunks, dim=0)
        
        grad_X_chunks = []
        grad_W = torch.zeros_like(W)  # Initialize gradient for W
        
        # Process each chunk
        for X_chunk, grad_Y_chunk in zip(X_chunks, grad_output_chunks):
            # Recompute Z_chunk for this chunk
            Z_chunk = X_chunk @ W
            Z_chunk_ = Z_chunk.detach().requires_grad_(True)
            
            # Compute gradients through f using autograd
            with torch.enable_grad():
                Y_chunk_ = f(Z_chunk_)
                grad_Z_chunk = torch.autograd.grad(
                    outputs=Y_chunk_,
                    inputs=Z_chunk_,
                    grad_outputs=grad_Y_chunk,
                    retain_graph=True
                )[0]
            
            # Compute gradients w.r.t. X and accumulate w.r.t. W
            grad_X_chunk = grad_Z_chunk @ W.transpose(-1, -2)  # dL/dX = dL/dZ @ W^T
            grad_X_chunks.append(grad_X_chunk)
            
            # Accumulate dL/dW = X^T @ dL/dZ
            # Use einsum to handle 3D X_chunk and 3D grad_Z_chunk
            grad_W += torch.einsum('bsh,bsv->hv', X_chunk, grad_Z_chunk)
        
        # Concatenate gradients for X
        grad_X = torch.cat(grad_X_chunks, dim=0)
        
        return grad_X, grad_W, None, None  # No gradients for f or num_chunks

# Example usage function
def test_custom_function():
    # Define problem dimensions
    batch_size, seq_len, hidden_dim, vocab_size = 4, 2, 3, 5
    num_chunks = 2
    
    # Create input tensors
    X = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, requires_grad=True)
    W = torch.randn(hidden_dim, vocab_size, dtype=torch.float32, requires_grad=True)
    
    # Define a transformation function (e.g., identity for logits or ReLU)
    f = lambda x: torch.relu(x)  # Example: ReLU as a nonlinearity
    
    # Apply the custom function
    Y = MemoryEfficientLinearFunction.apply(X, W, f, num_chunks)
    
    # Simulate a loss (e.g., sum of squares)
    loss = Y.pow(2).sum()
    loss.backward()
    
    # Print results
    print("Input X shape:", X.shape)
    print("Weight W shape:", W.shape)
    print("Output Y shape:", Y.shape)
    print("Gradient dL/dX shape:", X.grad.shape)
    print("Gradient dL/dW shape:", W.grad.shape)
    
    # Validate with standard computation
    Z_standard = X @ W
    Y_standard = f(Z_standard)
    loss_standard = Y_standard.pow(2).sum()
    loss_standard.backward()
    
    print("\nGradient match check:")
    print("dL/dX matches:", torch.allclose(X.grad, X.grad.clone(), atol=1e-5))
    print("dL/dW matches:", torch.allclose(W.grad, W.grad.clone(), atol=1e-5))

# Test cross-entropy loss
def test_cross_entropy():
    batch_size, seq_len, hidden_dim, vocab_size = 4, 2, 3, 5
    num_chunks = 2
    
    X = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, requires_grad=True)
    W = torch.randn(hidden_dim, vocab_size, dtype=torch.float32, requires_grad=True)
    
    # Use identity function for logits
    f = lambda x: x
    
    # Compute output
    Y = MemoryEfficientLinearFunction.apply(X, W, f, num_chunks)
    
    # Simulate targets for cross-entropy
    targets = torch.randint(0, vocab_size, (batch_size, seq_len))
    loss = F.cross_entropy(Y.view(-1, vocab_size), targets.view(-1))
    loss.backward()
    
    print("\nCross-entropy test:")
    print("Gradient dL/dX shape:", X.grad.shape)
    print("Gradient dL/dW shape:", W.grad.shape)

if __name__ == "__main__":
    print("Testing with ReLU:")
    test_custom_function()
    print("\nTesting with Cross-Entropy:")
    test_cross_entropy()

In [None]:
from torch import nn
loss = nn.CrossEntropyLoss(reduction="mean")
input = torch.randn(5, 3, 5, requires_grad=True)
print(input)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target)
output = loss(input, target)
print(output)

In [None]:
from torch import nn
loss = nn.CrossEntropyLoss(reduction = "mean")
input = torch.eye(5)
print(input)
target = torch.tensor([0, 1, 2, 3, 4])
print(target)
output = loss(input, target)
print(output)

In [None]:
from torch import nn
loss = nn.CrossEntropyLoss(reduction = "none")
input = torch.eye(5)
print(input)
target = torch.tensor([0, 1, 2, 3, 4])
print(target)
output = loss(input, target)
print(output)