To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ‚≠ê <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ‚≠ê
</div>

To install Unsloth your local device, follow [our guide](https://docs.unsloth.ai/get-started/install-and-update). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2

### Unsloth

If you want to finetune Llama-3 2x faster and use 70% less VRAM, go to our [finetuning notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Alpaca.ipynb)!

In [2]:
from unsloth import FastLanguageModel

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
    "unsloth/gemma-7b-it-bnb-4bit",
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",

] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",
    max_seq_length = 8192,
    load_in_4bit = True,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
ü¶• Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.9: Fast Llama patching. Transformers: 4.56.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

In [3]:
from transformers import TextStreamer
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"}, # ShareGPT style
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((409

Change the "value" part to call the model!

Unsloth makes inference natively 2x faster!! No need to change or do anything!

#Just predict yes no maybe

In [50]:
# PubMedQA ‚Äî single-token label + Brier from logits (Unsloth one-by-one)

import re, gc, torch
from datasets import load_dataset
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report

DEVICE  = "cuda" if torch.cuda.is_available() else "cpu"
MAX_INP = 768      # lower -> faster prefill
SYSTEM  = "Answer ONLY with one word: yes, no, or maybe."
CLASSES = ["yes","no","maybe"]
CLASS_SET = set(CLASSES)

# ---------- prompt: make the first generated token be the label ----------
def build_messages(q, ctx):
    if isinstance(ctx, list): ctx = " ".join(ctx)
    return [{
        "from": "human",
        "value": (
            f"{SYSTEM}\n\nQuestion: {q}\n\nAbstract:\n{ctx}\n\n"
            "Answer strictly with one word.\n\n"
            "Label: "   # <- next token will be yes|no|maybe
        )
    }]

def prompt_text(row):
    return tokenizer.apply_chat_template(
        build_messages(row["question"], row["context"]),
        tokenize=False, add_generation_prompt=True
    )

# ---------- tiny helpers ----------
def parse_label_token(tok_text: str) -> str:
    lab = tok_text.strip().strip(",.?:;!").lower()
    return lab if lab in CLASS_SET else "maybe"

def _first_token_ids(strings):
    out = []
    for s in strings:
        ids = tokenizer(s, add_special_tokens=False).input_ids
        if ids: out.append(ids[0])
    return out

# include space/no-space + case variants (Llama tokenizers often use leading-space tokens)
CAND_IDS = {
    "yes":   _first_token_ids([" yes","Yes","yes"]),
    "no":    _first_token_ids([" no","No","no"]),
    "maybe": _first_token_ids([" maybe","Maybe","maybe"]),
}

def probs_from_first_step_logits(out_struct):
    logits = out_struct.scores[0][0]    # (vocab,)
    pv = torch.softmax(logits, dim=-1)
    mass = {
        lab: float(pv[torch.tensor(ids, device=pv.device)].sum().item()) if ids else 0.0
        for lab, ids in CAND_IDS.items()
    }
    Z = sum(mass.values()) + 1e-12
    return {k: v/Z for k, v in mass.items()}

def brier_multiclass_sum(prob_dict, gold_label, classes=CLASSES):
    # Sum version ranges [0, 2] for 3 classes (0 is perfect)
    return sum((prob_dict[c] - (1.0 if c == gold_label else 0.0))**2 for c in classes)

# ---------- data ----------
ds   = load_dataset("qiaojin/PubMedQA", "pqa_labeled")
pmqa = ds["train"]       # use "test" for reporting; change to "train" if you want
N    = len(pmqa)        # set smaller for a smoke test

preds, golds = [], []
brier_probs, brier_vals = [], []

for i in tqdm(range(N), desc="Label-only + Brier (one-by-one)", ncols=100):
    row  = pmqa[i]
    gold = row["final_decision"].lower()
    golds.append(gold)

    prompt = prompt_text(row)
    enc = tokenizer(prompt, return_tensors="pt",
                    padding=False, truncation=True, max_length=MAX_INP).to(DEVICE)

    with torch.inference_mode():
        out = model.generate(
            **enc,
            max_new_tokens=1,            # exactly the label token
            do_sample=False, temperature=0.0,
            use_cache=False,             # lower KV mem
            pad_token_id=tokenizer.eos_token_id,
            output_scores=True,          # <- logits for first step
            return_dict_in_generate=True
        )

    # decode ONLY the new token
    new_tok = out.sequences[0, enc.input_ids.shape[1]:]
    label_text = tokenizer.decode(new_tok, skip_special_tokens=True)
    pred = parse_label_token(label_text)
    preds.append(pred)

    # probs -> Brier
    probs = probs_from_first_step_logits(out)
    brier_probs.append(probs)
    brier_vals.append(brier_multiclass_sum(probs, gold))

    del enc, out, new_tok
    if (i+1) % 100 == 0 and torch.cuda.is_available():
        torch.cuda.empty_cache(); gc.collect()

# ---------- metrics ----------
print(f"\nAccuracy:  {accuracy_score(golds, preds):.4f}")
print(f"Macro-F1:  {f1_score(golds, preds, average='macro'):.4f}\n")
print(classification_report(golds, preds, digits=4))

print(f"\nMean Brier (sum, 0‚Äì2): {sum(brier_vals)/len(brier_vals):.6f}")

# peek a few
for j in range(min(5, N)):
    print(f"[{j:03d}] gold={golds[j]:<6} pred={preds[j]:<6}  probs={brier_probs[j]}")


Label-only + Brier (one-by-one):   0%|                                     | 0/1000 [00:00<?, ?it/s]


Accuracy:  0.7550
Macro-F1:  0.5323

              precision    recall  f1-score   support

       maybe     0.0833    0.0091    0.0164       110
          no     0.8292    0.6893    0.7528       338
         yes     0.7369    0.9438    0.8276       552

    accuracy                         0.7550      1000
   macro avg     0.5498    0.5474    0.5323      1000
weighted avg     0.6962    0.7550    0.7131      1000


Mean Brier (sum, 0‚Äì2): 0.379200
[000] gold=yes    pred=yes     probs={'yes': 0.9913655215201216, 'no': 0.006861966567378376, 'maybe': 0.0017725119114997648}
[001] gold=no     pred=yes     probs={'yes': 0.6560849789535165, 'no': 0.31466205518141327, 'maybe': 0.029252965864069204}
[002] gold=yes    pred=yes     probs={'yes': 0.9299441903633273, 'no': 0.054342478162795284, 'maybe': 0.01571333147287286}
[003] gold=no     pred=no      probs={'yes': 0.2179722313618752, 'no': 0.6223282827320895, 'maybe': 0.15969948590503305}
[004] gold=yes    pred=yes     probs={'yes': 0.9763041

# predict + compute readability, brier score and  rogue1

In [16]:
pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=4e3a81ad309d5a461f21fbf1f16a2fb1b04b8a3781f7f9a8f894aec72f7e250e
  Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [52]:
import re, gc, torch
from datasets import load_dataset
from tqdm.auto import tqdm
from statistics import mean

# Optional deps (install once in Colab if missing)
try:
    from evaluate import load as load_metric
except Exception:
    !pip -q install evaluate
    from evaluate import load as load_metric

try:
    import textstat
except Exception:
    !pip -q install textstat
    import textstat

DEVICE  = "cuda" if torch.cuda.is_available() else "cpu"


In [59]:
# ==== PubMedQA rationale-only eval (ROUGE-1 + readability) ====
# Assumes:
#   - model, tokenizer already loaded (Unsloth)
#   - tokenizer has chat template set (e.g., get_chat_template(..., "llama-3.1"))
#   - FastLanguageModel.for_inference(model) already called



# ---- knobs (no abstract shrinking) ----
MAX_INP = 1024    # cap for (prompt + abstract) tokens
MAX_NEW = 160     # generation budget for rationale

# ---- Single-pass prompt: ONLY rationale required ----
INSTR = (
    "You are answering PubMedQA. "
    "Write a concise explanation in plain language based only on the abstract. "
    "End with: 'This is not medical advice.'\n\n"
    "Return answers in this EXACT format:\n"
    "Reason:\n"
    "<your explanation>"
)

def build_messages(q, ctx):
    if isinstance(ctx, list): ctx = " ".join(ctx)
    return [{
        "from": "human",
        "value": f"{INSTR}\n\nQuestion: {q}\n\nAbstract:\n{ctx}\n\nReason:\n"
    }]

def apply_tpl(msgs):
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

# ---- helpers to clean and extract the rationale ----
ASSIST_RE = re.compile(
    r'^(?:<\|assistant\|>|<\|start_header_id\|>\s*assistant\s*<\|end_header_id\|>|assistant:?)[\s\r\n]*',
    re.IGNORECASE
)
def strip_assistant_header(text: str) -> str:
    text = text.lstrip()
    text = ASSIST_RE.sub("", text)
    lines = [ln.strip() for ln in text.splitlines()]
    while lines and re.fullmatch(r'(?:assistant:?|<\|assistant\|>)', lines[0], re.IGNORECASE):
        lines.pop(0)
    return "\n".join(lines).strip()

def parse_reason(out_text: str) -> str:
    t = strip_assistant_header(out_text)
    m = re.search(r'(?mi)^Reason:\s*(.*)$', t, flags=re.DOTALL)
    if m:
        return m.group(1).strip()
    # fallback: everything after the last "Reason:"
    cut = t.lower().rfind("reason:")
    return t[cut+len("reason:"):].strip() if cut != -1 else t

# ---- load data ----
ds   = load_dataset("qiaojin/PubMedQA", "pqa_labeled")
pmqa = ds["train"]      # change to "test" if you want test-set numbers
N    = len(pmqa)        # set smaller for a smoke test, e.g., N = 100

# ---- prebuild prompts (saves a little time) ----
prompts = [apply_tpl(build_messages(ex["question"], ex["context"])) for ex in pmqa]

# ---- generation loop ----
refs_long, hyps_long = [], []
for i in tqdm(range(N), desc="Generating rationales", ncols=100):
    ex = pmqa[i]
    refs_long.append((ex.get("long_answer") or "").strip())

    p = prompts[i]
    enc = tokenizer(
        p, return_tensors="pt",
        padding=False, truncation=True, max_length=MAX_INP
    ).to(DEVICE)

    with torch.inference_mode():
        out = model.generate(
            **enc,
            max_new_tokens=MAX_NEW,
            do_sample=False, temperature=0.0,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
        )

    # decode only new tokens
    new_tokens = out.sequences[0, enc.input_ids.shape[1]:]
    gen_text   = tokenizer.decode(new_tokens, skip_special_tokens=True)
    rationale  = parse_reason(gen_text)
    hyps_long.append(rationale)

    # free per-iteration
    del enc, out, new_tokens
    if (i+1) % 100 == 0 and torch.cuda.is_available():
        torch.cuda.empty_cache(); gc.collect()

# ---- evaluation: ROUGE-1 + readability ----
rouge = load_metric("rouge")
r = rouge.compute(predictions=hyps_long, references=refs_long, use_stemmer=True)
print(f"\nROUGE-1 (rationales vs gold long_answer), n={len(hyps_long)}: {float(r['rouge1']):.6f}")

fre  = [textstat.flesch_reading_ease(h) for h in hyps_long]
fk   = [textstat.flesch_kincaid_grade(h) for h in hyps_long]
smog = [textstat.smog_index(h) for h in hyps_long]
print("\nReadability of generated rationales (mean):")
print(f"  Flesch Reading Ease:   {mean(fre):.2f}")
print(f"  Flesch-Kincaid Grade:  {mean(fk):.2f}")
print(f"  SMOG Index:            {mean(smog):.2f}")

# ---- show first 3 model answers for sanity ----
for j in range(min(3, N)):
    ex  = pmqa[j]
    q   = ex["question"]
    ref = refs_long[j]
    hyp = hyps_long[j]
    print(f"\n[{j:03d}]")
    print("Q:", q)
    print("Gold (first 220ch):", (ref[:220] + "‚Ä¶") if len(ref) > 220 else ref)
    print("Hyp  (first 220ch):", (hyp[:220] + "‚Ä¶") if len(hyp) > 220 else hyp)


Generating rationales:   0%|                                                | 0/100 [00:00<?, ?it/s]


ROUGE-1 (rationales vs gold long_answer), n=100: 0.318035

Readability of generated rationales (mean):
  Flesch Reading Ease:   36.79
  Flesch-Kincaid Grade:  12.71
  SMOG Index:            14.38

[000]
Q: Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?
Gold (first 220ch): Results depicted mitochondrial dynamics in vivo as PCD progresses within the lace plant, and highlight the correlation of this organelle with other organelles during developmental PCD. To the best of our knowledge, this ‚Ä¶
Hyp  (first 220ch): Mitochondria play a role in remodelling lace plant leaves during programmed cell death by changing their dynamics and distribution as the cells undergo cell death. This process involves a gradient of mitochondrial change‚Ä¶

[001]
Q: Landolt C and snellen e acuity: differences in strabismus amblyopia?
Gold (first 220ch): Using the charts described, there was only a slight overestimation of visual acuity by the Snellen E compared to th

In [56]:
ref

'"Aquagenic maladies" could be a pediatric form of the aquagenic urticaria.'

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ‚≠êÔ∏è <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ‚≠êÔ∏è

  This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme)
</div>
