# UNSLOTH CHALLENGE 1 SUBMISSION : Convert nf4 to Triton.

## Problem statement

---
---
---
<a name="NF4"></a>
## A) Convert `nf4` to Triton. [Difficulty: Hard] [Max points: 14]

1. Goal: Convert a `nf4` quantized tensor into `fp16` or `bf16` into a *single* Triton kernel The double dequant of the `absmax` and weight forming must be done in 1 Triton kernel. Must work on Tesla T4.
2. Must be faster than Unsloth's `fast_dequantize` by 1.15x or more, and not use large intermediate memory buffers.
3. Must not use `torch.compile`, but can use `trace.enabled` to help on writing Triton kernels.
4. Good material: [Unsloth `fast_dequantize` function](https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/utils.py#L128), also [bitsandbytes `dequantize_blockwise`](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/86b6c37a8ad448230cedb60753f63150b603a112/bitsandbytes/functional.py#L958)
5. Use `test_dequantize_function` to test your implementation.
6. No CUDA allowed. Custom CUDA inside of the Triton is allowed.
7. Watch Tim's videos on Youtube: [8-bit Optimizers](https://www.youtube.com/watch?v=2ETNONas068)

## Evaluation parameters 

## Marking Criteria for A) Max points = 14
```python
if attemped_A:
    A_score = 0
    if single_triton_kernel: A_score += 3
    speedup = old_time / new_time
    if speedup <= 1.00: A_score -= 3
    if speedup >= 1.05: A_score += 1
    if speedup >= 1.10: A_score += 2
    if speedup >= 1.15: A_score += 2
    if kernel_works_in_torch_compile: A_score += 1
    else: A_score -= 1
    if custom_asm_works: A_score += 3
    if uses_cache_eviction: A_score += 1
    if tested_in_f16_and_bf16: A_score += 1
    else: A_score -= 1
    final_score += A_score
else:
    final_score += 0
```

lets load up the basic libraries

In [1]:
!pip install triton

Collecting triton
  Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.1/253.1 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.2.0


phase one - get it to compile

In [None]:
import torch
from triton import jit, cdiv
import triton.language as tl

# Kernel: Traverses the input tensor in blocks and copies each element after a simple cast.
@jit
def _your_dequantize_nf4_kernel(weight_ptr, quant_state_ptr, output_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    # Get the unique program (block) ID.
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    # Create a vector of indices for the block.
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to avoid out-of-bound accesses.
    mask = offsets < num_elements

    # Load values from the input weight tensor.
    # (Note: Currently, we ignore quant_state in the arithmetic.)
    values = tl.load(weight_ptr + offsets, mask=mask)
    # For now, simply cast the values to float16 and store in output.
    tl.store(output_ptr + offsets, tl.cast(values, tl.float16), mask=mask)

def _your_dequantize_nf4(weight, quant_state):
    # Debug prints on the host side
    print("=== Starting _your_dequantize_nf4 ===")
    print("Weight tensor shape:", weight.shape)
    print("Weight tensor dtype:", weight.dtype)
    try:
        print("Quant state shape:", quant_state.shape)
    except AttributeError:
        print("Quant state does not have a shape attribute (likely not a tensor).")
    
    # Total number of elements in the weight tensor.
    num_elements = weight.numel()
    print("Total number of elements:", num_elements)
    
    # Allocate an output tensor on the same device, with fp16 precision.
    output = torch.empty(num_elements, dtype=torch.float16, device=weight.device)
    
    # Determine grid size based on BLOCK_SIZE.
    grid = lambda meta: (cdiv(num_elements, meta['BLOCK_SIZE']),)
    
    # Launch the kernel.
    _your_dequantize_nf4_kernel[grid](weight, quant_state, output, num_elements, BLOCK_SIZE=1024)
    
    # Synchronize and print a small sample of the output for debugging.
    torch.cuda.synchronize()
    print("Kernel execution complete. Output sample (first 10 elements):", output[:10])
    print("=== Finished _your_dequantize_nf4 ===")
    return output

def your_dequantize_nf4(weight):
    # weight is expected to be an object with attributes 'weight.data' and 'weight.quant_state'
    print(">>> Entering your_dequantize_nf4")
    output = _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)
    print(">>> Exiting your_dequantize_nf4")
    return output

# For debugging purposes, you can create dummy inputs as follows:
if __name__ == '__main__':
    # Create a dummy quantized weight tensor (simulate with uint8 data).
    dummy_weight = torch.randint(0, 255, (1024,), dtype=torch.uint8, device="cuda")
    # Create a dummy quant_state tensor (its content is not used in this phase).
    dummy_quant_state = torch.empty((1,), dtype=torch.uint8, device="cuda")
    
    # Wrap dummy_weight in an object with the expected attributes.
    class DummyWeight:
        def __init__(self, weight, quant_state):
            self.weight = type("W", (), {"data": weight, "quant_state": quant_state})
    
    dummy_obj = DummyWeight(dummy_weight, dummy_quant_state)
    
    # Run our dequantization kernel skeleton.
    output = your_dequantize_nf4(dummy_obj)
    print("Final output (first 10 elements):", output[:10])


double dequantisation

In [3]:
import torch
from triton import jit, cdiv
import triton.language as tl

@jit
def _your_dequantize_nf4_kernel(
    weight_ptr, 
    quant_absmax_ptr, 
    quant_code_ptr, 
    quant_offset_ptr, 
    state2_absmax_ptr,
    state2_code_ptr,
    output_ptr,
    N: tl.constexpr,  # Total number of dequantized elements (i.e. weight.numel() * 2)
    BLOCK_SIZE: tl.constexpr,
):
    # Calculate global indices for the dequantized elements
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    # Map dequantized index 'i' to the packed uint8 index and nibble index
    packed_indices = offsets // 2       # Two dequantized values per packed byte
    nibble_selector = offsets % 2         # 0 => lower nibble, 1 => upper nibble

    # Load the packed byte
    byte_val = tl.load(weight_ptr + packed_indices, mask=mask)

    # Unpack the nibble: lower nibble if selector==0, else upper nibble
    lower = byte_val & 0xF
    higher = byte_val >> 4
    q_val = tl.where(nibble_selector == 0, lower, higher)

    # Determine block indices for the two sets of quantization parameters
    block1 = offsets // 64   # For quant_state.absmax, code, offset (block size = 64)
    block2 = offsets // 256  # For state2.absmax and state2.code (block size = 256)

    # Load block parameters for each dequantized element
    amax = tl.load(quant_absmax_ptr + block1, mask=mask)
    code_val = tl.load(quant_code_ptr + block1, mask=mask)
    offset_val = tl.load(quant_offset_ptr + block1, mask=mask)
    s2_amax = tl.load(state2_absmax_ptr + block2, mask=mask)
    s2_code = tl.load(state2_code_ptr + block2, mask=mask)

    # Compute scales: first scale from quant_state and then from state2
    scale1 = tl.cast(amax, tl.float32) / code_val  # Cast amax from uint8 to float32
    scale2 = s2_amax / s2_code

    # Perform dequantization: adjust the nibble by subtracting offset, then scale
    result = (tl.cast(q_val, tl.float32) - offset_val) * scale1 * scale2

    # Store the dequantized value as fp16 (or bf16 with minor modifications)
    tl.store(output_ptr + offsets, tl.cast(result, tl.float16), mask=mask)

def _your_dequantize_nf4(weight, quant_state):
    print("=== Starting _your_dequantize_nf4 ===")
    print("Weight tensor shape:", weight.shape)
    print("Weight tensor dtype:", weight.dtype)
    try:
        print("Quant state absmax shape:", quant_state.absmax.shape)
    except AttributeError:
        print("Quant state missing attributes!")
    
    # Total dequantized elements: each uint8 holds 2 nf4 values
    N = weight.numel() * 2
    print("Total number of dequantized elements:", N)
    
    # Allocate output tensor (dequantized values) on the same device
    output = torch.empty(N, dtype=torch.float16, device=weight.device)
    
    # Extract quant_state arrays
    quant_absmax = quant_state.absmax.contiguous()
    quant_code = quant_state.code.contiguous()
    quant_offset = quant_state.offset.contiguous()
    state2_absmax = quant_state.state2.absmax.contiguous()
    state2_code = quant_state.state2.code.contiguous()
    
    # Set up kernel parameters and grid configuration
    BLOCK_SIZE = 1024
    grid = lambda meta: (cdiv(N, meta['BLOCK_SIZE']),)
    
    # Launch the Triton kernel
    _your_dequantize_nf4_kernel[grid](
        weight, 
        quant_absmax, 
        quant_code, 
        quant_offset, 
        state2_absmax,
        state2_code,
        output, 
        N, 
        BLOCK_SIZE=BLOCK_SIZE
    )
    
    torch.cuda.synchronize()
    print("Kernel execution complete. Output sample (first 10 elements):", output[:10])
    print("=== Finished _your_dequantize_nf4 ===")
    return output

def your_dequantize_nf4(weight):
    print(">>> Entering your_dequantize_nf4")
    output = _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)
    print(">>> Exiting your_dequantize_nf4")
    return output

