In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
import logging

# Enable FSDP debug logging
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
fsdp_logger = logging.getLogger("torch.distributed.fsdp")
fsdp_logger.setLevel(logging.DEBUG)
dist_logger = logging.getLogger("torch.distributed")
dist_logger.setLevel(logging.DEBUG)

In [None]:
import os
print("NCCL_BACKEND:", os.environ.get("NCCL_BACKEND"))
print("TORCH_DISTRIBUTED_BACKEND:", os.environ.get("TORCH_DISTRIBUTED_BACKEND"))
os.environ["TORCH_DISTRIBUTED_BACKEND"] = "nccl"  # or "gloo"

In [None]:
# ==================================================

In [None]:
# Plan: as part of patching the forward function, it would need to call the patched version of dequantize_4bit,
# and possibly a version of it that also does transposing
from bitsandbytes import functional

if not hasattr(functional, 'original_dequantize_4bit'):
    functional.original_dequantize_4bit = functional.dequantize_4bit

In [None]:
from kernels import fused_dequantize
import torch

def augmented_dequantize_4bit(
    A, # torch.Tensor
    quant_state = None, # Optional[QuantState]
    absmax = None, # Optional[torch.Tensor]
    out = None, # Optional[torch.Tensor]
    blocksize = 64, # int
    quant_type = "fp4",
):
    assert(quant_state is not None)
    assert(absmax is None)
    assert(out is None)
    assert(blocksize == 64)
    assert(quant_type == "fp4")

    return fused_dequantize(A, quant_state).t()
    # return functional.original_dequantize_4bit(A, quant_state, absmax, out, blocksize, quant_type)

functional.dequantize_4bit = augmented_dequantize_4bit

In [None]:
# Plan: patch the forward function of nn.Linear4bit, such that it can be torch.compile(d)
# this should mean that the backward pass can also be automatically derived
# first lets store the original functions before patching, this cell should not be modified
import bitsandbytes as bnb
if not hasattr(bnb.nn.Linear4bit, 'original_forward'):
    bnb.nn.Linear4bit.original_forward = bnb.nn.Linear4bit.forward

In [None]:
from torch.compiler import allow_in_graph
# If this cell is disabled, then it should show the custom dequantize function working
# Transposition trick, to avoid having the transposition be done by calling Params4bit.t()
# NOT transpose in the Linear4bit.forward, BUT secretly transpose in the MatMul4Bit
# maybe instead of it calling MatMul4bit, it could call a custom implementation of torch Function

import bitsandbytes.functional as F
from typing import Optional
from math import prod
from bitsandbytes.nn.modules import fix_4bit_weight_quant_state_from_module, Params4bit
ENABLE_ASSERTIONS = False

class TransposeBMatMul4Bit(torch.autograd.Function):
    # forward is the same, but we added the fallback for pre-turing GPUs
    # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

    @staticmethod
    def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
        # default of pytorch behavior if inputs are empty
        ctx.is_empty = False
        if prod(A.shape) == 0:
            ctx.is_empty = True
            ctx.A = A
            ctx.B = B
            ctx.bias = bias
            B_shape = quant_state.shape
            if A.shape[-1] == B_shape[0]:
                return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
            else:
                return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)

        # 1. Dequantize
        # 2. MatmulnN
        output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) # NOTE the transposition

        # 3. Save state
        ctx.state = quant_state
        ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

        if any(ctx.needs_input_grad[:2]):
            ctx.tensors = (None, B)
        else:
            ctx.tensors = (None, None)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.is_empty:
            bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

        req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
        _, B = ctx.tensors

        grad_A, grad_B, grad_bias = None, None, None

        if req_gradBias:
            # compute grad_bias first before changing grad_output dtype
            grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

        # not supported by PyTorch. TODO: create work-around
        # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
        if req_gradA:
            grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) # NOTE the transposition

        return grad_A, grad_B, None, grad_bias, None

@torch.compile(fullgraph = True)
def get_data_transposed(param: Params4bit):
    if isinstance(param.some_data, Params4bit):
        param = param.some_data
    # assert(False)
    # print(param.some_data)
    # print(type(param.data))  # See what Dynamo sees
    if isinstance(param.some_data, Params4bit):
        raise RuntimeError("Expected some_data NOT to be Params4bit")
    # assert isinstance(param.some_data, torch.Tensor)
    return param.some_data.t()

