## LPLR-Q: Adding an Activation Between L and R

### Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glog
glog.setLevel("WARN")

In [None]:
from lplr_llm.activation_aware.weight_compression import *
from lplr_llm.activation_aware.layer_quantization import *
from lplr_llm.activation_aware.layer_input_data import get_sublayer_input
from torch import nn
import torch
from lib.utils import get_hadK
from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda
from tqdm import tqdm
from lib.algo.quip import RHT_W, RHT_H
from enum import Enum

### Experiment Parameters

In [None]:
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
# HESSIAN_SAVE_PATH = "/media/hdd1/lplr-q-hessians/llama-2-7b"
HESSIAN_SAVE_PATH = "../../data/hessians/llama-2-7b/"
DEVICE = "cuda:0"
RANK = 128
QLR_ITERS = 30
LPLR_ITERS = 10

LAYER = 5
SUBLAYER = TransformerSubLayers.QUERY

### Get the Initial Condition for Q, L, and R using LPLR-Q

The three weight compressor options below compute $Q + LR$, where $Q$ has 2 bits of precision and $L$, $R$ are half precision. Below, there are three options for quantization of $Q$ (LDLQ, lattice quantization, and RTN uniform quantization).

**Option 1**: LPLR-LDLQ with Hessian Downdate (should be a very good approximation)

In [None]:
# weight_comp = ActivationAwareWeightCompressor(
#     model_params=ModelParameters(
#         base_model=BASE_MODEL
#     ),
#     data_params=DataParameters(),
#     hessian_save_path=HESSIAN_SAVE_PATH,
#     quant_params=ActivationAwareQuantParams(
#         Q_bits=2,
#         L_bits=16, R_bits=16,
#         lattice_quant_LR=False,
#         rank=RANK,
#         activation_aware_Q=True,
#         activation_aware_LR=True,
#         hadamard_transform=True,
#         iters=3,
#         lplr_iters=LPLR_ITERS,
#         rand_svd=True,
#         Q_hessian_downdate=True,
#         update_order=["Q", "LR"]
#     ),
#     compute_hessians=False,
#     quant_device=DEVICE,
# )

**Option 2**: LPLR-Lattice Quant (slightly worse approximatiom)

In [None]:
# weight_comp = ActivationAwareWeightCompressor(
#     model_params=ModelParameters(
#         base_model=BASE_MODEL
#     ),
#     data_params=DataParameters(),
#     hessian_save_path=HESSIAN_SAVE_PATH,
#     quant_params=ActivationAwareQuantParams(
#         Q_bits=2,
#         L_bits=16, R_bits=16,
#         lattice_quant_LR=False,
#         rank=RANK,
#         activation_aware_Q=False,
#         activation_aware_LR=True,
#         hadamard_transform=True,
#         iters=QLR_ITERS,
#         lplr_iters=LPLR_ITERS,
#         rand_svd=True,
#         update_order=["Q", "LR"]
#     ),
#     compute_hessians=False,
#     quant_device=DEVICE,
# )

**Option 3**: LPLR-Uniform ($Q$ should be a pretty bad approximation of $W - LR$)

In [None]:
weight_comp = ActivationAwareWeightCompressor(
    model_params=ModelParameters(
        base_model=BASE_MODEL
    ),
    data_params=DataParameters(),
    hessian_save_path=HESSIAN_SAVE_PATH,
    quant_params=ActivationAwareQuantParams(
        Q_bits=2,
        L_bits=16, R_bits=16,
        lattice_quant_LR=False,
        rank=RANK,
        activation_aware_Q=False,
        lattice_quant_Q=False,
        quant_factory_Q=QuantizerFactory("uniform"),
        activation_aware_LR=True,
        hadamard_transform=True,
        iters=QLR_ITERS,
        lplr_iters=LPLR_ITERS,
        rand_svd=True,
        Q_hessian_downdate=True,
        update_order=["Q", "LR"]
    ),
    compute_hessians=False,
    quant_device=DEVICE,
)

In [None]:
layer_quant = weight_comp.get_layer_quantizer(LAYER)
layer_quant.compress_sublayer(SUBLAYER)

In [None]:
print(f"The Frobenius norm error ||(W - LR - Q)X||_F^2 is {round(layer_quant.min_error(SUBLAYER) * 100, 8)} percent")

### Sample Calibration Data and Compute the Layer Input

See `src/lplr_llm/activation_aware/layer_input_data` for more details. It is very similar to the code that QuIP# uses for Hessian computation.
The function `get_sublayer_input` passes the calibration datapoints through the original model and captures the input to the specific layer and sublayer that we are trying to compress. 

