In [None]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [None]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

In [None]:
from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)

In [None]:
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
test_dequantize(peft_dequantize)

In [None]:
from triton import jit
import triton
import triton.language as tl

@triton.jit
def _your_dequantize_nf4_kernel():
    ### TRITON CODE GOES HERE
    return

def _your_dequantize_nf4(weight, quant_state):
    ### SETUP TRITON LAUNCH HERE
    return None

def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)

In [None]:
import torch
import torch._inductor.config as inductor_config

# Enable Inductor debug output (which may include Triton if used)
inductor_config.debug = True
inductor_config.verbose = True
inductor_config.trace.enabled = True  # Enable tracing for Inductor
inductor_config.triton.cudagraphs = False  # Disable cudagraphs for clearer output

@torch.compile(backend="inductor")
def multiply_2(x):
    return x * 2

# Test with a sample input
x = torch.randn(10, device="cuda")
y = multiply_2(x)
print(y)

In [None]:
import triton
import triton.language as tl
import torch

@triton.jit
def fused_dequantize_kernel(
    a_ptr,              # Input: packed 4-bit tensor (uint8)
    quant_absmax_ptr,   # Input: quant_state.absmax (uint8)
    state2_code_ptr,    # Input: quant_state.state2.code (float32)
    state2_absmax_ptr,  # Input: quant_state.state2.absmax (float32)
    code_ptr,           # Input: quant_state.code (float32)
    output_ptr,         # Output: dequantized result (bfloat16)
    offset,             # Input: quant_state.offset (float32)
    n_packed_elements,  # Number of uint8 elements in A
    blocksize,          # Elements per block (e.g., 256)
    BLOCK_SIZE: tl.constexpr  # Number of packed elements processed per thread block
):
    # Program ID: each thread block processes one output block
    pid = tl.program_id(axis=0)
    out_block_idx = pid
    num_out_blocks = n_packed_elements * 2 // blocksize

    # Early return if block index is out of bounds
    if out_block_idx >= num_out_blocks:
        return

    # Compute the scaling factor for this block
    absmax_idx = out_block_idx * blocksize
    quant_absmax_val = tl.load(quant_absmax_ptr + absmax_idx).to(tl.int32)
    code_val = tl.load(state2_code_ptr + quant_absmax_val)
    state2_absmax_val = tl.load(state2_absmax_ptr + out_block_idx)
    scaling = code_val * state2_absmax_val + offset

    # Process packed elements for this block
    packed_per_block = blocksize // 2  # Number of uint8 elements per block
    packed_start = out_block_idx * packed_per_block
    packed_offsets = packed_start + tl.arange(0, BLOCK_SIZE)

    # Mask to prevent out-of-bounds access
    packed_mask = packed_offsets < n_packed_elements

    # Load packed uint8 values
    packed_vals = tl.load(a_ptr + packed_offsets, mask=packed_mask, other=0).to(tl.uint8)

    # Unpack 4-bit values
    val0 = (packed_vals >> 4).to(tl.int32)    # High 4 bits
    val1 = (packed_vals & 0b1111).to(tl.int32) # Low 4 bits

    # Lookup dequantized values
    result0 = tl.load(code_ptr + val0, mask=packed_mask, other=0.0)
    result1 = tl.load(code_ptr + val1, mask=packed_mask, other=0.0)

    # Apply scaling
    result0 = result0 * scaling
    result1 = result1 * scaling

    # Compute output offsets (interleaved: val0, val1, val0, val1, ...)
    out_start = out_block_idx * blocksize
    out_offsets0 = out_start + 2 * tl.arange(0, BLOCK_SIZE)
    out_offsets1 = out_offsets0 + 1
    out_mask = out_offsets1 < (n_packed_elements * 2)

    # Store results as bfloat16
    tl.store(output_ptr + out_offsets0, result0.to(tl.bfloat16), mask=out_mask)
    tl.store(output_ptr + out_offsets1, result1.to(tl.bfloat16), mask=out_mask)

# Host function to launch the kernel
def triton_fused_dequantize(A, quant_state):
    """
    Fused dequantization of a 4-bit packed tensor using Triton.
    
    Args:
        A: torch.Tensor of shape [1, n_packed_elements], dtype uint8, containing packed 4-bit values
        quant_state: Object containing:
            - absmax: torch.Tensor of shape [total_elements], dtype uint8
            - state2.code: torch.Tensor, lookup table for absmax dequantization
            - state2.absmax: torch.Tensor of shape [num_blocks], per-block scaling
            - code: torch.Tensor, lookup table for 4-bit values
            - offset: float, scaling offset
            - blocksize: int, elements per block (e.g., 256)
            - shape: tuple, desired output shape (e.g., [M, N])
    
    Returns:
        torch.Tensor of shape quant_state.shape, dtype bfloat16
    """
    assert A.is_cuda, "Input tensor must be on CUDA"
    n_packed_elements = A.shape[1]
    output = torch.empty(quant_state.shape, dtype=torch.bfloat16, device="cuda").t().contiguous()

    blocksize = quant_state.blocksize
    num_out_blocks = (n_packed_elements * 2) // blocksize
    BLOCK_SIZE = blocksize // 2  # Number of packed elements per thread block

    grid = (num_out_blocks,)
    fused_dequantize_kernel[grid](
        A,
        quant_state.absmax,
        quant_state.state2.code,
        quant_state.state2.absmax,
        quant_state.code,
        output,
        quant_state.offset,
        n_packed_elements,
        blocksize,
        BLOCK_SIZE=BLOCK_SIZE
    )
    return output.t()  # Reshape to original dimensions