In [None]:
import torch
from torch import nn

In [None]:
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)

output = (2 * a).sum()
torch.autograd.grad(output, (a,))
# torch.autograd.grad(output, (a, b))

In [None]:
def transformation_function(batch, linear, labels):
    assert batch.requires_grad, "Batch lacks requires_grad"
    x = linear(batch).float()
    assert x.requires_grad, "x lacks requires_grad after linear"
    from torch.nn import CrossEntropyLoss
    down_projection_function = CrossEntropyLoss(reduction = "mean")
    # Down projection to small space
    loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
    return loss

In [None]:
from unsloth_zoo.rl_replacements import UnslothEfficientGRPO
from functools import partial

# given a batch, computes a reward for each of the items in the batch
# def example_reward_function(batch) -> torch.Tensor:
#     pass

# given a batch, computes a loss by passing it through linear and calculating rewards
# def is_this_grpo(batch, linear, reward_functions):
#     assert batch.requires_grad, "Batch lacks requires_grad"
#     x = linear(batch).float()
#     assert x.requires_grad, "x lacks requires_grad after linear"
#     from torch.nn import CrossEntropyLoss
#     down_projection_function = CrossEntropyLoss(reduction = "mean")
#     # Down projection to small space
#     loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
#     return loss

# we are not supposed to re-implement GRPO, I think we should try to call UnslothEfficientGRPO, and pass in a curried version. the currying should provide most of the args
def curried_grpo_function(old_hidden_states, completion_input_ids, completion_mask, advantages, beta):
    def inner_fn(
        batch,  # new_hidden_states
        linear, # lm_head
        labels, # not used
    ):
        return UnslothEfficientGRPO.apply(
            _old_hidden_states=old_hidden_states,
            _new_hidden_states=batch,
            lm_head=linear,
            _input_ids=completion_input_ids, # the concatenated input ids of the prompt & completion
            _mask=completion_mask,           # the concatenated mask of the prompt & completion
            _advantages=advantages,          # the advantage scores of the new completions compared to the old ones
            beta=beta,
            scaler=None, # seems to be optional
            n_chunks=1,  # the chunking is done by the caller of this function
        )
    return inner_fn


In [None]:
import os
os.environ['UNSLOTH_IS_PRESENT'] = '1'
from unsloth_zoo.rl_replacements import UnslothEfficientGRPO
import torch

old_hidden_states = torch.randn(6, 241, 2048, dtype=torch.bfloat16, device="cuda", requires_grad=True)
new_hidden_states = torch.randn(6, 241, 2048, dtype=torch.bfloat16, device="cuda", requires_grad=True)
lm_head = torch.randn(128256, 2048, dtype=torch.bfloat16, device="cuda", requires_grad=True)
completion_input_ids = torch.randint(0, 128256, (6, 240), dtype=torch.int64, device="cuda")

# filter out 128004
# completion_mask = torch.randint(0, 2, (6, 240), dtype=torch.int64, device="cuda")
completion_mask = torch.ones_like(completion_input_ids)
advantages = torch.randn(6, dtype=torch.float32, device="cuda")
advantages = torch.zeros(6, dtype=torch.float32, device="cuda")
beta = 0.04
scaler = None
n_chunks = 6

UnslothEfficientGRPO.apply(
    new_hidden_states,
    old_hidden_states,
    lm_head,
    completion_input_ids,
    completion_mask,
    advantages,
    beta,
    scaler,
    n_chunks,
)

In [None]:
# example in the statement was given as follows:
# bsz = 4, qlen = 4096, hd = 4096, vocab = 128K
# lets do something similar

# example input with 4 samples, ctx = 8192 tokens, and hidden dimension of 2048
input = torch.randn(4, 8192, 2048, device="cuda", requires_grad=True)

# linear (dense) layer which accepts an input of 2048 hidden dimensions, and outputs to a vocabulary of 4096
linear = nn.Linear(2048, 4096).to("cuda")

# correct labels for the 16 samples in the batch, each sample needs to specify for ctx = 8192 tokens
labels = torch.randint(0, 4096, (4, 8192), device="cuda")

In [None]:
TEST_TRANSFORMATION_FUNCTION = True
if TEST_TRANSFORMATION_FUNCTION:
    # calculate the loss, which should result in a single scalar value
    loss = transformation_function(input, linear, labels)
    print(loss)

In [None]:
# forward.weight.data
# forward.bias.data
# labels.requires_grad

