# Qwen2-0.5B Quantization Experiments

This notebook contains experiments for analyzing effects of quantizing activations based on various importance strategies on the perplexity of the model.

In [None]:
!pip install transformers==4.53



In [None]:
!pip install -U datasets



In [None]:
# Requisite imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import json
import pickle
from tqdm import tqdm
from datasets import load_dataset
import torch
from google.colab import files
import math

In [None]:
wikitext = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")



In [None]:
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-0.5B')


# Join the entire corpus
encodings = tokenizer("\n\n".join(wikitext["text"]), return_tensors="pt")

Token indices sequence length is longer than the specified maximum sequence length for this model (299078 > 32768). Running this sequence through the model will result in indexing errors


In [None]:
class QwenPointFiveBModel:
    def __init__(self, device) -> None:
        self.model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2-0.5B')
        tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-0.5B')
        self.tokenizer = tokenizer
        self.model.to(device)
        self.device = device
        self.num_layers = len(self.model.base_model.layers)
        self.qwen = self.model.base_model
        self.final_layer_norm = self.model.model.norm
        self.rotary_emb = self.qwen.rotary_emb

    def move_to_device(self):
        self.model.to(self.device)

    def remove_from_device(self):
        self.model.to("cpu")

    def activation_quantization(self, batched_input_tokens, target, importance_values, ratio, layer_of_interest):
        batched_input_tokens = batched_input_tokens.to(self.device)

        with torch.no_grad():
            hidden_states = self.qwen.embed_tokens(batched_input_tokens)
            position_ids = torch.arange(batched_input_tokens.size(1), dtype=torch.long, device=batched_input_tokens.device).unsqueeze(0).expand(batched_input_tokens.size(0), -1)
            position_embeddings = self.rotary_emb(hidden_states, position_ids)

            for i in range(self.num_layers):
                layer = self.qwen.layers[i]
                hidden_states = layer(hidden_states,position_embeddings=position_embeddings)[0]
                # Quantization simulation
                if i == layer_of_interest and ratio > 0:
                  # Get the ratio amount of the least important token positions. E.g., if ratio is 0.1, we get 10% of the least important token positions.
                  least_important_token_positions = torch.argsort(importance_values, descending=False)[:int(ratio * batched_input_tokens.size(1))]
                  # Quantize the hidden state activations corresponding to these token positions.
                  # Simulating symmetric 4-bit integer quantization
                  # Determine the maximum absolute value for scaling
                  max_val = torch.max(torch.abs(hidden_states[:, least_important_token_positions, :]))
                  # Scale the values to the range of 4-bit signed integers (-8 to 7)
                  # Number of levels is 2^4 = 16. For symmetric, we use range -2^(bits-1) to 2^(bits-1)-1
                  num_levels = 16
                  scaled_values = torch.clamp(hidden_states[:, least_important_token_positions, :] / max_val * (num_levels / 2 - 1), -(num_levels / 2), (num_levels / 2 - 1))
                  # Round to the nearest integer
                  quantized_values = torch.round(scaled_values)
                  # Scale back to the original range
                  dequantized_hidden_states = quantized_values / (num_levels / 2 - 1) * max_val

                  hidden_states[:, least_important_token_positions, :] = dequantized_hidden_states

                  ### NOTE: THIS IS A SIMPLIFIED SIMULATION OF SYMMETRIC INT4 QUANTIZATION.
                  ### REAL QUANTIZATION INVOLVES MORE NUANCES.

            post_norm = self.final_layer_norm(hidden_states)
            logits = self.model.lm_head(post_norm)
            # Logits shape: (batch_size, seq_len, vocab_size)
            # Targets are just the inputs. so, there is a need to shift them by 1
            target = target[:, 1:]
            # Since we do not have targets for the last token, we need to shift those as well
            logits = logits[:, :-1, :]

            # Calculate Cross entropy loss, which is the NLL in perplexity calculation
            cross_entropy = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)),target.view(-1), ignore_index=- 100)

        # Return the scalar negative log likelihood
        return cross_entropy

    def layer_by_layer_impl(self, batched_input_tokens, target):
          # input ids have shape batch size x num tokens
          batched_input_tokens = batched_input_tokens.to(self.device)
          # Get the final layer norm used in qwen architecture
          final_layer_norm = self.model.model.norm

          qwen = self.model.base_model
          rotary_emb = qwen.rotary_emb

          with torch.no_grad():
              hidden_states = qwen.embed_tokens(batched_input_tokens)
              # hidden_states shape: (batch_size, seq_len, model_dim)
              # Create position embeddings of shape batch_size , seq _len
              # Value of the position embeddings should be between 0 and the number of positions -1
              position_ids = torch.arange(batched_input_tokens.size(1), dtype=torch.long,
                                          device=batched_input_tokens.device).unsqueeze(0).expand(
                  batched_input_tokens.size(0), -1)
              position_embeddings = rotary_emb(hidden_states, position_ids)

              for i in range(self.num_layers):
                  layer = qwen.layers[i]
                  hidden_states = layer(hidden_states, position_embeddings=position_embeddings)[0]

              post_norm = final_layer_norm(hidden_states)
              logits = self.model.lm_head(post_norm)
              # logits shape: (batch_size, seq_len, vocab_size)
              # targets are just the inputs. so, there is a need to shift them by 1
              target = target[:, 1:]
              # Since we do not have targets for the last token, we need to shift those as well
              logits = logits[:, :-1, :]

              # Calculate Cross entropy loss, which is the NLL in perplexity calculation
              cross_entropy = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)),
                                                                    target.view(-1), ignore_index=- 100)

          # Return the scalar negative log likelihood
          return cross_entropy

