In this notebook, we do the experiments for quantizing KV vectors and LLM models using the framework from QUAROT (https://github.com/spcl/QuaRot/tree/main).

We do some modification to the Quarot repository and store our version here: https://github.com/tianhua2/my_quarot.git

We also edit the llama code in huggingface to add KV vector quantization and eviction codes. We store it as our own package here: https://github.com/tianhua2/my_huggingface.git
The KV vector eviction codes is based on H2O (https://github.com/FMInference/H2O) with some modification.

Environment Settings.

In [None]:
!huggingface-cli login --token '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'
!rm -rf my_huggingface
!git clone https://github.com/tianhua2/my_huggingface.git
%cd /content/my_huggingface
!pip install .
%cd /content/

!git clone https://github.com/EleutherAI/lm-evaluation-harness
%cd lm-evaluation-harness
!pip install .
%cd /content/

%cd /content/
!rm -rf my_quarot
!git clone https://github.com/tianhua2/my_quarot.git
%cd /content/my_quarot
!pip install .
%cd /content/

!pip install datasets
!pip install evaluate
!pip install accelerate -U
!pip install --upgrade pyarrow
import numpy as np
import pandas as pd
import pickle

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity

import matplotlib.pyplot as plt

from bokeh.io import show
from bokeh.plotting import figure
from bokeh.layouts import column
from bokeh.io import output_notebook
from bokeh.palettes import Category20# select a palette
import itertools
output_notebook()

import evaluate
from evaluate import load

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

from transformers import TrainingArguments
from transformers import AutoModelForSequenceClassification
from transformers import Trainer
import numpy as np

!cp /content/drive/MyDrive/kv_cache/llama2_7b_w4a4 /content/llama2_7b_w4a4

%cd /content/my_quarot/fake_quant/

llama2-7b fp16

In [None]:
%cd /content/my_quarot/fake_quant/
!python main.py --model meta-llama/Llama-2-7b-hf --a_bits 16 --v_bits 16 --k_bits 16 --w_bits 16 --bsz 4 --w_clip \
--lm_eval --lm_eval_batch_size 16 --tasks "piqa" "arc_easy" "arc_challenge" "winogrande" --eval_dataset 'c4'

llama2-7b W8A8KV4

In [None]:
%cd /content/my_quarot/fake_quant/
!python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 8 --v_bits 16 --k_bits 16 --w_bits 8 --bsz 1 --w_clip \
--DYNQ --KRON --KV_BITS1 4 --KV_BITS2 4 --KV_BITS3 4 --KV_BITS4 4 --heavy_budget_ratio1 0.15 --heavy_budget_ratio2 0.3 --heavy_budget_ratio3 0.5\
--TH_H 1e-3 --TH_L 1e-2 --TH_H2 1e-3 --TH_L2 1e-2 --TH_H3 1e-3 --TH_L3 1e-2 --TH_H4 1e-3 --TH_L4 1e-2\
--H2O --heavy_budget_ratio 0.3 --recent_budget_ratio 0.1 --score_coeff 0 --CACHE_SIZE 1024\
--lm_eval --lm_eval_batch_size 16 --tasks "piqa" "arc_easy" "arc_challenge" "winogrande" "lamabada"

llama2-13b fp16

In [None]:
%cd /content/my_quarot/fake_quant/
!python main.py --model meta-llama/Llama-2-13b-hf --a_bits 16 --v_bits 16 --k_bits 16 --w_bits 16 --bsz 1 --w_clip \
--lm_eval --lm_eval_batch_size 16 --tasks "piqa" "arc_easy" "arc_challenge" "winogrande"

llama2-13b W8A8KV4

In [None]:
%cd /content/my_quarot/fake_quant/
!python main.py --model meta-llama/Llama-2-13b-hf --rotate --a_bits 8 --v_bits 16 --k_bits 16 --w_bits 8 --bsz 1 --w_clip \
--DYNQ --KRON --KV_BITS1 4 --KV_BITS2 4 --KV_BITS3 4 --KV_BITS4 4 --heavy_budget_ratio1 0.15 --heavy_budget_ratio2 0.3 --heavy_budget_ratio3 0.5\
--TH_H 1e-3 --TH_L 1e-2 --TH_H2 1e-3 --TH_L2 1e-2 --TH_H3 1e-3 --TH_L3 1e-2 --TH_H4 1e-3 --TH_L4 1e-2\
--H2O --heavy_budget_ratio 0.3 --recent_budget_ratio 0.1 --score_coeff 0 --CACHE_SIZE 1024\
--lm_eval --lm_eval_batch_size 16 --tasks "piqa" "arc_easy" "arc_challenge" "winogrande" "lamabada"

llama3-8b fp16

In [None]:
%cd /content/my_quarot/fake_quant/
!python main.py --model meta-llama/Meta-Llama-3-8B --a_bits 16 --v_bits 16 --k_bits 16 --w_bits 16 --bsz 1 --w_clip \
--lm_eval --lm_eval_batch_size 16 --tasks "piqa" "arc_easy" "arc_challenge" "winogrande"

llama3-8b W8A8KV4

In [None]:
%cd /content/my_quarot/fake_quant/
!python main.py --model meta-llama/Meta-Llama-3-8B --rotate --a_bits 8 --v_bits 16 --k_bits 16 --w_bits 4 --bsz 1 --w_clip\
--DYNQ --KRON --KV_BITS1 4 --KV_BITS2 4 --KV_BITS3 4 --KV_BITS4 4 --heavy_budget_ratio1 0.03 --heavy_budget_ratio2 0.15 --heavy_budget_ratio3 0.8\
--TH_H 1e-3 --TH_L 1e-2 \
--H2O --heavy_budget_ratio 0.3 --recent_budget_ratio 0.1 --score_coeff 0 --CACHE_SIZE 256\
--lm_eval --lm_eval_batch_size 16 --tasks "piqa" "arc_easy" "arc_challenge" "winogrande" \
--load_qmodel_path "/content/llama3_8b_w4"