In [11]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

In [12]:
# Example: 100 input features -> 10 output features
width = 512
linear = nn.Linear(width, width)

# Input after activation (sparse)
x = torch.zeros(width)
x[5] = 0.8
x[72] = 0.3

start_time = time.time()
x = F.relu(x)

# 1. Get non-zero indices and values
nonzero_indices = x.nonzero(as_tuple=True)[0]
nonzero_values = x[nonzero_indices]  # shape: (k,)

# 2. Select corresponding weight columns (W has shape [10, 100])
W = linear.weight  # shape: [10, 100]
W_selected = W[:, nonzero_indices]  # shape: [10, k]

for i in range(100):
    # 3. Compute output using only non-zero elements
    # Equivalent to W[:, idx] @ x[idx] + b
    output = W_selected @ nonzero_values + linear.bias  # shape: [10]

end_time = time.time()
print(output.shape)  # torch.Size([10])
print(f'Time taken: {end_time - start_time}')


start_time = time.time()
x = F.relu(x)
for i in range(100):
    output = linear.forward(x)
end_time = time.time()
print(output.shape)  # torch.Size([10])
print(f'Time taken: {end_time - start_time}')

torch.Size([512])
Time taken: 0.0006811618804931641
torch.Size([512])
Time taken: 0.0013890266418457031
