In [None]:
import torch
import bitsandbytes
from typing import Any

# Original Params4bit class (for reference, you donâ€™t need to define this)
# class Params4bit:
#     def __init__(self, data: torch.Tensor, **kwargs):
#         self.data = data
#         self.__dict__.update(kwargs)
#     def some_method(self):
#         return self.data * 2

# Your wrapper class
OriginalParams4bit = bitsandbytes.nn.Params4bit
class Params4bitWrap:
    def __init__(self, data: torch.Tensor, *args, **kwargs):
        # Explicitly store the tensor
        assert(isinstance(data, torch.Tensor) == True)
        self._params4bit = OriginalParams4bit(data, *args, **kwargs)
        self.data = data

    def __getattr__(self, name: str) -> Any:
        # Delegate undefined attributes/methods to _params4bit
        if name == "data":
            return object.__getattribute__(self, name)
        if name == "_params4bit":
            return object.__getattribute__(self, name)
        return getattr(self._params4bit, name)
    
    def __setattr__(self, name: str, value: Any) -> None:
        # Special handling for 'data' to keep it in sync
        if name == "data":
            # Set self.data directly as a tensor
            assert(isinstance(value, torch.Tensor) == True)
            object.__setattr__(self, name, value)
            self._params4bit.data = value
        elif name == "_params4bit":
            object.__setattr__(self, name, value)
        else:
            setattr(self._params4bit, name, value)

    def __getstate__(self):
        return self._params4bit.__getstate__()

    def __setstate__(self, state):
        self._params4bit.__setstate__(state)
        
    @classmethod
    def from_prequantized(cls, data, *args, **kwargs):
        raise NotImplementedError("This method is not implemented")
        # Delegate to the original Params4bit.from_prequantized
        # original_instance = OriginalParams4bit.from_prequantized(data, *args, **kwargs)
        # wrapped_instance = Params4bitWrap.__new__(cls)
        # wrapped_instance.data = data
        # wrapped_instance._params4bit = original_instance
        # return wrapped_instance

    def __deepcopy__(self, memo):
        raise NotImplementedError("This method is not implemented")
        # new_instance = type(self).__new__(type(self))
        # state = self.__getstate__()
        # new_instance.__setstate__(state)
        # return new_instance

    def __copy__(self):
        # original_instance = self._params4bit.__copy__()
        raise NotImplementedError("This method is not implemented")
        # return original_instance
        

    def to_parameter(self):
        return torch.nn.Parameter(self.data)


# Monkey patch: Replace Params4bit with Params4bitWrap
bitsandbytes.nn.Params4bit = Params4bitWrap


# Now any code that imports and uses Params4bit will get Params4bitWrap instead
@torch.compile(fullgraph=True)
def my_function(param):
    return param.data + 999  # Should work with torch.compile


# Example: Simulate third-party code creating Params4bit
from bitsandbytes.nn import Params4bit  # This is now Params4bitWrap

param = Params4bit(torch.randn(5))
result = my_function(param)
print(result)

In [None]:
# Plan: as part of patching the forward function, it would need to call the patched version of dequantize_4bit,
# and possibly a version of it that also does transposing
from bitsandbytes import functional

from functions import my_dequantize_4bit

if not hasattr(functional, 'original_dequantize_4bit'):
    functional.original_dequantize_4bit = functional.dequantize_4bit

In [None]:
import os
import json
import pickle
import torch

# NOTE: the special patching happens here as I'm not sure if this being imported
# later will cause it to be un-patchable
next_case_index = 0
DIR_NAME = 'dequantize_4bit_cases'
PREFIX = 'case_'

# clear the directory first
for file in os.listdir(DIR_NAME):
    os.remove(f'{DIR_NAME}/{file}')

def printing_dequantize_4bit(
    A, # torch.Tensor
    quant_state = None, # Optional[QuantState]
    absmax = None, # Optional[torch.Tensor]
    out = None, # Optional[torch.Tensor]
    blocksize = 64, # int
    quant_type = "fp4",
):
    assert(quant_state is not None)
    assert(absmax is None)
    assert(out is None)
    assert(blocksize == 64)
    assert(quant_type == "fp4")

    global next_case_index
    os.makedirs(DIR_NAME, exist_ok=True)
    
    torch.save(A, f'{DIR_NAME}/{PREFIX}{next_case_index}_A.pt')
    if quant_state is not None:
        qs_dict = quant_state.as_dict(packed = True)
        # recurse through qs_dict, replace each tensor with a filename
        for key, value in qs_dict.items():
            if isinstance(value, torch.Tensor):
                torch.save(value, f'{DIR_NAME}/{PREFIX}{next_case_index}_{key}.pt')
                qs_dict[key] = f'{DIR_NAME}/{PREFIX}{next_case_index}_{key}.pt'
        
        with open(f'{DIR_NAME}/{PREFIX}{next_case_index}_qs.json', 'w') as f:
            json.dump(qs_dict, f)

    next_case_index += 1
    return functional.original_dequantize_4bit(A, quant_state, absmax, out, blocksize, quant_type)

