Skip to content
Michael Han edited this page May 11, 2024 · 18 revisions

Updating Unsloth without dependency updates

pip uninstall unsloth -y
pip install --upgrade --force-reinstall --no-cache-dir git+https://github.com/unslothai/unsloth.git

Loading LoRA adapters for continued finetuning

If you saved a LoRA adapter through Unsloth, you can also continue training using your LoRA weights. The optimizer state will be reset as well. To load even optimizer states to continue finetuning, see the next section.

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "LORA_MODEL_NAME",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
trainer = Trainer(...)
trainer.train()

Finetuning the lm_head and embed_tokens matrices:

Don't forget to resize your embedding matrices if you added new tokens!

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    modules_to_save = ["lm_head", "embed_tokens",],
)

Finetuning from your last checkpoint

You must edit the Trainer first to add save_strategy and save_steps. Below saves a checkpoint every 50 steps to the folder outputs.

trainer = SFTTrainer(
    ....
    args = TrainingArguments(
        ....
        output_dir = "outputs",
        save_strategy = "steps",
        save_steps = 50,
    ),
)

Then in the trainer do:

trainer_stats = trainer.train(resume_from_checkpoint = True)

Which will start from the latest checkpoint and continue training.

Saving models to 16bit for VLLM

To save to 16bit for VLLM, use:

model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