# ----------------- Debug/Test Harness -----------------

if __name__ == '__main__':
    # Create a dummy quantized weight tensor with 1024 uint8 elements (packed nf4 data)
    dummy_weight = torch.randint(0, 255, (1024,), dtype=torch.uint8, device="cuda")
    
    # Total dequantized elements = 1024 * 2 = 2048
    num_dequantized = dummy_weight.numel() * 2
    # Calculate number of blocks for the first and second quantization states
    num_blocks1 = (num_dequantized + 63) // 64   # block size = 64
    num_blocks2 = (num_dequantized + 255) // 256   # block size = 256
    
    # Create dummy parameters (with simple values for testing)
    dummy_quant_absmax = torch.randint(1, 255, (num_blocks1,), dtype=torch.uint8, device="cuda")
    dummy_quant_code = torch.rand(num_blocks1, dtype=torch.float32, device="cuda") * 0.1 + 0.9  # roughly near 1.0
    dummy_quant_offset = torch.rand(num_blocks1, dtype=torch.float32, device="cuda") * 0.1  # small offsets
    
    dummy_state2_absmax = torch.rand(num_blocks2, dtype=torch.float32, device="cuda") * 2.0 + 0.5  # some scale
    dummy_state2_code = torch.rand(num_blocks2, dtype=torch.float32, device="cuda") * 0.1 + 0.9     # roughly near 1.0

    # Create a dummy state2 object with required attributes
    class DummyState2:
        pass
    dummy_state2 = DummyState2()
    dummy_state2.absmax = dummy_state2_absmax
    dummy_state2.code = dummy_state2_code
    dummy_state2.blocksize = 256

    # Create a dummy quant_state object with the required attributes and nested state2
    class DummyQuantState:
        pass
    dummy_quant_state = DummyQuantState()
    dummy_quant_state.absmax = dummy_quant_absmax
    dummy_quant_state.code = dummy_quant_code
    dummy_quant_state.offset = dummy_quant_offset
    dummy_quant_state.blocksize = 64
    dummy_quant_state.state2 = dummy_state2

    # Wrap dummy_weight in an object that mimics the expected structure
    class DummyWeight:
        def __init__(self, weight, quant_state):
            self.weight = type("W", (), {"data": weight, "quant_state": quant_state})
    
    dummy_obj = DummyWeight(dummy_weight, dummy_quant_state)
    
    # Run the dequantization kernel
    output = your_dequantize_nf4(dummy_obj)
    print("Final output (first 10 elements):", output[:10])


>>> Entering your_dequantize_nf4
=== Starting _your_dequantize_nf4 ===
Weight tensor shape: torch.Size([1024])
Weight tensor dtype: torch.uint8
Quant state absmax shape: torch.Size([32])
Total number of dequantized elements: 2048
Kernel execution complete. Output sample (first 10 elements): tensor([ 1.7860e+03, -4.2456e-01,  2.7425e+02,  5.4900e+02,  6.8650e+02,
         8.2400e+02,  6.8650e+02,  1.7860e+03,  9.6150e+02,  1.0990e+03],
       device='cuda:0', dtype=torch.float16)
=== Finished _your_dequantize_nf4 ===
>>> Exiting your_dequantize_nf4
Final output (first 10 elements): tensor([ 1.7860e+03, -4.2456e-01,  2.7425e+02,  5.4900e+02,  6.8650e+02,
         8.2400e+02,  6.8650e+02,  1.7860e+03,  9.6150e+02,  1.0990e+03],
       device='cuda:0', dtype=torch.float16)