functional.dequantize_4bit = printing_dequantize_4bit

In [None]:
# Print the number of cases that have been dumped
# and the number that have quant_state
import os
import json
import pickle
import torch

num_cases = 0
num_qstate = 0
for file in os.listdir(DIR_NAME):
    if file.endswith('_A.pt'):
        num_cases += 1
        if os.path.exists(f'{DIR_NAME}/{file[:-5]}_qs.json'):
            num_qstate += 1

print(f'Number of cases: {num_cases}')
print(f'Number of cases with quant_state: {num_qstate}')

In [None]:
# Example of reading a specific case that was dumped
import torch
CASE_INDEX = 3333
if os.path.exists(f'{DIR_NAME}/{PREFIX}{CASE_INDEX}_A.pt'):
    A = torch.load(f'{DIR_NAME}/{PREFIX}{CASE_INDEX}_A.pt')
    with open(f'{DIR_NAME}/{PREFIX}{CASE_INDEX}_qs.json', 'r') as f:
        qs_dict = json.load(f)
        for key, value in qs_dict.items():
            if isinstance(value, str) and value.endswith('.pt'):
                qs_dict[key] = torch.load(value)
        quant_state = functional.QuantState.from_dict(qs_dict, 'cuda')
        print(quant_state)

In [None]:
import os
import json
import pickle

def augmented_dequantize_4bit(
    A, # torch.Tensor
    quant_state = None, # Optional[QuantState]
    absmax = None, # Optional[torch.Tensor]
    out = None, # Optional[torch.Tensor]
    blocksize = 64, # int
    quant_type = "fp4",
):
    assert(quant_state is not None)
    assert(absmax is None)
    assert(out is None)
    assert(blocksize == 64)
    assert(quant_type == "fp4")

    return my_dequantize_4bit(A, quant_state)
    # return functional.original_dequantize_4bit(A, quant_state, absmax, out, blocksize, quant_type)

functional.dequantize_4bit = augmented_dequantize_4bit

In [None]:
# Plan: patch the forward function of nn.Linear4bit, such that it can be torch.compile(d)
# this should mean that the backward pass can also be automatically derived
# first lets store the original functions before patching, this cell should not be modified
import bitsandbytes as bnb
if not hasattr(bnb.nn.Linear4bit, 'original_forward'):
    bnb.nn.Linear4bit.original_forward = bnb.nn.Linear4bit.forward

In [None]:
def printing_4bit_forward(self, *args, **kwargs):
    print(self.weight)
    return self.original_forward(*args, **kwargs)

bnb.nn.Linear4bit.forward = printing_4bit_forward

In [None]:
# If this cell is disabled, then it should show the custom dequantize function working
# Transposition trick, to avoid having the transposition be done by calling Params4bit.t()
# NOT transpose in the Linear4bit.forward, BUT secretly transpose in the MatMul4Bit
# maybe instead of it calling MatMul4bit, it could call a custom implementation of torch Function

import bitsandbytes.functional as F
from typing import Optional
from math import prod
from bitsandbytes.nn.modules import fix_4bit_weight_quant_state_from_module
from functions import ENABLE_ASSERTIONS

class TransposeBMatMul4Bit(torch.autograd.Function):
    # forward is the same, but we added the fallback for pre-turing GPUs
    # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

    @staticmethod
    def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
        # default of pytorch behavior if inputs are empty
        ctx.is_empty = False
        if prod(A.shape) == 0:
            ctx.is_empty = True
            ctx.A = A
            ctx.B = B
            ctx.bias = bias
            B_shape = quant_state.shape
            if A.shape[-1] == B_shape[0]:
                return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
            else:
                return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)

        # 1. Dequantize
        # 2. MatmulnN
        output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) # NOTE the transposition

        # 3. Save state
        ctx.state = quant_state
        ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

        if any(ctx.needs_input_grad[:2]):
            ctx.tensors = (None, B)
        else:
            ctx.tensors = (None, None)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.is_empty:
            bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

        req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
        _, B = ctx.tensors

        grad_A, grad_B, grad_bias = None, None, None

        if req_gradBias:
            # compute grad_bias first before changing grad_output dtype
            grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

        # not supported by PyTorch. TODO: create work-around
        # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
        if req_gradA:
            grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) # NOTE the transposition

        return grad_A, grad_B, None, grad_bias, None

