# UNSLOTH CHALLENGE 1 SUBMISSION : Convert nf4 to Triton.

## Problem statement

---
---
---
<a name="NF4"></a>
## A) Convert `nf4` to Triton. [Difficulty: Hard] [Max points: 14]

1. Goal: Convert a `nf4` quantized tensor into `fp16` or `bf16` into a *single* Triton kernel The double dequant of the `absmax` and weight forming must be done in 1 Triton kernel. Must work on Tesla T4.
2. Must be faster than Unsloth's `fast_dequantize` by 1.15x or more, and not use large intermediate memory buffers.
3. Must not use `torch.compile`, but can use `trace.enabled` to help on writing Triton kernels.
4. Good material: [Unsloth `fast_dequantize` function](https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/utils.py#L128), also [bitsandbytes `dequantize_blockwise`](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/86b6c37a8ad448230cedb60753f63150b603a112/bitsandbytes/functional.py#L958)
5. Use `test_dequantize_function` to test your implementation.
6. No CUDA allowed. Custom CUDA inside of the Triton is allowed.
7. Watch Tim's videos on Youtube: [8-bit Optimizers](https://www.youtube.com/watch?v=2ETNONas068)

## Evaluation parameters 

## Marking Criteria for A) Max points = 14
```python
if attemped_A:
    A_score = 0
    if single_triton_kernel: A_score += 3
    speedup = old_time / new_time
    if speedup <= 1.00: A_score -= 3
    if speedup >= 1.05: A_score += 1
    if speedup >= 1.10: A_score += 2
    if speedup >= 1.15: A_score += 2
    if kernel_works_in_torch_compile: A_score += 1
    else: A_score -= 1
    if custom_asm_works: A_score += 3
    if uses_cache_eviction: A_score += 1
    if tested_in_f16_and_bf16: A_score += 1
    else: A_score -= 1
    final_score += A_score
else:
    final_score += 0
```

lets load up the basic libraries

In [1]:
!pip install triton

Collecting triton
  Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.1/253.1 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.2.0


phase one - get it to compile

In [None]:
import torch
from triton import jit, cdiv
import triton.language as tl

# Kernel: Traverses the input tensor in blocks and copies each element after a simple cast.
@jit
def _your_dequantize_nf4_kernel(weight_ptr, quant_state_ptr, output_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    # Get the unique program (block) ID.
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    # Create a vector of indices for the block.
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to avoid out-of-bound accesses.
    mask = offsets < num_elements

    # Load values from the input weight tensor.
    # (Note: Currently, we ignore quant_state in the arithmetic.)
    values = tl.load(weight_ptr + offsets, mask=mask)
    # For now, simply cast the values to float16 and store in output.
    tl.store(output_ptr + offsets, tl.cast(values, tl.float16), mask=mask)

def _your_dequantize_nf4(weight, quant_state):
    # Debug prints on the host side
    print("=== Starting _your_dequantize_nf4 ===")
    print("Weight tensor shape:", weight.shape)
    print("Weight tensor dtype:", weight.dtype)
    try:
        print("Quant state shape:", quant_state.shape)
    except AttributeError:
        print("Quant state does not have a shape attribute (likely not a tensor).")
    
    # Total number of elements in the weight tensor.
    num_elements = weight.numel()
    print("Total number of elements:", num_elements)
    
    # Allocate an output tensor on the same device, with fp16 precision.
    output = torch.empty(num_elements, dtype=torch.float16, device=weight.device)
    
    # Determine grid size based on BLOCK_SIZE.
    grid = lambda meta: (cdiv(num_elements, meta['BLOCK_SIZE']),)
    
    # Launch the kernel.
    _your_dequantize_nf4_kernel[grid](weight, quant_state, output, num_elements, BLOCK_SIZE=1024)
    
    # Synchronize and print a small sample of the output for debugging.
    torch.cuda.synchronize()
    print("Kernel execution complete. Output sample (first 10 elements):", output[:10])
    print("=== Finished _your_dequantize_nf4 ===")
    return output

def your_dequantize_nf4(weight):
    # weight is expected to be an object with attributes 'weight.data' and 'weight.quant_state'
    print(">>> Entering your_dequantize_nf4")
    output = _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)
    print(">>> Exiting your_dequantize_nf4")
    return output

# For debugging purposes, you can create dummy inputs as follows:
if __name__ == '__main__':
    # Create a dummy quantized weight tensor (simulate with uint8 data).
    dummy_weight = torch.randint(0, 255, (1024,), dtype=torch.uint8, device="cuda")
    # Create a dummy quant_state tensor (its content is not used in this phase).
    dummy_quant_state = torch.empty((1,), dtype=torch.uint8, device="cuda")
    
    # Wrap dummy_weight in an object with the expected attributes.
    class DummyWeight:
        def __init__(self, weight, quant_state):
            self.weight = type("W", (), {"data": weight, "quant_state": quant_state})
    
    dummy_obj = DummyWeight(dummy_weight, dummy_quant_state)
    
    # Run our dequantization kernel skeleton.
    output = your_dequantize_nf4(dummy_obj)
    print("Final output (first 10 elements):", output[:10])
