In [None]:
from phantominator import shepp_logan
from matplotlib import pyplot as plt
import numpy as np

import torch

from loguru import logger
import pathlib
from natsort import natsorted, ns

import pandas as pd

from transformers import AutoModelForCausalLM

import csv

### Notes 01/18
- Hadamard transform during finetuning
- Change the optimization scheme for LoftQ --- try a projected ADMM algorithm (usually does better than just alternating optimization)
- Q + LR + Sparse, optimize via ADMM. Make sparsity a constraint instead of an $\ell_1$ term, with a pre-determined sparsity level

For quant: $$min_{Q \in \mathcal{C}, Q' \in \mathbb{R}^{n\times m}, L, R} \text{<fro norm>} \text{ s.t. } Q'=Q$$

- Use the spectrum to determine which scheme to use

### Notes 01/11
- Preventing Diverging LoftQ optimization w/ "momentum":
$B_{k+1} = Q((1-\alpha) B_k + \alpha * B_k)$
- Question: does NF work or not; convergence issue
- Try sketching
- Can rotate $L$ and $R$ with some unitary matrix $H$ that makes the data more NF-quantizable (more Gaussian)
- Add F-norm regularization on $L$, $R$
- Some normalize and shift
- Sparse $Q$
- Make $Q$ have a Kronecker structure ($\exist$ papers on this)
- Replace $Q$ with $QD$, where $D$ is full-precision and diagonal
- Try changing the objective function (type of norm)?
- Use data-aware oprimization: right-multiply by a batch of training data (see: GPTQ)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from lplr_llm.quantization import *
from lplr_llm.weight_compressors import *
from lplr_llm.hyperparameter_sweeps import *
from lplr_llm.benchmarkers import *
from peft.utils.loftq_utils import loftq_init
from peft.utils.loftq_lplr_utils import loftq_lplr_init
from peft.utils.quantization_utils import NFQuantizerFactory
from lplr_llm.enums import *

In [None]:
DEFAULT_DEVICE = "cuda:2"

In [None]:
mistral = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

In [None]:
# Print names of all layers in mistral
for layer_name, layer_weight in mistral.named_parameters():
    print(f"Layer name: {layer_name}: Shape:{layer_weight.detach().to(DEFAULT_DEVICE).shape}")

In [None]:
layer_name, X_mis = list(mistral.named_parameters())[10]
X_mis = X_mis.detach().to(DEFAULT_DEVICE)
print(X_mis.shape)

In [None]:
quantiles = np.quantile(X_mis.flatten().cpu(), np.arange(200)/200)
approx_pdf = 2/(quantiles[20:-10] -quantiles[10:-20])
approx_pdf_idxs = quantiles[15:-15]
plt.figure(figsize=(15, 3))
plt.title("Approximate distribution of Weights")
plt.xlabel("Weight Value")
plt.ylabel("PDF")
plt.plot(approx_pdf_idxs, approx_pdf, color="blue")

In [None]:
print(X_mis.shape)
_, S, _ = torch.linalg.svd(X_mis.float(), full_matrices=False)

# Plot the singular values
plt.figure(figsize=(6, 3))
plt.plot(S.cpu(), marker='o', linestyle='-', color='b')
plt.title(f'Singular Values of Mistral layer {layer_name}')
plt.xlabel('Index', color="white")
plt.ylabel('Singular Value', color="white")
plt.yscale('log')
plt.grid(True)
plt.show()

In [None]:
# Generate a phantominator matrix
# X = torch.Tensor(shepp_logan(2048))
# plt.imshow(X, cmap="gray", interpolation="nearest")
# plt.show()
# _, S, _ = torch.linalg.svd(X.float(), full_matrices=False)

# # Plot the singular values
# plt.figure(figsize=(11, 3))
# plt.plot(S, marker='o', linestyle='-', color='b')
# plt.title('Singular Values of Phantom(1024)')
# plt.xlabel('Index', color="white")
# plt.ylabel('Singular Value', color="white")
# plt.yscale('log')
# plt.grid(True)
# plt.show()

In [None]:
def test_iterative_weight_compression(
    X = None,
    weight_comp_configs: list[WeightCompressionConfig]=[
        WeightCompressionConfig(
            algorithm_type=AlgorithmType.ALTERNATING_MIXED_LPLR,
            algorithm_kwargs={
                "k": 64, "r1": 0, "r2": 0,
                "B1": 8, "B2": 8
            }
        )
    ],
    plot_title = "Frobenius Norm Errors over Iterations",
    seed=42
):
    plot_colors = ["b", "r", "g", "c", "m", "k"]
    plot_markers = ["o", "X", "*"]
    # Set random seed for reproducibility
    torch.manual_seed(seed)

    plt.figure(figsize=(11, 3))
    for i, config in enumerate(weight_comp_configs):
        kwargs = config.algorithm_kwargs.copy()
        algorithm_type = config.algorithm_type
        kwargs["log_errors"] = True

        if config.hadamard:
            result = hadamard_weight_compression(
                X=X, config=WeightCompressionConfig(
                    algorithm_type=config.algorithm_type,
                    algorithm_kwargs=kwargs
                )
            )
            errors = result[-1]
        elif algorithm_type == AlgorithmType.ALTERNATING_MIXED_LPLR:
            kwargs["X"] = X
            _, _, errors = alternating_mixed_lplr(**kwargs)
        elif algorithm_type == AlgorithmType.DIRECT_SVD_LPLR:
            kwargs["X"] = X
            _, _, errors = direct_svd_mixed_lplr(**kwargs)
        elif algorithm_type == AlgorithmType.LOFTQ:
            kwargs["weight"] = X
            _, _, _, errors = loftq_init(**kwargs)
        else: ## Loftq-LPLR
            kwargs["weight"] = X
            _, _, _, errors = loftq_lplr_init(**kwargs)

        fro_norm_X = torch.norm(X, p="fro").item()
        relative_errors = np.array(errors) / fro_norm_X

        print(relative_errors)

        # Plot errors over iterations
        plt.plot(
            range(1, len(relative_errors) + 1),
            relative_errors,
            marker=plot_markers[(i // len(plot_colors)) % len(plot_markers)],
            linestyle="-",
            markersize=4,
            color=plot_colors[i % len(plot_colors)],
            label=f"Param Set {i+1}*")

    print("-"*80, "\n* Legend Key")
    for i, config in enumerate(weight_comp_configs):
        print(f"Param Set {i+1}: ", config.algorithm_kwargs)
        print("\tusing algorithm type ", config.algorithm_type)
        if config.hadamard:
            print("\twith randomized Hadamard transform")

    plt.title(plot_title)
    plt.xlabel("Iteration")
    plt.ylabel("Error")
    plt.yscale("log")
    plt.grid(True)
    plt.legend()
    plt.show()

In [None]:
## WITHOUT HADAMARD SKETCH
test_iterative_weight_compression(
    weight_comp_configs=[
        WeightCompressionConfig(
            algorithm_kwargs={
                "num_bits": 4, "reduced_rank": 64, "num_iter": 30, "quantizer_factory": QuantizerFactory("normal")
            },
            algorithm_type=AlgorithmType.LOFTQ,
            hadamard=False
        ),
        # WeightCompressionConfig(
        #     algorithm_kwargs={
        #         "num_bits": 4, "num_bits_factors": 8, "reduced_rank": 64,
        #         "num_iter": 50, "num_iter_lplr": 30, "quantizer_factory": QuantizerFactory("normal")
        #     },
        #     algorithm_type=AlgorithmType.LOFTQ_LPLR,
        #     hadamard=False
        # )
    ],
    plot_title="Frobenius Norm Errors over Iterations",
    X=X_mis.T
)

In [None]:
## WITH HADAMARD
test_iterative_weight_compression(
    weight_comp_configs=[
        WeightCompressionConfig(
            algorithm_kwargs={
                "num_bits": 4, "reduced_rank": 64, "num_iter": 50, "quantizer_factory": QuantizerFactory("normal")
            },
            algorithm_type=AlgorithmType.LOFTQ,
            hadamard=True
        ),
        # WeightCompressionConfig(
        #     algorithm_kwargs={
        #         "num_bits": 4, "num_bits_factors": 8, "reduced_rank": 64,
        #         "num_iter": 50, "num_iter_lplr": 30, "quantizer_factory": QuantizerFactory("normal")
        #     },
        #     algorithm_type=AlgorithmType.LOFTQ_LPLR,
        #     hadamard=True
        # )
    ],
    plot_title="Frobenius Norm Errors over Iterations",
    X=X_mis.T,
    seed=42
)

In [None]:
test_iterative_weight_compression(
    weight_comp_configs=[
        WeightCompressionConfig(
            algorithm_kwargs={
                "num_bits": 4, "reduced_rank": 64, "num_iter": 50, "quantizer_factory": QuantizerFactory("uniform_clipped")
            },
            algorithm_type=AlgorithmType.LOFTQ,
            hadamard=True
        ),
        WeightCompressionConfig(
            algorithm_kwargs={
                "num_bits": 4, "reduced_rank": 64, "num_iter": 50, "quantizer_factory": QuantizerFactory("uniform_clipped")
            },
            algorithm_type=AlgorithmType.LOFTQ,
            hadamard=False
        )
    ],
    plot_title="Frobenius Norm Errors over Iterations",
    X=X_mis.T
)

In [None]:
### Do some hyperparameter sweeps on the Shepp-Logan matrix
average_bit_level = 3
budget = X_mis.shape[0] * X_mis.shape[1] * average_bit_level

In [None]:
mtxs, alpha, beta, B, error = lplr_sweep_alpha_and_B(
    X=X_mis.T, budget=budget,
    weight_comp_config=WeightCompressionConfig(
        algorithm_kwargs={"quantizer_factory": NFQuantizerFactory("normal"), "iters":30},
        algorithm_type=AlgorithmType.ALTERNATING_MIXED_LPLR
    ),
    debug=True
)

In [None]:
mtxs, alpha, beta, B, error = lplr_sweep_alpha_and_B(
    X=X_mis.T, budget=budget,
    weight_comp_config=WeightCompressionConfig(
        algorithm_kwargs={"quantizer_factory": QuantizerFactory("normal"), "num_bits": 2, "num_iter": 20, "num_iter_lplr": 20},
        algorithm_type=AlgorithmType.LOFTQ_LPLR
    ),
    debug=True
)

In [None]:
class BenchmarkerComparison():
    def __init__(
        self,
        benchmarkers: list[WeightCompressionBenchmarker] = [],
        enforce_budget = True,
        average_bit_level = 4,
        save_to_csv = False,
        continue_csv = False,
        save_file = None,
        reset_error_lists = True
    ):
        if reset_error_lists:
            for benchmarker in benchmarkers:
                benchmarker.reset_errors()

        self.benchmarkers = benchmarkers
        self.save_to_csv = save_to_csv or continue_csv
        self.continue_csv = continue_csv
        self.save_file = save_file
        self.enforce_budget = enforce_budget
        self.average_bit_level = average_bit_level if enforce_budget else 256
        
        if save_to_csv and not continue_csv:
            with open(save_file, 'w') as csvfile:
                writer = csv.writer(csvfile)

                first_headers = ["Layer Name", "n", "d"]
                if enforce_budget:
                    first_headers.append("Bit Budget")
                writer.writerow(first_headers + [benchmarker.label for benchmarker in benchmarkers])
        if continue_csv:
            df = pd.read_csv(save_file)
            self.prev_layer_names = list(df["Layer Name"])
    
    def write_latest_data(self, layer_name, n, d, budget=0):
        with open(self.save_file, 'a') as csvfile:
            writer = csv.writer(csvfile)
            first_items = [layer_name, n, d]
            if self.enforce_budget:
                first_items.append(budget)
            writer.writerow(first_items + [benchmarker.errors[-1] for benchmarker in self.benchmarkers])
    
    def run_on_matrix(self, layer_name, X):
        if self.continue_csv and layer_name in self.prev_layer_names:
            return
        print(f"Benchmarking {layer_name}")
        n, d = X.size()
        budget = n*d*self.average_bit_level

        for benchmarker in self.benchmarkers:
            benchmarker.run(X, budget)

        if self.save_to_csv:
            self.write_latest_data(layer_name, n, d, budget)

    def print_errors(self):
        for benchmarker in self.benchmarkers:
            print(f"{benchmarker.label}: {benchmarker.errors}")

    def plot_errors(self):
        plot_colors = ["b", "r", "g", "c", "m", "k"]
        plot_markers = ["o", "X", "*"]

        plt.figure(figsize=(15, 4))
        for i, benchmarker in enumerate(self.benchmarkers):
            plt.plot(
                benchmarker.errors,
                marker=plot_markers[(i // len(plot_colors)) % len(plot_markers)],
                linestyle="-",
                markersize=4,
                color=plot_colors[i % len(plot_colors)],
                label=benchmarker.label
            )
        plt.title("Relative Frobenius Error")
        plt.xlabel("Layer")
        plt.ylabel("Error")
        plt.yscale("log")
        plt.grid(True)
        plt.legend()
        plt.show()


In [None]:
class BenchmarkerComparisonList(BenchmarkerComparison):
    def __init__(
        self,
        X_list: list[torch.Tensor] = [],
        **kwargs
    ):
        self.X_list = X_list
        super().__init__(**kwargs)

    def run(self):
        for i, X in enumerate(self.X_list):
            self.run_on_matrix(f"Matrix {i}", X)

In [None]:
class BenchmarkerComparisonModel(BenchmarkerComparison):
    def __init__(
        self,
        model,
        device = "cpu",
        layer_limit:int = -1, # Limits the number of weight matrices used
                              # for benchmarking (mainly for debugging purposes).
                              # -1 means no limit.
         max_num_cols:int = -1, # Excludes weight matrices that are too large.
        **kwargs
    ):
        self.model = model
        self.device = device
        self.layer_limit = layer_limit if layer_limit > 0 else float('inf')
        self.max_num_cols = max_num_cols if max_num_cols > 0 else float('inf')
        super().__init__(**kwargs)

    def run(self):
        layer_count = 0
        for layer_name, X in self.model.named_parameters():
            if self.continue_csv and layer_name in self.prev_layer_names:
                continue
            if X.dim() != 2:
                continue
            if X.shape[0] < X.shape[1]:
                X = X.T
            if X.shape[0] > self.max_num_cols:
                logger.info("Layer larger than maximum size specified, skipping.")
                continue

            layer_count += 1
            if layer_count > self.layer_limit:
                logger.info("Reached layer limit, exiting.")
                return layer_count
            
            self.run_on_matrix(layer_name, X.float().to(self.device))
            

In [None]:
benchmarkers = [
        # LoftqBenchmarker(
        #     WeightCompressionConfig(
        #         algorithm_kwargs={
        #             "num_iter": 20,
        #             "reduced_rank": 64,
        #             "num_bits": 2,
        #             "quant_type": QuantType.UNIFORM
        #         },
        #         algorithm_type=AlgorithmType.LOFTQ
        #     ),
        #     fixed_rank=True,
        #     label="LoftQ (2b)"
        # ),
        # LoftqBenchmarker(
        #     WeightCompressionConfig(
        #         algorithm_kwargs={
        #             "num_iter": 20,
        #             "reduced_rank": 64,
        #             "num_bits": 4,
        #             "quant_type": QuantType.UNIFORM
        #         },
        #         algorithm_type=AlgorithmType.LOFTQ
        #     ),
        #     fixed_rank=True,
        #     label="LoftQ (4b)"
        # ),
        LoftqBenchmarker(
            WeightCompressionConfig(
                algorithm_kwargs={
                    "num_iter": 20,
                    "reduced_rank": 64,
                    "num_bits": 8,
                    "quantizer_factory": QuantizerFactory("normal")
                },
                algorithm_type=AlgorithmType.LOFTQ
            ),
            label="LoftQ (8b)"
        ),
        # LplrBenchmarker(
        #     WeightCompressionConfig(
        #         algorithm_kwargs={
        #             "quant_type": QuantType.UNIFORM,
        #             "num_bits": 2,
        #             "num_bits_factors": 8,
        #             "reduced_rank": 64,
        #             "num_iter": 20,
        #             "num_iter_lplr": 30
        #         },
        #         algorithm_type=AlgorithmType.LOFTQ_LPLR
        #     ),
        #     label="Loftq-LPLR (2b)",
        #     run_hyper_parameter_sweep=False
        # ),
        # LplrBenchmarker(
        #     WeightCompressionConfig(
        #         algorithm_kwargs={
        #             "quant_type": QuantType.UNIFORM,
        #             "num_bits": 4,
        #             "num_bits_factors": 8,
        #             "reduced_rank": 64,
        #             "num_iter": 20,
        #             "num_iter_lplr": 30
        #         },
        #         algorithm_type=AlgorithmType.LOFTQ_LPLR
        #     ),
        #     label="Loftq-LPLR (4b)",
        #     run_hyper_parameter_sweep=False
        # ),
        LplrBenchmarker(
            WeightCompressionConfig(
                algorithm_kwargs={
                    "quantizer_factory": QuantizerFactory("normal"),
                    "num_bits": 8,
                    "num_bits_factors": 8,
                    "reduced_rank": 64,
                    "num_iter": 20,
                    "num_iter_lplr": 30
                },
                algorithm_type=AlgorithmType.LOFTQ_LPLR
            ),
            label="Loftq-LPLR (8b)",
            run_hyper_parameter_sweep=False
        ),
    ]

In [None]:
## For debugging, we set limits on how many layers we use
## and their sizes
comparison_object = BenchmarkerComparisonModel(
    model=mistral,
    device=DEFAULT_DEVICE,
    layer_limit=5,
    max_num_cols=15000,
    benchmarkers=benchmarkers,
    enforce_budget=False,
    continue_csv=True,
    save_file="/home/nsagan/experiments/results/Loftq_LPLR_comparsion_2.csv"
)

In [None]:
comparison_object.run()