In [1]:
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
import sys

def remove_patched_module(package_name):
    modules_to_delete = [
        name for name in sys.modules
        if name == package_name or name.startswith(package_name + ".")
    ]
    for name in modules_to_delete: del sys.modules[name]

remove_patched_module("trl")
remove_patched_module("transformers")
remove_patched_module("peft")
remove_patched_module("bitsandbytes")

In [4]:
from triton import jit
import triton
import triton.language as tl

@triton.jit
def _your_dequantize_nf4_kernel(#pointers
                                absmax2_ptr, out2_ptr, absmax1_ptr, out1_ptr, weight_ptr, code2_ptr, code1_ptr, offset1,
                                #Dimensions
                                blocksize1, blocksize2, num_elems_1, num_elems_2, shift1, shift2,
                                #meta param
                                BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_2:tl.constexpr
                                ):
  
    #init
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask_for_abs2 = offsets < num_elems_2 
    indexs = offsets >> shift2 
    
    #de_doublequant
    absmax2_max = tl.load(absmax2_ptr + indexs,  mask = mask_for_abs2, other = 128)
    absmax1_values = tl.load(absmax1_ptr + offsets, mask = mask_for_abs2, other = 128)
    x = tl.load(code2_ptr + tl.cast(absmax1_values, tl.int32), mask = mask_for_abs2, other = 128)
    val_offset = tl.fma(x, absmax2_max, tl.load(offset1))

    tl.store(out2_ptr + offsets, val_offset, mask = mask_for_abs2, eviction_policy='evict_last')
    tl.debug_barrier()
     
  
    new_offsets = pid * BLOCK_SIZE_2 + tl.arange(0, BLOCK_SIZE_2)

    #indexs = new_offsets >> shift1 
    indexs = tl.inline_asm_elementwise(
        asm ="""  
        shr.b32 $0, $4, $8;
        shr.b32 $1, $5, $9;
        shr.b32 $2, $6, $10;
        shr.b32 $3, $7, $11;
        """,
         constraints=(
            "=r,=r,=r,=r,"
            "r,r,r,r,r,r,r,r"),
        args=[new_offsets, shift1],
        dtype=(tl.int32),
        is_pure=True,
        pack=4,
    )

    mask_for_abs1 = new_offsets < num_elems_1
    gathered_max = tl.load(out2_ptr + indexs, mask = mask_for_abs1, other = 128)
    weights_values = tl.load(weight_ptr + new_offsets, mask = mask_for_abs1, other=128)
    
    #high_bits = weights_values >> 4
    high_bits,  lowtem = tl.inline_asm_elementwise(
        asm="""
        {
          // --- Unpack the 32-bit input into 4 bytes ---
          .reg .b8 a<4>;
          // The input packed value is in $0.
          mov.b32 {a0, a1, a2, a3}, $8;
          
          // --- Convert each 8-bit value to a 32-bit integer ---
          cvt.u32.u8 $4, a0;
          cvt.u32.u8 $5, a1;
          cvt.u32.u8 $6, a2;
          cvt.u32.u8 $7, a3;
        }
        // --- Compute high nibble for each (byte >> 4) ---
        shr.b32 $0, $4, 4;
        shr.b32 $1, $5, 4;
        shr.b32 $2, $6, 4;
        shr.b32 $3, $7, 4;
        
        // --- Compute low nibble for each (byte & 0x0F) ---
        and.b32 $4, $4, 0x0F;
        and.b32 $5, $5, 0x0F;
        and.b32 $6, $6, 0x0F;
        and.b32 $7, $7, 0x0F;
        
    """,
         constraints=(
            "=r,=r,=r,=r,=r,=r,=r,=r,"
            "r"),
        args=[weights_values],
        dtype=(tl.int32, tl.int32),
        is_pure=True,
        pack=4,
    )

    x = tl.load(code1_ptr + high_bits, mask = mask_for_abs1, other=4)
    hi_val = x * (gathered_max)

    x = tl.load(code1_ptr + lowtem, mask = mask_for_abs1, other=4)
    lo_val = x  * (gathered_max)

    #storing
    out_hi2 =  new_offsets  * 2
    out_lo2 =  new_offsets  * 2 + 1
    tl.store(out1_ptr + out_hi2, hi_val  , mask = mask_for_abs1 , eviction_policy='evict_last')
    tl.store(out1_ptr + out_lo2, lo_val  , mask = mask_for_abs1 , eviction_policy='evict_last')
  