_Note_: this might not be the data used to compute the Hessian matrix (as the Hessian computation does not currently include saving the input embeddings)

In [None]:
BATCH_SIZE = 2

In [None]:
_, sublayer_input, attention_mask, position_ids = \
    get_sublayer_input(
        LAYER, layer_quant.sublayer_info[SUBLAYER].out_key,
        BASE_MODEL, data_params=DataParameters(batch_size=BATCH_SIZE),
        device=DEVICE
    )

### Define the Approximation as a Neural Network

This neural network layer does the following:
1. Performs an inverse Hadamard transform on the input, $X$. The Hadamard-transformed output will be denoted $X_H$. This is because the we perform the decomposition $Q + LR$ after incoherence processing: $H_2 W H_1^\top \approx Q + LR$, where $H_1$ and $H_2$ are randomized Hadamard transform matrices, $Q$ is quantized, and $L, R$ are low-rank.

2. Applies $Q$: $\text{output} \gets X_H Q^\top$.
3. Applies the low-rank factors as follows:
    - If `network_type` is `REGULAR_MLP`, $\text{output} \gets \text{output} + \sigma(X_H R^\top) L^\top$, where $\sigma$ is an activation, by default ReLU.
    - If `network_type` is `SPLIT_LR`, we split $L = \begin{pmatrix} \\ L_1 & L_2 \\ \\ \end{pmatrix}$ and $R = \begin{pmatrix} && R_1 && \\ && R_2 && \end{pmatrix}$, where $L_1$ and $L_2$ are of equal size and the same holds for $R_1$ and $R_2$.
    <br/> The output is updated as follows:  $\text{output} \gets \text{output} + \sigma(X_H R_1^\top) L_1^\top + X_H R_2^\top L_2^\top$.
    - If `network_type` is `WITH_RESIDUAL`, we add skip connections without splitting the $L$ and $R$ matrices: $\text{output} \gets \text{output} + \frac{1}{2}\sigma(X_H R^\top) L^\top + \frac{1}{2}X_H R^\top L^\top$
    - If you pass in `batchnorm=True` when instantiating the network, then a batch norm (via the default pytorch implementation) is applied after the activation: $\sigma(X_H R^\top) L^\top$ becomes $\text{BN}(\sigma(X_H R^\top)) L^\top$, where $\text{BN}$ is the batch norm function.

4. Perform a Hadamard transform on $\text{output}$ and returns the final result.

#### Explanation of the incoherence processing (steps 1 and 4 above)
For the purposes of illustration, assume that $H_1^\top W H_2 = Q + LR$ (the decomposition is exact), and that the network just computes $X_H (Q + LR)^\top$.

The network first computes $X_H = X H_1^\top$ in step 1.

Then, it computes $\text{output} = X_H W_H^\top = X H_1^\top H_1 W^\top H_2^\top = X W^\top H_2^\top$.

Step 4 computes $\text{output} H_2 = X W^\top H_2^\top H_2 = XW^\top$, which is the desired neutral network output.

In [None]:
class LRNetworkType(Enum):
    REGULAR_MLP = 0
    SPLIT_LR = 1
    WITH_RESIDUAL = 2