In [None]:
X = input
X0, X1 = torch.chunk(X, chunks=2, dim=0)
labels_0, labels_1 = torch.chunk(labels, chunks=2, dim=0)

In [None]:
Z0 = transformation_function(X0, linear, labels_0)

In [None]:
# to be very clear about the terminology here
# X is the input to the memory efficient linear function
# Y is W @ X + b, where W is the weight matrix and b is the bias of the linear layer
# Z is f(Y), where f is the transformation function
# the output is expected to be a single scalar value

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        # TODO early exit if there is only one sample or if the number
        # TODO of samples isn't a multiple of two
        n_batch = X.shape[0]
        # X0, X1 = torch.chunk(X, chunks=2, dim=0)
        labels_0, labels_1 = torch.chunk(labels, chunks=2, dim=0)
        with torch.enable_grad():
            X0 = X[:n_batch // 2]
            X1 = X[n_batch // 2:]
            assert X0.requires_grad
            assert X1.requires_grad
            Z0 = forward_function(X0, linear, labels_0)
            Z1 = forward_function(X1, linear, labels_1)
            # at some point, realized need to move the `X0 =` and `X1 =` into this block
            # use this to check if grad is working
            output = ((Z0.to(torch.float64) + Z1.to(torch.float64)) * 0.5).to(torch.float32)
        ctx.save_for_backward(X0, X1, Z0, Z1, linear.weight)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        X0, X1, Z0, Z1, linear_weight = ctx.saved_tensors
        
        # Gradient scaling factor from the mean
        grad_scale = grad_output * 0.5

        # print('Z0', Z0)
        # print('X0', X0)
        
        # Compute gradients w.r.t. X1 from Z0
        grad_X0 = torch.autograd.grad(Z0, X0, grad_outputs=grad_scale, retain_graph=True)[0]

        # Compute gradients w.r.t. X0 from Z1
        grad_X1 = torch.autograd.grad(Z1, X1, grad_outputs=grad_scale, retain_graph=True)[0]
        
        # grad_linear_weight = (
        #     torch.autograd.grad(Z0, linear_weight, grad_outputs=grad_scale, retain_graph=True)[0] +
        #     torch.autograd.grad(Z1, linear_weight, grad_outputs=grad_scale, retain_graph=True)[0]
        # )[0]

        # print('grad_X0', grad_X0)
        # print('grad_X1', grad_X1)
        # print('grad_linear_weight', grad_linear_weight)

        # Assemble full gradient w.r.t. X
        grad_X = torch.cat([grad_X0, grad_X1], dim=0)

        # Return gradients for all inputs
        return grad_X, None, None, None

In [None]:
loss = MemoryEfficientLinear.apply(input, linear, labels, transformation_function)

In [None]:
loss

In [None]:
loss.backward()

In [None]:
input

In [None]:
# run tests to see if the outputs match
for x in range(100):
    input = torch.randn(4, 8, 2, device="cuda", requires_grad=True)
    linear = nn.Linear(2, 4).to("cuda")
    labels = torch.randint(0, 4, (4, 8), device="cuda")
    expected = transformation_function(input, linear, labels)
    actual = MemoryEfficientLinear.apply(input, linear, labels, transformation_function)
    assert(torch.allclose(expected, actual))
    
    # now check if the backpropagation calculates the same
    expected.backward()
    gradI_expected = input.grad
    gradW_expected = linear.weight.grad
    gradB_expected = linear.bias.grad

    MemoryEfficientLinear.apply(input, linear, labels, transformation_function).backward()
    gradI_actual = input.grad
    gradW_actual = linear.weight.grad
    gradB_actual = linear.bias.grad

    assert(torch.allclose(gradI_expected, gradI_actual))
    assert(torch.allclose(gradW_expected, gradW_actual))
    assert(torch.allclose(gradB_expected, gradB_actual))

In [None]:
silly_vec = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
loss = (silly_vec * 2).sum()
loss.backward()
silly_vec.grad

In [None]:
# self score: 2 + 1 + 1 = 4
# if attemped_E:
#     E_score = 0
#     if VRAM_50_percent_reduction: E_score += 2
#     if remove_float32_upcast: E_score = 0
#     if show_ce_loss_works: E_score += 1
#     if show_other_functions_work: E_score += 1
#     if hardcoded_gradients: E_score = 0
#     if allows_dynamic_chunk_sizes: E_score += 1
#     if llama_1B_training_loss_matches: E_score += 1
#     else: E_score = 0
#     if GRPO_memory_efficient_linear_works: E_score += 4
#     final_score += E_score
# else:
#     final_score += 0