In [None]:
device = "cuda"
qwen_model = QwenPointFiveBModel(device)

In [None]:
ratios = [0., 0.25, 0.5, 0.75, 1]

layers_of_interest = np.random.choice(list(range(24)), 5, replace=False)
print(layers_of_interest)

methods = ["regular_importance", "weighted_importance", "last_row", "aggregate_till"]


[22 18  3 23 11]


In [None]:
# Load the weights
with open('attention_head_weights.pkl', 'rb') as f:
    attention_head_weights = pickle.load(f)
print(len(attention_head_weights))
print(len(attention_head_weights[0]))

24
14


In [None]:
def get_importance_order(method, attention_map, num_layers, attention_head_weights):
  res = []
  aggregate_importance = 0
  for l in range(num_layers):
    if method == "regular_importance":
      # We take the average across all heads first.
      avg_across_heads = torch.mean(attention_map[l], dim = 1)
      # Shape is batch x seq x seq
      # Now, we take the column wise mean
      column_wise_mean = torch.mean(avg_across_heads, dim = 1)
      # Shape is batch x seq
      # Batch dimension will always be 1, so we remove that and append the importance
      # ordering to the result
      column_wise_mean = column_wise_mean.squeeze(0)
      res.append(column_wise_mean)
    elif method == "weighted_importance":
      # Get the attention map for the current layer and the weights for each head
      layer_attention_map = attention_map[l] # Shape: batch x heads x seq x seq
      layer_head_weights = attention_head_weights[l] # List of 'heads' tensors, each shape: seq x seq

      # Initialize a tensor to store the weighted sum of attention maps across heads
      weighted_sum_tensor = torch.zeros_like(layer_attention_map[:, 0, :, :])
      # Iterate through each head and apply the corresponding weight
      num_heads = layer_attention_map.shape[1]
      for h in range(num_heads):
          head_attention_map = layer_attention_map[:, h, :, :] # Shape: batch x seq x seq
          head_weight = layer_head_weights[h] # Shape: seq x seq
          # Multiply the head attention map by the head weight (broadcasting over batch)
          weighted_sum_tensor += head_attention_map * head_weight

      # Calculate the column-wise mean of the weighted sum tensor
      column_wise_mean = torch.mean(weighted_sum_tensor, dim=1)

      # Squeeze the batch dimension and append to the result
      res.append(column_wise_mean.squeeze(0))

    elif method == "last_row":
      # Take the last row of the attention map for each head, then average across heads
      last_row_attention = attention_map[l][:, :, -1, :]
      avg_across_heads = torch.mean(last_row_attention, dim=1)
      # Shape is batch x seq
      avg_across_heads = avg_across_heads.squeeze(0)
      res.append(avg_across_heads)
    elif method == "aggregate_till":
      # Running mean of importance till this layer
      current_layer_importance = torch.mean(attention_map[l], dim = 1)
      current_layer_importance = current_layer_importance.squeeze(0)
      current_layer_importance = torch.mean(current_layer_importance, dim = 0)
      aggregate_importance = aggregate_importance + current_layer_importance
      res.append(aggregate_importance / (l + 1))
    else:
      raise ValueError(f"Unknown method: {method}")

  # Return a list of tensors, where each tensor corresponds to the importance
  # order for a specific layer, according to the chosen method.
  return res

In [None]:
enthu_qwen = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2-0.5B', attn_implementation = "eager")
enthu_qwen.eval()
enthu_qwen.to(device)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [None]:
max_length = 512
stride = 32
seq_len = encodings.input_ids.size(1)
prev_end_loc = 0

total_num_layers = qwen_model.num_layers

total_nll = [[[0 for _ in range(len(ratios))] for __ in range(len(layers_of_interest))] for ___ in range(len(methods))]
n_tokens = 0

iterations = 0

