# NF4 Dequantize

In [2]:
import os; os.environ['TRITON_INTERPRET'] = '1'
import torch
import triton
import triton.language as tl
from IPython.core.debugger import set_trace

torch.set_printoptions(linewidth=120)

def _b(*pids):
    "breakpoint on pids"
    if all(tl.program_id(i) == pid for i, pid in enumerate(pids)):
        set_trace()

def cdiv(x, y): return (x + y - 1) // y

In [3]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# ---

from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
    if quant_state is None: return W
    if type(quant_state) is not list:
        # New quant_state as a class
        # https://github.com/TimDettmers/bitsandbytes/pull/763/files
        absmax     = quant_state.absmax
        shape      = quant_state.shape
        dtype      = quant_state.dtype
        blocksize  = quant_state.blocksize
        offset     = quant_state.offset
        state2     = quant_state.state2
        absmax2    = state2.absmax
        code2      = state2.code
        blocksize2 = state2.blocksize
    else:
        # Old quant_state as a list of lists
        absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
        offset, state2 = compressed_stats
        absmax2, code2, blocksize2, _, _, _, _ = state2
    pass

    n_elements_absmax = absmax.numel()

    # Create weight matrix
    if use_global_buffer:

        # Use same buffers for faster inference
        size = shape[0]*shape[1]
        global WEIGHT_BUFFER
        global ABSMAX_BUFFER
        if WEIGHT_BUFFER is None:
            WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False)
            ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0", requires_grad = False)

        if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
        if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)

        out = WEIGHT_BUFFER[:size].view(shape)
        out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
    else:
        if out is None:
            out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False)
        else:
            assert(out.shape == shape)
            assert(out.dtype == dtype)
        out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
    pass

    # Do dequantization
    ptr_out_absmax = get_ptr(out_absmax)
    cdequantize_blockwise_fp32(
        get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
        ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
    )
    out_absmax += offset

    fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
         cdequantize_blockwise_bf16_nf4
    fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
       ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)

    # Careful returning transposed data
    is_transposed = (True if W.shape[0] == 1 else False)
    return out.t() if is_transposed else out

def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

ModuleNotFoundError: No module named 'transformers'

In [None]:
@triton.jit
def _your_dequantize_nf4_kernel():
    ### TRITON CODE GOES HERE
    return

def _your_dequantize_nf4(weight, quant_state):
    ### SETUP TRITON LAUNCH HERE
    return None

def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)