In [None]:
from caldera.decomposition.layer_quantization import *
from caldera.decomposition.weight_compression import *
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
HESSIAN_SAVE_PATH = "../data/Hessians-Llama-2-7b-6144"
DEVICE = "cuda:0"
RANK = 256
QLR_ITERS = 5
LPLR_ITERS = 10

LAYER = 25
SUBLAYER = TransformerSubLayers.GATE

In [None]:
LR_BITS = 4
DOWNDATE = False

In [None]:
# First, instantiate a weight compressor
weight_comp = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=CalderaParams(
        Q_bits=2,
        L_bits=LR_BITS, R_bits=LR_BITS,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=True,
        activation_aware_LR=True,
        hadamard_transform=True,
        iters=QLR_ITERS,
        lplr_iters=LPLR_ITERS,
        rand_svd=True,
        Q_hessian_downdate=DOWNDATE,
        update_order=["LR", "Q"]
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
# Then, get the layer quantizer for the particular layer you want to quantize
layer_quant = weight_comp.get_layer_quantizer(LAYER)

In [None]:
layer_quant.compress_sublayer(SUBLAYER)

In [None]:
# plot the errors
layer_quant.plot_errors(SUBLAYER)

"LR" refers to the Frobeius norm error after the LPLR step, and "Q" refers to the Frobenius norm error after the LDLQ step. The first-iteration error for LR will be high, since Q is still set to zero. So, there's the option for you to omit the first iteration while plotting.

In [None]:
layer_quant.plot_errors(SUBLAYER, plot_first_iter=False)

Here is how to get the error arrays if you want to plot errors for different quatization parameters on the same plot.

In [None]:
errors = layer_quant.sublayer_info[SUBLAYER].caldera.errors
errors

You can also export the errors as a JSON file.

In [None]:
OUTFILE = "errors.json" # change this!
layer_quant.export_errors_json(SUBLAYER, OUTFILE)