In [10]:
!pip install -U bitsandbytes accelerate transformers peft datasets


Collecting peft
  Downloading peft-0.18.0-py3-none-any.whl.metadata (14 kB)
Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Downloading peft-0.18.0-py3-none-any.whl (556 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m556.4/556.4 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-4.4.1-py3-none-any.whl (511 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m42.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (47.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyarrow, datasets, peft
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 18.1.0
    Uninstalling pyar

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [2]:
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,                    # QLoRA 4-bit quantization
    device_map="auto"
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [4]:
model = prepare_model_for_kbit_training(model)

In [5]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj","k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

In [6]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 3,063,808 || all params: 1,103,112,192 || trainable%: 0.2777


In [7]:
dataset = load_dataset("tatsu-lab/alpaca", split="train[:2000]")

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-a09b74b3ef9c3b(…):   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

In [8]:
dataset

Dataset({
    features: ['instruction', 'input', 'output', 'text'],
    num_rows: 2000
})

In [9]:
def format_sample(sample):
    prompt = f"### Instruction:\n{sample['instruction']}\n\n### Response:\n"
    return {"text": prompt + sample["output"]}

In [10]:
dataset = dataset.map(format_sample)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [11]:
def tokenize(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=256
    )

In [12]:
tokenized = dataset.map(tokenize, batched=True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [13]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [14]:
training_args = TrainingArguments(
    output_dir="./qlora-output",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=20,
    save_strategy="epoch"
)


In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    data_collator=data_collator
)

In [16]:
trainer.train()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m[0m ([33m-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
20,1.6553
40,1.4031
60,1.3277
80,1.3142
100,1.3407
120,1.3171
140,1.3615
160,1.2949
180,1.3398
200,1.3066


TrainOutput(global_step=250, training_loss=1.347245433807373, metrics={'train_runtime': 1841.3835, 'train_samples_per_second': 1.086, 'train_steps_per_second': 0.136, 'total_flos': 3187434061824000.0, 'train_loss': 1.347245433807373, 'epoch': 1.0})

In [17]:
prompt = "Explain what is machine learning in simple words."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

In [18]:
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=100)

print(tokenizer.decode(output[0], skip_special_tokens=True))

Caching is incompatible with gradient checkpointing in LlamaDecoderLayer. Setting `past_key_values=None`.
  return fn(*args, **kwargs)


Explain what is machine learning in simple words.
< ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ###
