In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glog
glog.setLevel("WARN")

In [None]:
from lplr_llm.activation_aware.weight_compression import *
from lplr_llm.activation_aware.layer_quantization import *

In [None]:
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
HESSIAN_SAVE_PATH = "../../data/hessians/llama-2-7b"
DEVICE = "cuda:7"
RANK = 128
QLR_ITERS = 5
LPLR_ITERS = 10

LAYER = 23
SUBLAYER = TransformerSubLayers.VALUE

### Regular 4B Factors (no Gaussian Transform)

For speed of testing, we are using lattice quantization instead of LDLQ.

In [None]:
weight_comp_default_4B_factors = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=4, R_bits=4,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        hadamard_transform=False,
        compute_quantized_component=True,
        iters=QLR_ITERS,
        lplr_iters=1,
        rand_svd=True,
        update_order=["Q", "LR"],
        verbose=True
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
layer_quant = weight_comp_default_4B_factors.get_layer_quantizer(LAYER)
layer_quant.compress_sublayer(SUBLAYER)

In [None]:
print([round(val, 4) for val in layer_quant.sublayer_info[SUBLAYER].errors['LR']])

In [None]:
layer_quant.min_error(SUBLAYER)

In [None]:
layer_quant.plot_errors(SUBLAYER)

In [None]:
# Longer, not verbose
weight_comp_default_4B_factors = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=4, R_bits=4,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        hadamard_transform=False,
        compute_quantized_component=True,
        iters=30,
        lplr_iters=1,
        rand_svd=True,
        update_order=["Q", "LR"],
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)
layer_quant_longer = weight_comp_default_4B_factors.get_layer_quantizer(LAYER)
layer_quant_longer.compress_sublayer(SUBLAYER)

In [None]:
layer_quant_longer.min_error(SUBLAYER)

In [None]:
layer_quant_longer.plot_errors(SUBLAYER)

### 4B Factors with Hadamard

In [None]:
weight_comp_4B_factors_incoh = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=4, R_bits=4,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        hadamard_transform=False,
        hadamard_transform_L=True,
        hadamard_transform_R=True,
        iters=5,
        lplr_iters=1,
        rand_svd=True,
        update_order=["Q", "LR"],
        verbose=True
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
layer_quant2 = weight_comp_4B_factors_incoh.get_layer_quantizer(LAYER)
layer_quant2.compress_sublayer(SUBLAYER)

In [None]:
print([round(val, 4) for val in layer_quant2.sublayer_info[SUBLAYER].errors['LR']])

In [None]:
layer_quant2.min_error(SUBLAYER)

In [None]:
layer_quant2.plot_errors(SUBLAYER, plot_first_iter=True)

In [None]:
# longer, not verbose
weight_comp_4B_factors_incoh = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=4, R_bits=4,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        hadamard_transform=False,
        hadamard_transform_L=True,
        hadamard_transform_R=True,
        iters=30,
        lplr_iters=1,
        rand_svd=True,
        update_order=["Q", "LR"]
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)
layer_quant_longer2 = weight_comp_4B_factors_incoh.get_layer_quantizer(LAYER)
layer_quant_longer2.compress_sublayer(SUBLAYER)

In [None]:
layer_quant_longer2.min_error(SUBLAYER)

In [None]:
layer_quant_longer2.plot_errors(SUBLAYER)

### 4B Factors with Haar

In [None]:
weight_comp_4B_factors_incoh_2 = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=4, R_bits=4,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        hadamard_transform=True,
        hadamard_transform_L=True,
        Haar_transform_L=False,
        iters=QLR_ITERS,
        lplr_iters=LPLR_ITERS,
        rand_svd=True,
        update_order=["Q", "LR"],
        verbose=True
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
layer_quant3 = weight_comp_4B_factors_incoh_2.get_layer_quantizer(LAYER)
layer_quant3.compress_sublayer(SUBLAYER)

In [None]:
layer_quant3.min_error(SUBLAYER)

In [None]:
print([round(val, 4) for val in layer_quant3.sublayer_info[SUBLAYER].errors['LR']])

In [None]:
layer_quant3.plot_errors(SUBLAYER, plot_first_iter=True)

In [None]:
weight_comp_4B_factors_incoh_2 = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=4, R_bits=4,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        hadamard_transform=True,
        incoherence_process_LR=True,
        Haar_transform_L=True,
        iters=50,
        lplr_iters=LPLR_ITERS,
        rand_svd=True,
        update_order=["Q", "LR"],
        verbose=False
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)
layer_quant_longer3 = weight_comp_4B_factors_incoh_2.get_layer_quantizer(LAYER)
layer_quant_longer3.compress_sublayer(SUBLAYER)

In [None]:
layer_quant_longer3.min_error(SUBLAYER)

### For comparison: 16B Factors

In [None]:
weight_comp_16B_factors = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=16, R_bits=16,
        lattice_quant_LR=False,
        rank=RANK,
        activation_aware_Q=False,
        activation_aware_LR=True,
        incoherence_process_LR=False,
        iters=50,
        lplr_iters=1,
        rand_svd=True,
        update_order=["Q", "LR"]
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
layer_quant4 = weight_comp_16B_factors.get_layer_quantizer(LAYER)
layer_quant4.compress_sublayer(SUBLAYER)

In [None]:
layer_quant4.min_error(SUBLAYER)

In [None]:
layer_quant4.plot_errors(SUBLAYER, plot_first_iter=True)

### What about 2B Factors?

In [None]:
weight_comp_2B_factors_incoh = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=2, R_bits=2,
        lattice_quant_LR=True,
        rank=RANK,
        activation_aware_Q=True,
        activation_aware_LR=True,
        hadamard_transform=True,
        incoherence_process_LR=True,
        iters=QLR_ITERS,
        lplr_iters=LPLR_ITERS,
        rand_svd=True,
        update_order=["Q", "LR"]
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
layer_quant4 = weight_comp_2B_factors_incoh.get_layer_quantizer(LAYER)
layer_quant4.compress_sublayer(SUBLAYER)

In [None]:
layer_quant4.min_error(SUBLAYER)

In [None]:
layer_quant4.plot_errors(SUBLAYER, plot_first_iter=True)