To merge to 4bit to load on HuggingFace, first call merged_4bit. Then use merged_4bit_forced if you are certain you want to merge to 4bit. I highly discourage you, unless you know what you are going to do with the 4bit model (ie for DPO training for eg or for HuggingFace's online inference engine)

model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

To save just the LoRA adapters, either use:

model.save_pretrained(...) AND tokenizer.save_pretrained(...)

Or just use our builtin function to do that:

model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

Saving to safetensors, not bin format in Colab

We save to .bin in Colab so it's like 4x faster, but set safe_serialization = None to force saving to .safetensors. So model.save_pretrained(..., safe_serialization = None) or model.push_to_hub(..., safe_serialization = None)

Saving to GGUF

To save to GGUF, use the below to save locally:

model.save_pretrained_gguf("dir", tokenizer, quantization_method = "q4_k_m")
model.save_pretrained_gguf("dir", tokenizer, quantization_method = "q8_0")
model.save_pretrained_gguf("dir", tokenizer, quantization_method = "f16")

For to push to hub:

model.push_to_hub_gguf("hf_username/dir", tokenizer, quantization_method = "q4_k_m")
model.push_to_hub_gguf("hf_username/dir", tokenizer, quantization_method = "q8_0")

All supported quantization options for quantization_method are listed below:

# https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19
# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
ALLOWED_QUANTS = \
{
    "not_quantized"  : "Recommended. Fast conversion. Slow inference, big files.",
    "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
    "quantized"      : "Recommended. Slow conversion. Fast inference, small files.",
    "f32"     : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
    "f16"     : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
    "q8_0"    : "Fast conversion. High resource use, but generally acceptable.",
    "q4_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
    "q5_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
    "q2_k"    : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
    "q3_k_l"  : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
    "q3_k_m"  : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
    "q3_k_s"  : "Uses Q3_K for all tensors",
    "q4_0"    : "Original quant method, 4-bit.",
    "q4_1"    : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
    "q4_k_s"  : "Uses Q4_K for all tensors",
    "q4_k"    : "alias for q4_k_m",
    "q5_k"    : "alias for q5_k_m",
    "q5_0"    : "Higher accuracy, higher resource usage and slower inference.",
    "q5_1"    : "Even higher accuracy, resource usage and slower inference.",
    "q5_k_s"  : "Uses Q5_K for all tensors",
    "q6_k"    : "Uses Q8_K for all tensors",
    "iq2_xxs" : "2.06 bpw quantization",
    "iq2_xs"  : "2.31 bpw quantization",
    "iq3_xxs" : "3.06 bpw quantization",
    "q3_k_xs" : "3-bit extra small quantization",
}

Manually saving to GGUF

First save your model to 16bit:

model.save_pretrained_merged("merged_model", tokenizer, save_method = "merged_16bit",)

Then use the terminal and do:

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUBLAS=1 make all -j
pip install gguf protobuf

Then follow the steps at https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model using the model name "merged_model" to merge to GGUF.

Evaluation Loop - also OOM or crashing.

Set the trainer settings for evaluation to:

SFTTrainer(
    args = TrainingArguments(
        fp16_full_eval = True,
        per_device_eval_batch_size = 2,
        eval_accumulation_steps = 4,
        evaluation_strategy = "steps",
        eval_steps = 1,
    ),
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,

This will cause no OOMs and make it somewhat faster with no upcasting to float32.

Chat Templates

Assuming your dataset is a list of list of dictionaries like the below:

[
    [{'from': 'human', 'value': 'Hi there!'},
     {'from': 'gpt', 'value': 'Hi how can I help?'},
     {'from': 'human', 'value': 'What is 2+2?'}],
    [{'from': 'human', 'value': 'What's your name?'},
     {'from': 'gpt', 'value': 'I'm Daniel!'},
     {'from': 'human', 'value': 'Ok! Nice!'},
     {'from': 'gpt', 'value': 'What can I do for you?'},
     {'from': 'human', 'value': 'Oh nothing :)'},],
]

You can use our get_chat_template to format it. Select chat_template to be any of zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth, and use mapping to map the dictionary values from, value etc. map_eos_token allows you to map <|im_end|> to EOS without any training.

from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

from datasets import load_dataset
dataset = load_dataset("philschmid/guanaco-sharegpt-style", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)

You can also make your own custom chat templates! For example our internal chat template we use is below. You must pass in a tuple of (custom_template, eos_token) where the eos_token must be used inside the template.

unsloth_template = \
    "{{ bos_token }}"\
    "{{ 'You are a helpful assistant to the user\n' }}"\
    "{% endif %}"\
    "{% for message in messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ '>>> User: ' + message['content'] + '\n' }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}"\
        "{{ '>>> Assistant: ' }}"\
    "{% endif %}"
unsloth_eos_token = "eos_token"

tokenizer = get_chat_template(
    tokenizer,
    chat_template = (unsloth_template, unsloth_eos_token,), # You must provide a template and EOS token
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

2x Faster Inference

Unsloth supports natively 2x faster inference. All QLoRA, LoRA and non LoRA inference paths are 2x faster. This requires no change of code or any new dependencies.

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 64)

NotImplementedError: A UTF-8 locale is required. Got ANSI

See https://github.com/googlecolab/colabtools/issues/3409

In a new cell, run the below:

import locale
locale.getpreferredencoding = lambda: "UTF-8"

Use Unsloth LoRA Adapter with Ollama

Read this 3 step guide, which details how to use LLama.Cpp to convert Unsloth Lora Adapter to GGML(.bin) and use it in Ollama: https://medium.com/p/edadb6d9e0f0

This article was written by Sarin Suriyakoon.

Ollama Guide - Unsloth FastLanguageModel

This guide provides information on how to set the fine-tuned model we trained using unsloth from a Google Colab training notebook and call the model locally via the Ollama cli.

This Ollama guide was written by Jed Tiotuico

Prerequisites

To successfully run the fine-tuned model, we need:

  1. Hugging Face account
  2. A Base unsloth model - for this guide, we have chosen unsloth/tinyllama as the base model
  3. A basic understanding of the unsloth FastLanguageModel. In particular, fine-tuning unsloth/tinyllama. We recommend their Google Colab training notebooks on huggingface for more information on the training data
  4. The Lora adapters that were saved online via the huggingface hub
  5. A working local ollama installation: as of writing, we used 0.1.32, but it should work from later versions.
    • ollama --version
    • ollama version is 0.1.32

Training

To recall, we provided some training code using unsloth FastLanguageModel. Please note that we can log in on huggingface on Google Colab by setting our API token as a secret token labeled “HF_TOKEN”

import os
from google.colab import userdata
hf_token = userdata.get("HF_TOKEN")
os.environ['HF_TOKEN'] = hf_token

We then run the cli command below to login

!huggingface-cli login --token $HF_TOKEN

To check our token is working, run

!huggingface-cli whoami

Below is a sample training code from the Unsloth notebook

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/tinyllama", # "unsloth/tinyllama" for 16bit loading
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

Moreover, we used the training code below. We provided dataset and eval_dataset for our training data, which had only one text column.

from trl import SFTTrainer
from transformers import TrainingArguments
from transformers.utils import logging
logging.set_verbosity_info()

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = eval_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = True, # Packs short sequences together to save time!
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.1,
        num_train_epochs = 2,
        learning_rate = 2e-5,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.1,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

trainer_stats = trainer.train()

Then, we should be able to run our inference, as shown below.

FastLanguageModel.for_inference(model)
inputs = tokenizer(
[
"""
<s>
Q:
What is the capital of France?
A:
"""
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 1000, use_cache = True)

print(tokenizer.batch_decode(outputs))

Lastly, below, we demonstrate how to save the model online via huggingface

model.push_to_hub_merged(“myhfusername/my-model", tokenizer, save_method = "lora")

Installation

When we wrote part of this guide we merely took from the page below https://rentry.org/llama-cpp-conversions#setup

1. Build llama.cpp

Clone the llama.cpp repository using

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp

llama.cpp has Python scripts that we need to run, so we need to pip install its dependencies

pip install -r requirements.txt

Now, let us build our local llama.cpp

make clean && make all -j

For anyone with nvidia GPUs make clean && LLAMA_CUDA=1 make all -j

2. Clone our huggingface base model and the Lora adapters from huggingface hub we uploaded earlier, where we used the push_to_hub_merged() function

From the llama.cpp folder let us clone our base model.

git clone https://huggingface.co/unsloth/tinyllama

Next, we clone our Lora model

git clone https://huggingface.co/myhfusername/my-model

3. GGUF conversion

We now need to convert both the base model and the Lora adapters.

python convert.py tinyllama --outtype f16 --outfile tinyllama.f16.gguf

4. GGUF conversion of Lora adapters

python convert-lora-to-ggml.py my-model

If the conversion succeeds, the last lines from our output should be

Converted my-model/adapter_config.json and my-model/adapter_model.safetensors to my-model/ggml-adapter-model.bin

5. Merge our gguf base model and adapter model using the commandexport-lora

--model-base - is the gguf model --model-out - is the new gguf model --lora is the adapter model

export-lora --model-base tinyllama.f16.gguf --model-out tinyllama-my-model.gguf --lora my-model/ggml-adapter-model.bin

Lastly we quantize the merged model

quantize tinyllama-my-model.gguf tinyllama-my-model.Q8_0.gguf Q8_0

6. Create ollama Modelfile

FROM tinyllama-my-model.gguf

### Set the system message
SYSTEM """
You are a super helpful helper.
"""

PARAMETER stop <s>
PARAMETER stop </s>

7. Create a Modelfile

ollama create my-model -f Modelfile

8. Test command

ollama run my-model "<s>\nQ: \nWhat is the capital of France?\nA:\n"