@torch.compile(fullgraph=True)
def _your_dequantize_nf4(weight, quant_state):
    absmax2 = quant_state.state2.absmax 
    blocksize2 = quant_state.state2.blocksize
    absmax1 = quant_state.absmax 
    out2 = torch.empty(absmax1.shape, dtype=quant_state.dtype, device=absmax1.device)
    offset1 = quant_state.offset 
    out1 = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=weight.device)
    blocksize1 = quant_state.blocksize//2
    num_elems_1 = out1.numel()
    num_elems_2 = out2.numel()
    code2 = quant_state.state2.code 
    code1 = quant_state.code
    shift1 =  blocksize1.bit_length() - 1
    shift2 = blocksize2.bit_length() - 1

    grid = lambda meta: (triton.cdiv(num_elems_2, meta['BLOCK_SIZE']),)
    
    _your_dequantize_nf4_kernel[grid](
        #wanna_be_pointers
        absmax2 , out2 , absmax1 , out1 , weight,code2,code1, offset1,
        #Dimensions
        blocksize1,blocksize2,num_elems_1,num_elems_2,  shift1, shift2,
        #meta_param
        BLOCK_SIZE = 64, BLOCK_SIZE_2 = 64*32
    )
     
    return out1#, kernal   

def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)

In [5]:
def dequantize_bnb_weight_temp_patched(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
    """
    Helper function to dequantize 4bit or 8bit bnb weights.

    If the weight is not a bnb quantized weight, it will be returned as is.
    """
    if not isinstance(weight, torch.nn.Parameter):
        raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")

    cls_name = weight.__class__.__name__
    if cls_name not in ("Params4bit", "Int8Params"):
        return weight

    if cls_name == "Params4bit":
        output_tensor = your_dequantize_nf4(weight.data, weight.quant_state)
        logger.warning_once(
            f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
        )
        return output_tensor.to(dtype)

    if state.SCB is None:
        state.SCB = weight.SCB

    if hasattr(bnb.functional, "int8_vectorwise_dequant"):
        # Use bitsandbytes API if available (requires v0.45.0+)
        dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
    else:
        # Multiply by (scale/127) to dequantize.
        dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3

    return dequantized.to(dtype)

In [6]:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
import transformers.integrations.bitsandbytes
transformers.integrations.bitsandbytes.dequantize_bnb_weight = dequantize_bnb_weight_temp_patched


In [7]:
import os
import torch
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,"\
    "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import os
max_seq_length = 2048
torch.set_default_dtype(torch.float16)
model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
dtype = torch.float16
bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_compute_dtype    = dtype,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    attn_implementation = "sdpa",
    quantization_config = bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r = 64,
    lora_alpha = 128,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

# Get LoRA and setup model
model = get_peft_model(model, lora_config)
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
        else: param.requires_grad_(False)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Get dataset
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:10%]")

config.json:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

unified_chip2.jsonl:   0%|          | 0.00/95.6M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [8]:
fsdp_config = {
  "compute_environment": "LOCAL_MACHINE",
  "debug": False,
  "distributed_type": "FSDP",
  "downcast_bf16": "no",
  "enable_cpu_affinity": False,
  "fsdp_config": {
    "fsdp_activation_checkpointing": True,
    "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
    "fsdp_backward_prefetch": "BACKWARD_PRE",
    "fsdp_cpu_ram_efficient_loading": True,
    "fsdp_forward_prefetch": True,
    "fsdp_offload_params": True,
    "fsdp_sharding_strategy": "FULL_SHARD",
    "fsdp_state_dict_type": "SHARDED_STATE_DICT",
    "fsdp_sync_module_states": True,
    "fsdp_use_orig_params": True
  },
  "machine_rank": 0,
  "main_process_ip": "",
  "main_process_port": 8000,
  "main_training_function": "main",
  "mixed_precision": "fp16",
  "num_machines": 2,
  "num_processes": 2,
  "rdzv_backend": "static",
  "same_network": True,
  "tpu_env": [],
  "tpu_use_cluster": False,
  "tpu_use_sudo": False,
  "use_cpu": False
}


In [9]:
training_args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 1,
        max_steps = 10,
        logging_steps = 1,
        output_dir = "outputs",
        seed = 3407,
        max_seq_length = max_seq_length,
        fp16 = model.get_input_embeddings().weight.dtype == torch.float16,
        bf16 = model.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to = "none", # For W&B
        dataset_num_proc = 4,
       
        fsdp_config=fsdp_config
    )

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    processing_class = tokenizer,
    args =training_args,
)
trainer.train()

Converting train dataset to ChatML (num_proc=4):   0%|          | 0/21029 [00:00<?, ? examples/s]

Applying chat template to train dataset (num_proc=4):   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=4):   0%|          | 0/21029 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=4):   0%|          | 0/21029 [00:00<?, ? examples/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,9.0842
2,9.0773
3,11.4712
4,8.7086
5,7.7551
6,7.1945
7,6.528
8,7.0119
9,6.2949
10,7.4489


TrainOutput(global_step=10, training_loss=8.057448768615723, metrics={'train_runtime': 74.4421, 'train_samples_per_second': 1.075, 'train_steps_per_second': 0.134, 'total_flos': 461650822987776.0, 'train_loss': 8.057448768615723})