def inner_transpose_forward(self, x: torch.Tensor):
    fix_4bit_weight_quant_state_from_module(self)

    # weights are cast automatically as Int8Params, but the bias has to be cast manually
    if self.bias is not None and self.bias.dtype != x.dtype:
        self.bias.data = self.bias.data.to(x.dtype)

    if not self.compute_type_is_set:
        self.set_compute_type(x)
        self.compute_type_is_set = True

    inp_dtype = x.dtype
    if self.compute_dtype is not None:
        x = x.to(self.compute_dtype)

    bias = None if self.bias is None else self.bias.to(self.compute_dtype)

    # return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
    # NOTE: here we pass in the data (nf4-packed weights) directly as Matrix B
    # we assume the GEMV path of bnb.matmul_4bit is never taken
    A = x
    B = get_data_transposed(self.weight)
    out = None
    quant_state = self.weight.quant_state

    # shape is  torch.Size([1, 100, 2048]) torch.Size([2097152, 1]) <reversed object at 0x7e1f00bb5cc0>
    # print('shape is ', A.shape, B.shape, reversed(B.shape))
    if ENABLE_ASSERTIONS:
        assert(B.shape[1] == 1)
    return TransposeBMatMul4Bit.apply(A, B, out, bias, quant_state).to(inp_dtype)

bnb.nn.Linear4bit.forward = inner_transpose_forward

In [None]:
import torch
torch_compile_options = {
    "epilogue_fusion"   : True,
    "max_autotune"      : True,
    "shape_padding"     : True,
    "trace.enabled"     : True,
    "triton.cudagraphs" : False,
}

@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def compiled_llama_mlp(self, x):
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    return down_proj

import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaMLP.forward = compiled_llama_mlp

In [None]:
# ==================================================

In [None]:
def monkey_patch_custom(module, fn_name, replacement):
    original_fn_name = f'original_{fn_name}'
    if not hasattr(module, original_fn_name):
        original_fn = getattr(module, fn_name)
        assert original_fn is not None
        def new_fn(*args, **kwargs):
            return replacement(original_fn, *args, **kwargs)

        setattr(module, original_fn_name, original_fn)
        setattr(module, fn_name, new_fn)

