In [None]:
import torch
from torch import nn

In [None]:
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)

output = (2 * a).sum()
torch.autograd.grad(output, (a,))
# torch.autograd.grad(output, (a, b))

In [None]:
def transformation_function(batch, linear, labels):
    assert batch.requires_grad, "Batch lacks requires_grad"
    x = linear(batch).float()
    assert x.requires_grad, "x lacks requires_grad after linear"
    from torch.nn import CrossEntropyLoss
    down_projection_function = CrossEntropyLoss(reduction = "mean")
    # Down projection to small space
    loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
    return loss

In [None]:
# example in the statement was given as follows:
# bsz = 4, qlen = 4096, hd = 4096, vocab = 128K
# lets do something similar

# example input with 4 samples, ctx = 8192 tokens, and hidden dimension of 2048
input = torch.randn(4, 8192, 2048, device="cuda", requires_grad=True)

# linear (dense) layer which accepts an input of 2048 hidden dimensions, and outputs to a vocabulary of 4096
linear = nn.Linear(2048, 4096).to("cuda")

# correct labels for the 16 samples in the batch, each sample needs to specify for ctx = 8192 tokens
labels = torch.randint(0, 4096, (4, 8192), device="cuda")

In [None]:
TEST_TRANSFORMATION_FUNCTION = True
if TEST_TRANSFORMATION_FUNCTION:
    # calculate the loss, which should result in a single scalar value
    loss = transformation_function(input, linear, labels)
    print(loss)

In [None]:
# forward.weight.data
# forward.bias.data
# labels.requires_grad

In [None]:
X = input
X0, X1 = torch.chunk(X, chunks=2, dim=0)
labels_0, labels_1 = torch.chunk(labels, chunks=2, dim=0)

In [None]:
Z0 = transformation_function(X0, linear, labels_0)

In [None]:
# to be very clear about the terminology here
# X is the input to the memory efficient linear function
# Y is W @ X + b, where W is the weight matrix and b is the bias of the linear layer
# Z is f(Y), where f is the transformation function
# the output is expected to be a single scalar value

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        # TODO early exit if there is only one sample or if the number
        # TODO of samples isn't a multiple of two
        n_batch = X.shape[0]
        # X0, X1 = torch.chunk(X, chunks=2, dim=0)
        labels_0, labels_1 = torch.chunk(labels, chunks=2, dim=0)
        with torch.enable_grad():
            X0 = X[:n_batch // 2]
            X1 = X[n_batch // 2:]
            assert X0.requires_grad
            assert X1.requires_grad
            Z0 = forward_function(X0, linear, labels_0)
            Z1 = forward_function(X1, linear, labels_1)
            # at some point, realized need to move the `X0 =` and `X1 =` into this block
            # use this to check if grad is working
            output = (Z0 + Z1) / 2
        ctx.save_for_backward(X0, X1, Z0, Z1, linear.weight)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        X0, X1, Z0, Z1, linear_weight = ctx.saved_tensors
        
        # Gradient scaling factor from the mean
        grad_scale = grad_output * 0.5

        # print('Z0', Z0)
        # print('X0', X0)
        
        # Compute gradients w.r.t. X1 from Z0
        grad_X0 = torch.autograd.grad(Z0, X0, grad_outputs=grad_scale, retain_graph=True)[0]

        # Compute gradients w.r.t. X0 from Z1
        grad_X1 = torch.autograd.grad(Z1, X1, grad_outputs=grad_scale, retain_graph=True)[0]
        
        grad_linear_weight = (
            torch.autograd.grad(Z0, linear_weight, grad_outputs=grad_scale, retain_graph=True)[0] +
            torch.autograd.grad(Z1, linear_weight, grad_outputs=grad_scale, retain_graph=True)[0]
        )[0]

        print('grad_X0', grad_X0)
        print('grad_X1', grad_X1)
        print('grad_linear_weight', grad_linear_weight)

        # Assemble full gradient w.r.t. X
        grad_X = torch.cat([grad_X0, grad_X1], dim=0)

        # Return gradients for all inputs
        return grad_X, None, None, None

In [None]:
loss = MemoryEfficientLinear.apply(input, linear, labels, transformation_function)

In [None]:
loss

In [None]:
loss.backward()

In [None]:
input

In [None]:
# run tests to see if the outputs match