This notebook loads a pre-trained HyperLightspeedBench (HLB) neural network for image classification on CIFAR-10 dataset and experiments with low-precision low-rank approximation of the weight matrices of the neural network.

In [None]:
import sys
import os

os.chdir("..")
sys.path.extend([os.path.abspath("src")])

from math import ceil, floor
from typing import Callable


In [None]:
import torch
import torch.nn as nn
import numpy as np
from loguru import logger

In [None]:
from src.hlb.speedyresnet import (
    SpeedyResNet,
    Conv,
    ConvGroup,
    TemperatureScaler,
    Linear,
    BatchNorm,
    FastGlobalMaxPooling,
)
from src.hlb.config import hyp
from src.hlb.utils import get_batches

from src.lplr.quantizers import quantize
from src.lplr.compressors import lplr

In [None]:
def evaluate_models(model: nn.Module, benchmark_model):

    model.eval()
    device = torch.device("cuda:0")
    model = model.to(device).float()
    benchmark_model = benchmark_model.to(device).float()
    data = torch.load(hyp["misc"]["data_location"], map_location=device)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2, reduction="none")

    eval_batchsize = 2500
    from collections import defaultdict
    loss_list_val, acc_list = defaultdict(list), defaultdict(list)

    model_types = ("Test model", "Benchmark Model")
    with torch.no_grad():
        for inputs, targets in get_batches(
            data, key="eval", batchsize=eval_batchsize
        ):
            
            input_tensors = inputs.float()
            
            for mm, mt in zip((model, benchmark_model), model_types):
                
                outputs = mm(input_tensors)
                loss_val = loss_fn(outputs, targets).float().mean()
                acc_val = (outputs.argmax(-1) == targets.argmax(-1)).float().mean()
                logger.trace(f"loss {loss_val:.3f} acc {acc_val:.3f} for {mt}")
            
                loss_list_val[mt].append(loss_val)
                acc_list[mt].append(acc_val)

    for mt in model_types:
        avg_val_acc = torch.mean(torch.tensor(acc_list[mt])).item()
        avg_val_loss = torch.mean(torch.tensor(loss_list_val[mt])).item()

        logger.debug(f"Avg Validation Accuracy: {avg_val_acc:.2f} and Avg Validation Loss {avg_val_loss} for {mt}")
    return (torch.mean(torch.tensor(acc_list["Test model"])).item(), torch.mean(torch.tensor(loss_list_val["Test model"])).item())

In [None]:
# for (name, param) in model.named_parameters():
#     print(name, param.shape, param.requires_grad)

In [None]:
# model_param = model.get_parameter("net_dict.initial_block.whiten.weight").to("cpu").detach().numpy()
# param_shape = model_param.shape
# reshaped_param = model_param.reshape(param_shape[0], param_shape[1], -1)
# rp_copy = np.copy(reshaped_param)
# for idxs in range(reshaped_param.shape[-1]):
#     P = reshaped_param[:, :, idxs]
#     # P = np.interp(P, (P.min(), P.max()), (-1, 1))
#     rp_copy[:, :, idxs] = lplr(P, 3, 32, 32)

# rp_final = rp_copy.reshape(param_shape)

# errs = []
# for row in range(param_shape[-2]):
#     for col in range(param_shape[-1]):
#         errs.append(error(model_param[:, :, row, col], rp_final[:, :, row, col]))


In [None]:
def quantize_layers(model: nn.Module, compressor: Callable[[np.ndarray], np.ndarray] = lplr) -> nn.Module:
    from math import ceil, floor
    from copy import deepcopy

    output_model = deepcopy(model)
    # b1 = 8
    # b2 = 8
    # frac = 0.9
    for name, param in output_model.named_parameters():
        model_param = param.to("cpu").detach().numpy()
        param_shape = model_param.shape
        logger.trace(f"Applying LPLR on {name} with shape {param_shape}")
        if param.ndim >= 2:
            reshaped_param = model_param.reshape(param_shape[0], param_shape[1], -1)
            out_param = np.zeros_like(reshaped_param)
            for idxs in range(reshaped_param.shape[-1]):
                # out_param[:, :, idxs] = compressor(
                #     reshaped_param[:, :, idxs],
                #     ceil(frac * rank),
                #     b1,
                #     b2,
                # )
                out_param[:, :, idxs] = compressor(reshaped_param[:, :, idxs])
            param.data = torch.from_numpy(out_param.reshape(param_shape))
        elif param.ndim == 1:
            param.data = torch.from_numpy(quantize(model_param))
    return output_model


In [None]:
model_location = "artifacts/hlb/checkpoints/trained-speedyresnet.pt"
model = torch.load(model_location, map_location="cpu")

In [None]:
# pname = "net_dict.initial_block.whiten.weight"
# torch.linalg.norm(quantized_model.get_parameter(pname) - model.get_parameter(pname)) / torch.linalg.norm(model.get_parameter(pname))

In [None]:
from functools import partial
def comp(param, fraction):
    from math import ceil
    rank = np.min(param.shape)
    try:
        out_rank = ceil(fraction * rank)
        assert isinstance(out_rank, int)
    except AssertionError as ae:
        logger.error(f"Wrong Out rank: {out_rank}")
        raise ae
    logger.debug(f"Shape: {param.shape} Input Rank: {rank} Output Rank: {out_rank}")
    return lplr(param, out_rank, 32, 32)

In [None]:
quantized_models = {f: quantize_layers(model, compressor=partial(comp, fraction=float(f))) for f in ("0.9", "0.99", "0.999", "1.0")}

In [None]:
for f, qm in quantized_models.items():
    (acc, loss) = evaluate_models(qm, model)
    logger.info(f"Computing fraction {f} with accuracy {acc:.3f} and loss {loss:.3f}")