In [None]:
class QPlusTwoLayerNN(nn.Module):
    def __init__(
        self,
        Q, L, R,
        SU, SV,
        global_scale,
        scaleWH,
        act=nn.ReLU(),
        batchnorm=True,
        network_type=LRNetworkType.REGULAR_MLP
    ):
        super(QPlusTwoLayerNN, self).__init__()

        # Scaling parameters -- from QuIP#
        self.global_scale = global_scale
        self.scaleWH = scaleWH
        if self.scaleWH is not None:
            self.scaleWH = nn.Parameter(self.scaleWH.float(), requires_grad=False)
        
        # Quantized matrix. For the purpose of this experiment, Q is floating point, i.e.,
        # only simulated quantization was performed.
        self.Q = nn.Parameter(Q.float(), requires_grad=False)

        if network_type is LRNetworkType.WITH_RESIDUAL:
            L = L / 2
        if network_type is LRNetworkType.SPLIT_LR:
            # Split L and R, as described in the markdown above
            self.L = nn.Parameter(L.float()[:, :L.shape[1]//2], requires_grad=True)
            self.R = nn.Parameter(R.float()[:R.shape[0]//2, :], requires_grad=True)

            self.L_feedthrough = nn.Parameter(L.float()[:, L.shape[1]//2:], requires_grad=True)
            self.R_feedthrough = nn.Parameter(R.float()[R.shape[0]//2:, :], requires_grad=True)
            self.norm = nn.BatchNorm1d(L.shape[1] // 2, affine=False).to(self.L.device)
        else:
            self.L = nn.Parameter(L.float(), requires_grad=True)
            self.R = nn.Parameter(R.float(), requires_grad=True)
            self.norm = nn.BatchNorm1d(L.shape[1], affine=False).to(self.L.device)
        
        self.network_type = network_type
        self.do_batchnorm = batchnorm

        self.act = act

        # Diagonals of the randomized Hadamard transform
        self.SU = nn.Parameter(SU.float(), requires_grad=False)
        self.SV = nn.Parameter(SV.float(), requires_grad=False)

        # Hadamard matrices and sizes of the Hadamard transform
        had_left, K_left = get_hadK(len(SU))
        had_right, K_right = get_hadK(len(SV))
        self.had_left = nn.Parameter(had_left, requires_grad=False)
        self.had_right = nn.Parameter(had_right, requires_grad=False)
        self.K_left = K_left
        self.K_right = K_right

    def set_activation(self, act):
        self.act = act

    def forward(self, x):
        shape = x.shape
        x = x.view(-1, len(self.SU))
        
        # Randomized Hadamard transform (see description in markdown text above)
        x = x * self.SU 
        if self.scaleWH is not None:
            x /= self.scaleWH
        x = matmul_hadUt_cuda(x, self.had_left, self.K_left)

        # Apply Q
        output = x @ self.Q.T

        # Different options for applying L and R, as described in the markdown above
        if self.do_batchnorm:
            output += self.norm(self.act(x @ self.R.T)) @ self.L.T
        else:
            output += self.act(x @ self.R.T) @ self.L.T
        if self.network_type is LRNetworkType.SPLIT_LR:
            output += x @ self.R_feedthrough.T @ self.L_feedthrough.T
        elif self.network_type is LRNetworkType.WITH_RESIDUAL:
            output += x @ self.R.T @ self.L.T # skip over the activation

        # Another randomized Hadamard transform 
        output = matmul_hadU_cuda(output, self.had_right, self.K_right) * self.global_scale
        output = output * self.SV
        if self.scaleWH is not None:
            output *= self.scaleWH

        return output.view(*shape[:-1], len(self.SV))


In [None]:
class RelativeFroLoss(nn.Module):
    def __init__(self):
        super(RelativeFroLoss, self).__init__()

    def forward(self, output, target):
        # matrix_norm defaults to Frobenius
        return (torch.linalg.matrix_norm(target - output) / 
                torch.linalg.matrix_norm(target)).mean()

### Get Target Layer Outputs
Here, we compute $X W^\top$, for each $X$ in the calibration set. We will be computing the Frobenius norm error of the network output with respect to these target outputs.

In [None]:
sublayer_info = layer_quant.best_sublayer_info[SUBLAYER]

In [None]:
targets = torch.zeros(sublayer_input.shape[0], sublayer_input.shape[1], len(sublayer_info.SV)).to(sublayer_input.dtype)
for i in tqdm(range(0, sublayer_input.shape[0], BATCH_SIZE)):
    targets[i:i+BATCH_SIZE] = (sublayer_input[i:i+BATCH_SIZE].to(DEVICE) @ sublayer_info.W.T.to(DEVICE).half()).cpu()

    gc.collect()
    torch.cuda.empty_cache()

In [None]:
train = torch.utils.data.TensorDataset(sublayer_input, targets)
train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

### First, Perform Rank-Constrained Regression on $L$, $R$ with the Sampled Data
The data is not necessarily the data used to compute the Hessian, and we want to make sure we start with an optimal $L$, $R$ so that we have a fair comparison.

In [None]:
# First, compute the Hessian
H = torch.zeros(sublayer_input.shape[2], sublayer_input.shape[2]).to(DEVICE)
for i in tqdm(range(sublayer_input.shape[0])):
    x = sublayer_input[i].to(DEVICE).to(torch.float64)
    H += x.T @ x

    gc.collect()
    torch.cuda.empty_cache()
H /= H.abs().max()

In [None]:
# Normalization and Regularization (taken from QuIP#)
H.div_(torch.diag(H).mean())
H = regularize_H(H, H.shape[0], layer_quant.quant_params.quip_args.sigma_reg)

In [None]:
# Incoherence Process
Hr = RHT_H(H, sublayer_info.SU)
Wr = RHT_W(sublayer_info.W.to(DEVICE) / sublayer_info.global_scale, sublayer_info.SU, sublayer_info.SV)

# Compute the symmetric square root: ||(W - LR - Q)X.T||_F^2 = ||(W - LR - Q) H^{1/2}||_F^2
eigH = torch.linalg.eigh(Hr)
H_sqrt = (eigH.eigenvectors @
            torch.diag(torch.sqrt(eigH.eigenvalues)) @
            eigH.eigenvectors.T)

# Rank-constrained regression
residual = Wr - sublayer_info.Q
Y = residual @ H_sqrt @ eigH.eigenvectors
U, Sigma, Vh = torch.linalg.svd(Y, full_matrices=False)

rank_const_regression_L = U[:, :RANK]
rank_const_regression_R = torch.diag(Sigma[:RANK]) @ \
    Vh[:RANK, :] @ \
    torch.diag(1 / eigH.eigenvalues.sqrt()) @ eigH.eigenvectors.T

sublayer_info.L = rank_const_regression_L
sublayer_info.R = rank_const_regression_R

### Initialize the NN

In [None]:
network_type = LRNetworkType.REGULAR_MLP
batchnorm = True

In [None]:
# This is necessary, otherwise rank_const_regression_L will be modified by updates to layer.L, e.g.
sublayer_info.L = torch.Tensor(rank_const_regression_L.tolist()).to(DEVICE)
sublayer_info.R = torch.Tensor(rank_const_regression_R.tolist()).to(DEVICE)

# Instantiate layer
layer = QPlusTwoLayerNN(
    Q=sublayer_info.Q, L=sublayer_info.L, R=sublayer_info.R,
    SU=sublayer_info.SU, SV=sublayer_info.SV,
    global_scale=sublayer_info.global_scale,
    scaleWH=sublayer_info.scaleWH,
    network_type=network_type,
    batchnorm=batchnorm
)
loss_fn = RelativeFroLoss()

For comparison, compute the activation-aware error for the linear model.

In [None]:
old_act = layer.act
layer.set_activation(nn.Identity())
layer.do_batchnorm = False

total_loss = 0
n = 0
with torch.no_grad():
    for x, target in tqdm(train_loader):
        output = layer(x.to(DEVICE).float())
        total_loss += loss_fn(output, target.to(DEVICE).float()) 
        n += 1
layer.act = old_act
layer.do_batchnorm = True

print(f"The error for the linear model is {(total_loss / n).item()}")

### Experiment 1: PReLU with initial slope=1
This should be strictly better than the linear model, but by how much?

In [None]:
layer.set_activation(nn.PReLU(init=1).to(DEVICE))

In [None]:
optimizer = torch.optim.AdamW(
    params=[param for _, param in layer.named_parameters() if param.requires_grad],
    lr=1e-4,
    amsgrad=True
)
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [None]:
PRINT_FREQ=2

torch.set_grad_enabled(True)
for epoch in range(50):
    n = 0
    total_loss = 0
    for x, target in tqdm(train_loader):
        optimizer.zero_grad()
        output = layer(x.to(DEVICE).float())
        loss = loss_fn(output, target.to(DEVICE).float())
        total_loss += loss.item()
        n += 1

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    if epoch % PRINT_FREQ == 0:
        print(total_loss / n)

In [None]:
torch.set_grad_enabled(False)
total_loss = 0
n = 0
with torch.no_grad():
    for x, target in tqdm(train_loader):
        output = layer(x.to(DEVICE).float())
        total_loss += loss_fn(output, target.to(DEVICE).float()) 
        n += 1
(total_loss / n).item()

In [None]:
layer.act.weight

### Experiment 2: ReLU

In [None]:
# This is necessary, otherwise original_L will be modified by updates to layer.L, e.g.
sublayer_info.L = torch.Tensor(rank_const_regression_L.tolist()).to(DEVICE)
sublayer_info.R = torch.Tensor(rank_const_regression_R.tolist()).to(DEVICE)

# Instantiate layer
layer = QPlusTwoLayerNN(
    Q=sublayer_info.Q, L=sublayer_info.L, R=sublayer_info.R,
    SU=sublayer_info.SU, SV=sublayer_info.SV,
    global_scale=sublayer_info.global_scale,
    scaleWH=sublayer_info.scaleWH
)
loss_fn = RelativeFroLoss()

In [None]:
optimizer = torch.optim.AdamW(
    params=[param for _, param in layer.named_parameters() if param.requires_grad],
    lr=1e-4,
)
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [None]:
PRINT_FREQ=2

torch.set_grad_enabled(True)
for epoch in range(20):
    n = 0
    total_loss = 0
    for x, target in tqdm(train_loader):
        optimizer.zero_grad()
        output = layer(x.to(DEVICE).float())
        loss = loss_fn(output, target.to(DEVICE).float())
        total_loss += loss.item()
        n += 1

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    if epoch % PRINT_FREQ == 0:
        print(total_loss / n)