In [3]:
import os
# where you want all HF files (models, tokenizers, caches, etc.) to live:
os.environ['HF_HOME'] = ""
os.environ["HF_TOKEN"] = "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"

inference

In [1]:
#!/usr/bin/env python
"""
Stable inference for Gemma-3-12B-IT + AG-News LoRA
--------------------------------------------------
Weights   : 4-bit NF4 (base)  +  float32 (LoRA)
Compute   : float32
Hardware  : 3× RTX 3090 (24 GiB each)
Returns   : 'World' | 'Sports' | 'Business' | 'Sci/Tech'
"""

# 0 ─ Environment -----------------------------------------------------------
import os, torch
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"        # edit if fewer GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32      = True

# 1 ─ Libraries -------------------------------------------------------------
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# 2 ─ Constants -------------------------------------------------------------
BASE_ID  = "google/gemma-3-12b-it"
ADAPTER  = "gemma3-agnews-lora/adapter"
LABELS   = ["World", "Sports", "Business", "Sci/Tech"]

# 3 ─ 4-bit base model (weights fp32, compute fp32) -------------------------
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_use_double_quant = True,
    bnb_4bit_compute_dtype    = torch.float32,   # compute in fp32 → no overflow
)

base = AutoModelForCausalLM.from_pretrained(
    BASE_ID,
    torch_dtype         = torch.float32,        # ← KEEP full precision weights
    attn_implementation = "eager",
    device_map          = "auto",
    quantization_config = bnb_cfg,
    trust_remote_code   = True,
)

model = PeftModel.from_pretrained(base, ADAPTER, torch_dtype=torch.float32)
model.eval()

print("✓ model loaded   (param dtype =", next(model.parameters()).dtype, ")")

# 4 ─ Tokeniser & label-IDs -----------------------------------------------
tok = AutoTokenizer.from_pretrained(BASE_ID)
tok.pad_token = tok.eos_token
LABEL_IDS = [tok(lbl, add_special_tokens=False).input_ids[0] for lbl in LABELS]

# 5 ─ Classifier -----------------------------------------------------------
@torch.no_grad()
def classify_article(text: str) -> str:
    sys = {"role":"system","content":[{"type":"text",
           "text":"You are a helpful assistant. Answer with exactly one label "
                  "from [World, Sports, Business, Sci/Tech]."}]}
    usr = {"role":"user","content":[{"type":"text",
           "text":f"Classify the following news article:\n\n{text}"}]}

    bundle = tok.apply_chat_template(
        [sys, usr], tokenize=True, add_generation_prompt=True,
        return_dict=True, return_tensors="pt"
    )
    bundle = {k: v.to(next(model.parameters()).device) for k,v in bundle.items()}

    # forward pass – fp32 compute, no overflow
    logits = model(**bundle).logits[:, -1, :]          # [1, vocab]
    probs  = logits[0, LABEL_IDS]
    if torch.isnan(probs).any():
        raise RuntimeError("NaNs persist -- adapter file is corrupt.")

    return LABELS[probs.argmax().item()]

# 6 ─ Demo -----------------------------------------------------------------
if __name__ == "__main__":
    art = ("Nvidia’s quarterly revenue soared 265 % year-on-year "
           "thanks to AI demand.")
    print("Prediction:", classify_article(art))

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

✓ model loaded   (param dtype = torch.float32 )
Prediction: Sci/Tech


In [3]:
# 5 ─ Classifier -----------------------------------------------------------
@torch.no_grad()
def classify_article(text: str) -> str:
    sys = {"role":"system","content":[{"type":"text",
           "text":"You are a helpful assistant. Answer with exactly one label "
                  "from [World, Sports, Business, Sci/Tech]."}]}
    usr = {"role":"user","content":[{"type":"text",
           "text":f"Classify the following news article:\n\n{text}"}]}

    bundle = tok.apply_chat_template(
        [sys, usr], tokenize=True, add_generation_prompt=True,
        return_dict=True, return_tensors="pt"
    )
    bundle = {k: v.to(next(model.parameters()).device) for k,v in bundle.items()}

    # forward pass – fp32 compute, no overflow
    logits = model(**bundle).logits[:, -1, :]          # [1, vocab]
    probs  = logits[0, LABEL_IDS]
    if torch.isnan(probs).any():
        raise RuntimeError("NaNs persist -- adapter file is corrupt.")

    return LABELS[probs.argmax().item()]

