In [None]:
import torch
from torch import nn

In [None]:
def transformation_function(batch, linear, labels):
    x = linear(batch).float() # Up projection to large space
    from torch.nn import CrossEntropyLoss
    down_projection_function = CrossEntropyLoss(reduction = "mean")
    # Down projection to small space
    loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
    return loss

In [None]:
# example in the statement was given as follows:
# bsz = 4, qlen = 4096, hd = 4096, vocab = 128K
# lets do something similar

# example input with 4 samples, ctx = 8192 tokens, and hidden dimension of 2048
input = torch.randn(4, 8192, 2048, device="cuda", requires_grad=True)

# linear (dense) layer which accepts an input of 2048 hidden dimensions, and outputs to a vocabulary of 4096
forward = nn.Linear(2048, 4096).to("cuda")

# correct labels for the 16 samples in the batch, each sample needs to specify for ctx = 8192 tokens
labels = torch.randint(0, 4096, (4, 8192), device="cuda")

In [None]:
TEST_TRANSFORMATION_FUNCTION = False
if TEST_TRANSFORMATION_FUNCTION:
    # calculate the loss, which should result in a single scalar value
    loss = transformation_function(input, forward, labels)
    print(loss)

In [None]:
class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        # TODO early exit if there is only one sample or if the number
        # TODO of samples isn't a multiple of two
        outputs = [] # NOTE this is likely a hint
        X0, X1 = torch.chunk(X, chunks=2, dim=0)
        L0, L1 = torch.chunk(labels, chunks=2, dim=0)
        Y0 = forward_function(X1, linear, L0)
        Y1 = forward_function(X0, linear, L1)
        outputs.append(Y0)
        outputs.append(Y1)
        ctx.save_for_backward(X, Y0, Y1)
        return torch.mean(torch.tensor(outputs))

    @staticmethod
    def backward(ctx, dY):
        print(dY)
        X, Y0, Y1 = ctx.saved_tensors
        # EDIT THIS FUNCTION
        return X, None, None, None

In [None]:
loss = MemoryEfficientLinear.apply(input, forward, labels, transformation_function)

In [None]:
loss

In [None]:
loss.backward()

In [None]:
input