for begin_loc in tqdm(range(0, seq_len, stride)):
  end_loc = min(begin_loc + max_length, seq_len)
  trg_len = end_loc - prev_end_loc
  input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
  target_ids = input_ids.clone()
  target_ids[:, :-trg_len] = -100

  with torch.no_grad():
      # We need eager attention here for importance calculation. Do not need this later on. Hence, two different model initializations.
      base_output = enthu_qwen(input_ids=input_ids, output_attentions=True)
      attention_map = base_output.attentions
  # Calculate importance values once per method per chunk
  importance_values_dict = {}
  for method in methods:
    importance_values_dict[method] = get_importance_order(method, attention_map, total_num_layers,
                                                        attention_head_weights)
  num_valid_tokens = (target_ids != -100).sum().item()  # number of valid tokens in target_ids
  batch_size = target_ids.size(0)
  num_loss_tokens = num_valid_tokens - batch_size  # subtract batch_size due to internal label shift

  for m, method in enumerate(methods):
      importance_values = importance_values_dict[method]
      for l, layer_of_interest in enumerate(layers_of_interest):
          for r, ratio in enumerate(ratios):
              neg_log_likelihood = qwen_model.activation_quantization(input_ids, target_ids,
                                                                      importance_values[layer_of_interest],
                                                                      ratio, layer_of_interest)
              total_nll[m][l][r] += (neg_log_likelihood.item() * num_loss_tokens)


  n_tokens += num_loss_tokens

  prev_end_loc = end_loc

  if iterations % 1000 == 0:
      print(f"Processed {iterations} chunks")
      print(f"Total NLL: {total_nll}")
      print(f"Total tokens: {n_tokens}")
      # Save the total_nll and n_tokens to a file
      with open('total_nll.pkl', 'wb') as f:
          pickle.dump(total_nll, f)
      with open('n_tokens.pkl', 'wb') as f:
          pickle.dump(n_tokens, f)

  iterations += 1

  if end_loc == seq_len:
      break


  0%|          | 1/9347 [00:16<41:36:40, 16.03s/it]

Processed 0 chunks
Total NLL: [[[1273.4980659484863, 1777.0852043628693, 2095.0119886398315, 2431.899836063385, 2731.8689646720886], [1273.4980659484863, 1524.619915008545, 1758.036422252655, 1927.6400637626648, 8231.311679840088], [1273.4980659484863, 1293.243848323822, 1309.5264372825623, 1364.405492067337, 7810.692914009094], [1273.4980659484863, 2196.9604301452637, 2999.7950868606567, 3815.591682434082, 4426.585940361023], [1273.4980659484863, 1293.444261789322, 1323.2758975028992, 1337.4595665931702, 6552.683093070984]], [[1273.4980659484863, 1835.323408126831, 2157.7349462509155, 2433.209041595459, 2731.8689646720886], [1273.4980659484863, 1422.7816095352173, 1731.9001915454865, 1943.0374221801758, 8231.311679840088], [1273.4980659484863, 1289.2568995952606, 1325.2110753059387, 1365.5066087245941, 7810.692914009094], [1273.4980659484863, 2249.862518787384, 3092.446536540985, 3869.968180656433, 4426.585940361023], [1273.4980659484863, 1300.781465768814, 1318.8445060253143, 1372.36

 11%|█         | 1001/9347 [4:32:37<37:54:16, 16.35s/it]

Processed 1000 chunks
Total NLL: [[[81561.27448298037, 150374.7779841423, 166651.6217069626, 174516.14770436287, 187345.64291787148], [81561.27448298037, 100817.95765119791, 113576.13408380747, 123648.90030604601, 507172.86792850494], [81561.27448298037, 83283.25349000096, 85354.20199272037, 88933.27927100658, 503339.1490621567], [81561.27448298037, 241659.6267094612, 280371.6874599457, 291834.346244812, 298685.15542411804], [81561.27448298037, 84788.90491971374, 87219.74526897073, 89662.54157668352, 397787.1414384842]], [[81561.27448298037, 145163.03697133064, 164832.23598647118, 174532.04102754593, 187345.64291787148], [81561.27448298037, 94267.92607817054, 110569.03363758326, 123383.88972973824, 507172.86792850494], [81561.27448298037, 82871.2812808752, 85049.00743149221, 88847.55956968665, 503339.1490621567], [81561.27448298037, 211544.75926089287, 265814.05832099915, 291390.39578294754, 298685.15542411804], [81561.27448298037, 84813.60305318236, 87409.80615136027, 89925.9847141206

 15%|█▍        | 1364/9347 [6:11:29<36:11:46, 16.32s/it]

In [None]:
avg_ppl_results = [[[0 for _ in range(len(ratios))] for __ in range(len(layers_of_interest))] for ___ in range(len(methods))]

for m, method in enumerate(methods):
  for l, layer_of_interest in enumerate(layers_of_interest):
    for r, ratio in enumerate(ratios):
        avg_ppl_results[m][l][r] = math.exp(total_nll[m][l][r] / n_tokens)

In [None]:
print(avg_ppl_results)