In [None]:
def train_model():
    # Helpful functions used through the entire notebook
    from accelerate import Accelerator
    accelerator = Accelerator()
    process_index = accelerator.process_index

    from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
    def my_set_auto_wrap_policy(original_fn, self, model):
        import torch.nn as nn
        from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding
        # instance_types = [nn.Embedding, LlamaAttention, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding]
        instance_types = ['Embedding', 'LlamaAttention', 'LlamaMLP', 'LlamaRMSNorm', 'LlamaRotaryEmbedding', 'LlamaSdpaAttention', 'LlamaAttention']
        def my_module_wrap_policy(
            module: nn.Module,
            recurse: bool,
            nonwrapped_numel: int,
        ) -> bool:
            if recurse: # always recurse
                return True

            for name in instance_types:
                if name in str(type(module)):
                    return True

            if module == model.lm_head:
                print('Found LM_HEAD')
                return True

            return False
        self.auto_wrap_policy = my_module_wrap_policy

    monkey_patch_custom(FullyShardedDataParallelPlugin, 'set_auto_wrap_policy', my_set_auto_wrap_policy)

    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    import torch
    import torch.distributed as dist
    if dist.is_available() and dist.is_initialized():
        backend = dist.get_backend()
        print(f"Rank {process_index}: Using backend {backend}")
        print(f"Rank {process_index}: NCCL available: {dist.is_nccl_available()}")
        print(f"Rank {process_index}: CUDA available: {torch.cuda.is_available()}")

    import torch
    import torch.nn as nn
    from transformers import set_seed
    import time
    import inspect
    import os
    import gc
    major_version, minor_version = torch.cuda.get_device_capability()
    HAS_BFLOAT16 = (major_version >= 8)
    from inspect import currentframe as _C, getframeinfo
    _F = lambda c: getframeinfo(c).lineno # Gets line number
    WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings
    
    # https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
    def NAME(var):
        callers_local_vars = inspect.currentframe().f_back.f_locals.items()
        names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
        return names[0] if len(names) != 0 else ""
    
    def assert_same(x, y, line, dtype):
        assert(x.dtype == dtype)
        try: torch.testing.assert_close(x, y, check_stride = True)
        except Exception as error:
            raise RuntimeError(
                f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
            )
    
    os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

    # HELPFUL functions to undo Unsloth patches:
    import sys
    
    def remove_patched_module(package_name):
        modules_to_delete = [
            name for name in sys.modules
            if name == package_name or name.startswith(package_name + ".")
        ]
        for name in modules_to_delete: del sys.modules[name]
    
    remove_patched_module("trl")
    remove_patched_module("transformers")
    
    remove_patched_module("peft")
    remove_patched_module("bitsandbytes")

    # CUSTOM
    # rank = torch.distributed.get_rank()
    # world_size = torch.distributed.get_world_size()
    print(f'Training on {process_index}')
    
    import os
    os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,"\
        "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
    
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from peft import get_peft_model, LoraConfig, TaskType
    
    max_seq_length = 2048
    torch.set_default_dtype(torch.float32)
    model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
    dtype = torch.float16
    bnb_config = BitsAndBytesConfig(
        load_in_4bit              = True,
        bnb_4bit_use_double_quant = True,
        bnb_4bit_quant_type       = "nf4",
        bnb_4bit_quant_storage    = torch.float32,
        bnb_4bit_compute_dtype    = torch.float32,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map = f"cuda:{process_index}",
        attn_implementation = "sdpa",
        quantization_config = bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.padding_side = "right"
    
    lora_config = LoraConfig(
        r = 64,
        lora_alpha = 128,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                          "gate_proj", "up_proj", "down_proj"],
        lora_dropout = 0,
        bias = "none",
        task_type = TaskType.CAUSAL_LM,
    )
    
    # Get LoRA and setup model
    model = get_peft_model(model, lora_config)
    with torch.no_grad():
        for name, param in model.named_parameters():
            if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
            else: param.requires_grad_(False)
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    # ==================================================
    from bitsandbytes.nn import Params4bit
    from bitsandbytes import functional as F
    from converter_kernel import convert_uint8_to_float32
    # params4bit_count = 0
    def convert_fn(t):
        # gc.collect()
        # torch.cuda.empty_cache()
        # if t.dtype == torch.float16:
        #     print(f'GPU{process_index}: converting {t.numel()} to f32')
        #     return t.to(torch.float32)

        if not isinstance(t, Params4bit):
            return t
    
        assert not isinstance(t.data, Params4bit)
        # print(t.quant_storage)
        t.data = convert_uint8_to_float32(t.data.cuda()).to(t.data.device)
        t.quant_storage = torch.float32
        return t

    model._apply(convert_fn)
    # print(f'Number of Params4bit converted: {params4bit_count}')
    # ==================================================
    gc.collect()
    torch.cuda.empty_cache()
    # ==================================================
    # Surely at this point everything inside LlamaAttention should be fp32 right?
    from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaAttention, LlamaRotaryEmbedding, LlamaDecoderLayer
    for module in model.modules():
        if isinstance(module, LlamaAttention):
            # print(f'GPU{process_index}: examining LlamaAttention\n', end='')
            for name, param in module.named_parameters():
                if param.dtype != torch.float32:
                    print(f'GPU{process_index}: invalid param in module {name}\n', end='')
    # ==================================================
    
    # Get dataset
    from datasets import load_dataset
    from trl import SFTTrainer, SFTConfig
    url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
    dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:10%]")

    from torch.distributed.fsdp.wrap import ModuleWrapPolicy
    import torch.nn as nn
    from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding
    instance_types = {nn.Embedding, LlamaAttention, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding}

    print(model)
    
    trainer = SFTTrainer(
        model = model,
        train_dataset = dataset,
        processing_class = tokenizer,
        args = SFTConfig(
            per_device_train_batch_size = 2,
            gradient_accumulation_steps = 4,
            warmup_steps = 1,
            max_steps = 10,
            logging_steps = 1,
            output_dir = "outputs",
            seed = 3407,
            max_seq_length = max_seq_length,
            fp16 = False,
            bf16 = False,
            report_to = "none", # For W&B
            dataset_num_proc = 4,

            # custom args for FSDP2
            fsdp="auto_wrap",  # Enable FSDP
            fsdp_config={
                "offload_to_cpu": False,
                "activation_checkpointing": False,
                "limit_all_gathers": False,
                "forward_prefetch": True,
                "backward_prefetch": "backward_pre",
                "sync_module_states": False,
                # "transformer_layer_cls_to_wrap": [
                #     "Embedding",
                #     "LlamaAttention",
                #     "LlamaMLP",
                #     "LlamaRMSNorm",
                #     "LlamaRotaryEmbedding",
                # ],
                # replaced with monkey patch as we also want to wrap LM_HEAD
                # also this causes unnecessary wrapping of inner Linear
            },
        ),
    )

    # fsdp_count_before=0
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    fsdp_count_before = 0
    for module in model.modules():
        if isinstance(module, FSDP):
            fsdp_count_before += 1
    print(f'{fsdp_count_before=}')

    # trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = ModuleWrapPolicy(instance_types)
    # NOTE: ^ is overwritten by FullyShardedDataParallelPlugin.set_auto_wrap_policy called from prepare_model

    trainer.train()

    print(trainer.model_wrapped)

    # fsdp_count_after=579
    # fsdp_count_after=131
    fsdp_count_after = 0
    for module in model.modules():
        if isinstance(module, FSDP):
            fsdp_count_after += 1
    print(f'{fsdp_count_after=}')


In [None]:
from accelerate import notebook_launcher

In [None]:
notebook_launcher(train_model, num_processes=2)