def inner_transpose_forward(self, x: torch.Tensor):
    fix_4bit_weight_quant_state_from_module(self)

    # weights are cast automatically as Int8Params, but the bias has to be cast manually
    if self.bias is not None and self.bias.dtype != x.dtype:
        self.bias.data = self.bias.data.to(x.dtype)

    if not self.compute_type_is_set:
        self.set_compute_type(x)
        self.compute_type_is_set = True

    inp_dtype = x.dtype
    if self.compute_dtype is not None:
        x = x.to(self.compute_dtype)

    bias = None if self.bias is None else self.bias.to(self.compute_dtype)

    # return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
    # NOTE: here we pass in the data (nf4-packed weights) directly as Matrix B
    # we assume the GEMV path of bnb.matmul_4bit is never taken
    A = x
    B = self.weight.data
    out = None
    quant_state = self.weight.quant_state

    # shape is  torch.Size([1, 100, 2048]) torch.Size([2097152, 1]) <reversed object at 0x7e1f00bb5cc0>
    # print('shape is ', A.shape, B.shape, reversed(B.shape))
    if ENABLE_ASSERTIONS:
        assert(B.shape[1] == 1)
    return TransposeBMatMul4Bit.apply(A, B.t(), out, bias, quant_state).to(inp_dtype)

bnb.nn.Linear4bit.forward = inner_transpose_forward

In [None]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [None]:
import torch
torch_compile_options = torch_compile_options = {
    "epilogue_fusion"   : True,
    "max_autotune"      : True,
    "shape_padding"     : True,
    "trace.enabled"     : True,
    "triton.cudagraphs" : False,
}

@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def compiled_llama_mlp(self, x):
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    return down_proj

import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaMLP.forward = compiled_llama_mlp

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
    "expandable_segments:True,"\
    "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"

max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16
bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_compute_dtype    = dtype,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    attn_implementation = "sdpa",
    quantization_config = bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r = 32,
    lora_alpha = 64,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

# Get LoRA and setup model
model = get_peft_model(model, lora_config)
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
        else: param.requires_grad_(False)

# Currently GC will cause torch.compile to be disabled, so disable it
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Get dataset
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:10%]")

In [None]:
# Must show all graph breaks are not seen with torch.compile
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

import logging
torch._inductor.config.debug = True
torch._logging.set_logs(
    dynamo = logging.WARN,
    inductor = logging.WARN,
    graph_breaks = True,
    recompiles = True,
    recompiles_verbose = True,
    compiled_autograd_verbose = True,
    # aot_joint_graph = True, # Enable for more logs
    # aot_graphs = True,
)
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False

In [None]:
## Fast Iteration Cell, run this and the one below to quickly iterate on the forward function
def printing_4bit_forward(self, *args, **kwargs):
    # original forward function does matmul between weight transposed and X
    # if we can assume weight always follows a certain pattern
    print('==================================================')
    qs = self.weight.quant_state
    items = [
        qs.absmax,
        qs.shape,
        qs.code,
        qs.dtype,
        qs.blocksize,
        qs.quant_type,
        qs.offset,
        qs.state2,
        qs.nested,
    ]
    # assert(qs.quant_type == 'nf4')
    # assert(qs.nested == False)
    # assert(qs.dtype == torch.bfloat16)
    print('quant_state', items)
    print(self.weight)
    print('--------------------------------------------------')
    print(self.weight.t())
    print('==================================================')
    return self.original_forward(*args, **kwargs)

# bnb.nn.Linear4bit.forward = printing_4bit_forward

In [None]:
trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    processing_class = tokenizer,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 2,
        warmup_steps = 1,
        max_steps = 10,
        logging_steps = 1,
        output_dir = "outputs",
        seed = 3407,
        max_seq_length = max_seq_length,
        fp16 = model.get_input_embeddings().weight.dtype == torch.float16,
        bf16 = model.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to = "none", # For W&B
        dataset_num_proc = 4,
    ),
)
trainer.train()