In [1]:
import sys

sys.path.append("..")

from logger import get_logger

%load_ext autoreload
%autoreload 2

In [2]:
logger = get_logger(__name__)
logger.debug("This is a debug message")

DEBUG:__main__:This is a debug message


In [3]:
import torch

In [8]:
x = torch.nn.Parameter(torch.randn(5, 5), requires_grad=True)

y = 5 * x + 2

vals, indices = torch.topk(y, 2)

print(vals.requires_grad, indices.requires_grad)

loss = vals.norm()
loss.backward()

True False


In [6]:
vals._grad

  vals._grad


In [3]:
import torch
import torch.nn as nn


class TopKSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, k):
        # Get the values and indices of the top k elements
        vals, indices = torch.topk(input, k, dim=-1)
        ctx.save_for_backward(indices, torch.tensor(input.shape))
        return vals, indices

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        indices, input_shape = ctx.saved_tensors
        grad_input = torch.zeros(tuple(input_shape), device=grad_output.device)
        grad_input.scatter_(-1, indices, grad_output)
        return grad_input, None


# Test function
def test_topk_ste():
    # Create a simple model that uses TopKSTE
    class SimpleModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(5, 5)
            self.topk_ste = TopKSTE.apply

        def forward(self, x):
            x = self.linear(x)
            values, indices = self.topk_ste(x, 3)
            return values.sum(), indices

    # Create an instance of the model
    model = SimpleModel()

    # Create a random input tensor
    x = torch.randn(1, 5, requires_grad=True)

    # Forward pass
    output, indices = model(x)

    # Backward pass
    output.backward()

    # Print results
    print("Input:", x)
    print("Output (sum of top-3 values):", output)
    print("Top-3 indices:", indices)
    print("Input gradient:", x.grad)
    print("Linear layer weight gradient:", model.linear.weight.grad)


# Run the test
test_topk_ste()

Input: tensor([[-0.3632, -0.4344, -0.5378, -0.0366, -1.7030]], requires_grad=True)
Output (sum of top-3 values): tensor(1.5141, grad_fn=<SumBackward0>)
Top-3 indices: tensor([[1, 0, 3]])
Input gradient: tensor([[ 0.1855, -0.7287, -0.2749,  0.0936, -0.5795]])
Linear layer weight gradient: tensor([[-0.3632, -0.4344, -0.5378, -0.0366, -1.7030],
        [-0.3632, -0.4344, -0.5378, -0.0366, -1.7030],
        [-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
        [-0.3632, -0.4344, -0.5378, -0.0366, -1.7030],
        [-0.0000, -0.0000, -0.0000, -0.0000, -0.0000]])
