In [1]:
import dataclasses
import json
import logging
import os

import torch
from IPython.core.interactiveshell import InteractiveShell
from llama_head import CEL_only_forward
from peft import LoraConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    LlamaTokenizer,
    Trainer,
    TrainerCallback,
    TrainingArguments,
)
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.trainer_callback import ProgressCallback
from transformers.trainer_pt_utils import _secs2timedelta
from trl import SFTTrainer

import unsloth.utils.data as data_utils
import unsloth.utils.memory as memory_utils
import unsloth.utils.testing as test_utils
from unsloth.kernels import fused_cel
from unsloth.kernels.fused_cel import patch_model as patch_model_fused_cel
from unsloth.models._utils import patch_tokenizer, prepare_model_for_kbit_training
from unsloth.models.llama import FastLlamaModel
from unsloth.utils.profiling import MetricsCallBack

logging.basicConfig(level=logging.WARNING)

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model_config = LlamaConfig.from_pretrained("./llama-10m.json")
model = AutoModelForCausalLM.from_pretrained(
    "./llama-10m", quantization_config=quant_config, torch_dtype=torch.bfloat16
)
# model = LlamaForCausalLM(model_config).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", model_max_length=4096, padding_side="right"
)
model, tokenizer = patch_tokenizer(model, tokenizer)

max_seq_length = 256

training_args = TrainingArguments(
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    warmup_steps=5,
    max_steps=5,
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=1,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=3407,
    output_dir="outputs",
    overwrite_output_dir=True,
    # Metrics
    skip_memory_metrics=False,
    include_num_input_tokens_seen=True,
    include_tokens_per_second=True,
)

accepted_modules = frozenset(
    (
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ),
)

dataset = data_utils.get_alpaca(tokenizer)

peft_config = LoraConfig(
    target_modules=accepted_modules,
    lora_alpha=8,
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)

    PyTorch 2.3.0+cu121 with CUDA 1201 (you have 2.4.0.dev20240507+cu121)
    Python  3.11.9 (you have 3.11.4)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
`low_cpu_mem_usage` was None, now set to True since model is quantized.
./llama-10m does not have a padding token! Will use pad_token = <unk>.


In [2]:
from unsloth.utils.data import get_data_loader

In [3]:
dataloader = get_data_loader(dataset, tokenizer, max_seq_length)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [4]:
batches = list(iter(dataloader))

In [10]:
batch = batches[0]

In [13]:
out = model(**batch)
fused_out = patch_model_fused_cel(model, use_fused_cel=True)(**batch)
out.loss, fused_out.loss



(tensor(10.3927, grad_fn=<ToCopyBackward0>),
 tensor(10.3927, device='cuda:0',
        grad_fn=<FusedCrossEntropyLossFunctionBackward>))

In [12]:
out.loss

tensor(10.3927, grad_fn=<ToCopyBackward0>)

In [2]:
bs, seqlen, in_features = 1, 16, 4096
dtype = torch.bfloat16
hidden_dim = model.config.hidden_size

input_ids = torch.randint(0, model.config.vocab_size, (bs, seqlen), device="cuda")
hidden_states = torch.randn(
    bs, seqlen, hidden_dim, dtype=dtype, device="cuda", requires_grad=True
)
labels = input_ids.detach().clone()
attention_mask = torch.ones((bs, seqlen), device="cuda")

# ref_out = model(input_ids, labels=labels, attention_mask=attention_mask)
# ref_head = model.lm_head

In [4]:
embed_out = model.model.embed_tokens(input_ids)
decoder_out = model.model.

In [8]:
llama_out = model.model(input_ids, attention_mask=attention_mask)

In [14]:
llama_out["last_hidden_state"].requires_grad

True

In [2]:
# Need to set inputs require_grad (i.e., output of embeddings requires grad for fused_cel to work)
model = prepare_model_for_kbit_training(
    model, use_gradient_checkpointing=training_args.gradient_checkpointing
)

In [4]:
patched_model = patch_model_fused_cel(model, use_fused_cel=True)

trainer = SFTTrainer(
    model=patched_model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,  # Can make training 5x faster for short sequences.
    args=training_args,
)
# trainer.remove_callback(ProgressCallback)
# _ = trainer.add_callback(MetricsCallBack())
# train_stats = trainer.train()

max_steps is given, it will override any value given in num_train_epochs


In [5]:
dataset

Dataset({
    features: ['output', 'input', 'instruction', 'text'],
    num_rows: 51760
})

In [7]:
train_loader = trainer.get_train_dataloader()

In [8]:
import itertools

batches = list(itertools.islice(train_loader, 10))

In [9]:
batches[0]

{'input_ids': tensor([[    1, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29892,
          3300,  2859,   411,   385,  1881,   393,  8128,  4340,  3030, 29889,
         14350,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009,
         29889,    13,    13,  2277, 29937,  2799,  4080, 29901,    13, 29907,
         20440,   675,   322,  3858,   278,  2183,  1426,   408,  2845,   263,
          2114,   470,   385,  9426, 29889,    13,    13,  2277, 29937, 10567,
         29901,    13,  1576, 14064,   471,  8031, 29889,    13,    13,  2277,
         29937, 13291, 29901,    13, 11746,   262,   291, 29901,   450,  3229,
           376,  1576, 14064,   471,  8031, 29908,   338,   385,  9426,  1363,
           372,   338,   263,  4967,   573, 24284,  2729,   373,  7333, 21737,
         29892, 21779, 29892,   470, 24583, 29889,     2,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

batches = [b for b in accepted_modules