# 6 ─ Demo -----------------------------------------------------------------
if __name__ == "__main__":
    art = ("Economy is booming")
    print("Prediction:", classify_article(art))

Prediction: Business


test & debug

ui

In [1]:
#!/usr/bin/env python
"""
Gradio web demo for Gemma‑3‑12B‑IT + AG‑News LoRA classifier
-----------------------------------------------------------
Weights   : 4‑bit NF4 (base)  +  float32 (LoRA)
Hardware  : 3× RTX 3090 (24 GiB each)
Returns   : one of 4 labels with probabilities

Run:
    conda activate science   # or your venv
    pip install gradio transformers peft bitsandbytes datasets accelerate torch>=2.2
    CUDA_VISIBLE_DEVICES=1,2,3 python gemma_agnews_gradio.py
"""

# 0 ─ Environment -----------------------------------------------------------
import os, torch, gradio as gr
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")          # edit if needed

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32      = True

# 1 ─ Libraries -------------------------------------------------------------
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# 2 ─ Constants -------------------------------------------------------------
BASE_ID  = "google/gemma-3-12b-it"
ADAPTER  = "gemma3-agnews-lora/adapter"
LABELS   = ["World", "Sports", "Business", "Sci/Tech"]

# 3 ─ 4‑bit base model (weights fp32, compute fp32) -------------------------
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_use_double_quant = True,
    bnb_4bit_compute_dtype    = torch.float32,
)

print("Loading Gemma‑3‑12B base …")
base = AutoModelForCausalLM.from_pretrained(
    BASE_ID,
    torch_dtype         = torch.float32,
    attn_implementation = "eager",
    device_map          = "auto",
    quantization_config = bnb_cfg,
    trust_remote_code   = True,
)
model = PeftModel.from_pretrained(base, ADAPTER, torch_dtype=torch.float32)
model.eval()
print("✓ Model & LoRA adapter ready →", next(model.parameters()).device)

# 4 ─ Tokeniser & label‑IDs --------------------------------------------------
tok = AutoTokenizer.from_pretrained(BASE_ID)
tok.pad_token = tok.eos_token
LABEL_IDS = [tok(lbl, add_special_tokens=False).input_ids[0] for lbl in LABELS]

# 5 ─ Classifier -------------------------------------------------------------
@torch.no_grad()
def classify_article(text: str):
    """Return a dict {label: probability}. Compatible with gradio.Label."""
    if not text or len(text.strip()) < 10:
        return "Please paste a news article ≥ 10 characters."

    sys = {"role":"system","content":[{"type":"text",
           "text":"You are a helpful assistant. Answer with exactly one label "
                  "from [World, Sports, Business, Sci/Tech]."}]}
    usr = {"role":"user","content":[{"type":"text",
           "text":f"Classify the following news article:\n\n{text}"}]}

    bundle = tok.apply_chat_template(
        [sys, usr], tokenize=True, add_generation_prompt=True,
        return_dict=True, return_tensors="pt"
    )
    bundle = {k: v.to(next(model.parameters()).device) for k,v in bundle.items()}

    logits = model(**bundle).logits[:, -1, :]          # [1, vocab]
    probs  = torch.softmax(logits[0, LABEL_IDS], dim=-1)

    # Build dict for gradio.Label
    return {lbl: float(p) for lbl, p in zip(LABELS, probs)}

# 6 ─ Gradio Interface -------------------------------------------------------
demo = gr.Interface(
    fn          = classify_article,
    inputs      = gr.Textbox(lines=12, label="Paste a news article"),
    outputs     = gr.Label(num_top_classes=4, label="Predicted class"),
    title       = "Gemma‑3 News Classifier",
    description = (
        "A lightweight 4‑bit Gemma‑3‑12B model fine‑tuned via QLoRA on \n"
        "AG‑News (500‑sample demo). Paste any English news article to get the\n"
        "predicted category (World / Sports / Business / Sci/Tech)."
    ),
    examples=[
        ["Nvidia’s quarterly revenue soared 265 % year‑on‑year thanks to AI demand."],
        ["Real Madrid lifted their 15th Champions League trophy after a 2‑0 win."],
    ],
    cache_examples=False,
)

# 7 ─ Launch -----------------------------------------------------------------
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)


Loading Gemma‑3‑12B base …


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

✓ Model & LoRA adapter ready → cuda:0
* Running on local URL:  http://0.0.0.0:7860
* To create a public link, set `share=True` in `launch()`.
