## Notebook to try out CALDERA decomposition on a random matrix

In [None]:
import torch

import sys
import os

In [None]:
src_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))

if src_dir not in sys.path:
    sys.path.append(src_dir)

In [None]:
from src.caldera.decomposition.dataclasses import CalderaParams
from src.caldera.utils.quantization import QuantizerFactory
from src.caldera.decomposition.alg import caldera

In [None]:
quant_factory_Q = QuantizerFactory(method="uniform", block_size=64)
quant_factor_LR = QuantizerFactory(method="uniform", block_size=64)

In [None]:
quant_params = CalderaParams(
    compute_quantized_component=True,
    compute_low_rank_factors=True,
    Q_bits=4,
    L_bits=4,
    R_bits=4,
    rank=16,
    iters=20,
    lplr_iters=5,
    activation_aware_Q=False,
    activation_aware_LR=True,
    lattice_quant_Q=False,
    lattice_quant_LR=False,
    update_order=["Q", "LR"],
    quant_factory_Q=quant_factory_Q,
    quant_factory_LR=quant_factor_LR,
    rand_svd=False
)

In [None]:
torch.manual_seed(42)

W = torch.rand(1024, 1024)
X = torch.randn(1024, 2048)
H = torch.matmul(X, X.T)

In [None]:
caldera_decomposition = caldera(
    quant_params=quant_params,
    W=W,
    H=H,
    device="cpu",
    use_tqdm=True,
    scale_W=True
)

In [None]:
print(f"caldera_decomposition.Q.shape: {caldera_decomposition.Q.shape}")
print(f"caldera_decomposition.L.shape: {caldera_decomposition.L.shape}")
print(f"caldera_decomposition.R.shape: {caldera_decomposition.R.shape}")