<a href="https://colab.research.google.com/github/pashtetttt/noMAD-attention/blob/main/NoMAD_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers torch datasets accelerate bitsandbytes

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from 

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import numpy as np

In [3]:
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", torch_dtype=torch.float16)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

In [4]:
test_data = load_dataset("wikitext", "wikitext-2-v1", split="test")
test_texts = [text for text in test_data["text"] if len(text) > 0]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/685k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.07M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/618k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [5]:
def compute_perplexity(model, tokenizer, texts, max_length=512):
    nlls = []
    for text in texts[:20]:  # Use first 20 samples for speed
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            nlls.append(outputs.loss.item())
    return np.exp(np.mean(nlls))

In [6]:
orig_ppl = compute_perplexity(model, tokenizer, test_texts)
print(f"Original Perplexity: {orig_ppl:.2f}")

Original Perplexity: 71.02


Mock NoMAD-Attention

In [7]:
import numpy as np
from sklearn.cluster import KMeans


class NoMADAttentionWrapper:
    def __init__(self, model, n_centroids=16, d_sub=64):
        self.model = model
        self.n_centroids = n_centroids
        self.d_sub = d_sub
        self.centroids = None

    def train_centroids(self, calibration_texts):
        # Extract key vectors
        key_vectors = []
        for text in calibration_texts[:50]:  # Small calibration set
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
                key_vectors.append(outputs.hidden_states[-1].mean(dim=1).numpy())  # [1, d_model]
        key_vectors = np.vstack(key_vectors)  # [n_samples, d_model]

        # Train centroids per sub-quantizer
        S = key_vectors.shape[1] // self.d_sub
        self.centroids = []
        for s in range(S):
            sub_vecs = key_vectors[:, s*self.d_sub : (s+1)*self.d_sub]
            kmeans = KMeans(n_clusters=self.n_centroids).fit(sub_vecs)
            self.centroids.append(kmeans.cluster_centers_)

    def nomad_forward(self, input_ids):
        # Mock: Replace dot-products with lookup-based approximation
        inputs = tokenizer(input_ids, return_tensors="pt", return_attention_mask=False)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.logits


In [8]:

nomad_model = NoMADAttentionWrapper(model)
nomad_model.train_centroids(test_texts[:50])

In [9]:
def compute_noMAD_perplexity(nomad_model, tokenizer, texts):
    nlls = []
    for text in texts[:20]:  # Same samples as baseline
        logits = nomad_model.nomad_forward(text)
        # Mock loss calculation (real impl. requires full attention replacement)
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            loss = torch.nn.functional.cross_entropy(
                logits[:, :-1].reshape(-1, logits.shape[-1]),
                inputs["input_ids"][:, 1:].reshape(-1))
            nlls.append(loss.item())
    return np.exp(np.mean(nlls))

nomad_ppl = compute_noMAD_perplexity(nomad_model, tokenizer, test_texts)
print(f"NoMAD Perplexity: {nomad_ppl:.2f}")

NoMAD Perplexity: 70.97


In [10]:
print(f"Original PPL: {orig_ppl:.2f} | NoMAD PPL: {nomad_ppl:.2f}")
print(f"Relative Change: {((nomad_ppl - orig_ppl) / orig_ppl * 100):.1f}%")

Original PPL: 71.02 | NoMAD PPL: 70.97
Relative Change: -0.1%


inference time

In [11]:
import time

def benchmark_inference(model, tokenizer, texts, method="original"):
    times = []
    for text in texts[:20]:  # Same 20 samples
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        start = time.time()
        if method == "original":
            with torch.no_grad():
                _ = model(**inputs)
        else:  # NoMAD
            _ = nomad_model.nomad_forward(text)
        times.append(time.time() - start)
    return np.mean(times) * 1000

In [12]:
# Benchmark
orig_time = benchmark_inference(model, tokenizer, test_texts, "original")
nomad_time = benchmark_inference(nomad_model, tokenizer, test_texts, "nomad")

print(f"Original Attention Time: {orig_time:.1f}ms")
print(f"NoMAD Time: {nomad_time:.1f}ms")
print(f"Speedup: {orig_time / nomad_time:.1f}x")

Original Attention Time: 20244.1ms
NoMAD Time: 20126.3ms
Speedup: 1.0x
