## ANN vs softmax

The goal of this notebook is to asses if it possible to scale llm to 500k+ vocab size without sacrifying on inference speed.

The main idea is to say that 

Max(softmax(nn.linear.forward(emebdding)) ~= vector search over the linear matrix rows.

Therefore if the vocab size grow too much we could use ann (approximate nearest neighbors)

In [1]:
#!pip install hnswlib

In [2]:
import torch
import torch.utils.benchmark as benchmark
from jaxtyping import Float, Int
from torch import Tensor, nn
from torchinfo import summary
import numpy as np
import hnswlib

In [3]:
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

In [4]:
device = "cuda"

In [5]:
torch.set_default_device(device)

In [6]:
torch.set_default_dtype(torch.bfloat16)

In [7]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fee9616e610>

## nn.Linear vs exact vector search


In [8]:
vocab_size = 500_000
embed_size = 4096

In [9]:
class ClassificationHead(nn.Module):
    def __init__(self, vocab_size: int, embed_size: int):
        super().__init__()

        self.linear = nn.Linear(embed_size, vocab_size)

    def forward(
        self, x: Float[Tensor, "batch embed_size"]
    ) -> Int[Tensor, "batch classes"]:
        
        y = self.linear(x)
        return y.softmax(dim=1).max(dim=1).indices

    def forward_nn(
        self, x: Float[Tensor, "batch embed_size"]
    ) -> Int[Tensor, "batch classes"]:
        
        y = self.linear(x)
        return y.max(dim=1).indices

In [10]:
head = ClassificationHead(vocab_size, embed_size)

In [11]:
x = torch.rand(8, embed_size)

In [12]:
y1 = head.forward(x)
y2 = head.forward_nn(x)
torch.allclose(y1, y2)

True

In [13]:
t1 = benchmark.Timer(
    stmt="head.forward_nn(x)",
    globals={"x": x, "head":head}
)
t1.timeit(100)

<torch.utils.benchmark.utils.common.Measurement object at 0x7fee96a86250>
head.forward_nn(x)
  5.17 ms
  1 measurement, 100 runs , 1 thread

In [14]:
t2 = benchmark.Timer(
    stmt="head.forward(x)",
    globals={"x": x, "head":head}
)
t2.timeit(100)

<torch.utils.benchmark.utils.common.Measurement object at 0x7fee961617d0>
head.forward(x)
  5.27 ms
  1 measurement, 100 runs , 1 thread

In [15]:
summary(head, input_data=x, dtypes=[torch.bfloat16])

Layer (type:depth-idx)                   Output Shape              Param #
ClassificationHead                       [8]                       --
├─Linear: 1-1                            [8, 500000]               2,048,500,000
Total params: 2,048,500,000
Trainable params: 2,048,500,000
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 16.39
Input size (MB): 0.07
Forward/backward pass size (MB): 16.00
Params size (MB): 4097.00
Estimated Total Size (MB): 4113.07

## nn.Lienear vs Ann

In [16]:
class ClassificationHead:

    def __init__(self, linear: nn.Linear):
        weight_array = linear.weight.data.detach().to(torch.float32).cpu().numpy()
        
        num_elements, dim = weight_array.shape

        print("start indexing")
        self.index = hnswlib.Index(space='ip', dim=dim)
        self.index.init_index(max_elements=num_elements, ef_construction=200, M=16)

        print("add index")

        # Adding the weight vectors to the index
        self.index.add_items(weight_array)
        self.index.set_ef(50)

    def forward(
        self, x: Float[np.ndarray, "batch embed_size"]
    ) -> Int[np.ndarray, "batch classes"]:  
        labels, _ = self.index.knn_query(x, k=1)
        return labels
        

In [17]:
%%time
head_hnsw = ClassificationHead(head.linear)

start indexing
add index
CPU times: user 7h 48min 17s, sys: 39.2 s, total: 7h 48min 56s
Wall time: 30min 51s


In [18]:
x_hnsw = x.to(torch.float32).cpu().numpy()

In [19]:
%%time
y3 = head_hnsw.forward(x_hnsw)

CPU times: user 9.78 ms, sys: 0 ns, total: 9.78 ms
Wall time: 9.58 ms


In [20]:
y1_np = y2.to(torch.float32).cpu().numpy()

In [21]:
y1_np.shape

(8,)

In [22]:
y3.flatten().shape

(8,)

In [23]:
np.testing.assert_allclose(y1_np, y3.flatten())

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 8 / 8 (100%)
Max absolute difference: 388007.
Max relative difference: 2.90451677
 x: array([153840., 192137., 228301.,  71834., 123258., 423469.,  71834.,
        71834.], dtype=float32)
 y: array([ 44385, 228402,  58471, 298397, 468588, 124622,  67025, 459841],
      dtype=uint64)