# Phi-3-mini PTQ + QAT ternary demo

This notebook shows a compact, reproducible workflow for measuring balanced ternary compression
and latency wins on a real LLM checkpoint. It uses a CPU-first path by default, with optional
GPU acceleration if available.

**Model:** `microsoft/Phi-3-mini-4k-instruct` (3.8B params)
**Goals:**
- Measure FP16/FP32 baseline size, perplexity, and latency.
- Apply PTQ via `t81` and compare size + speed.
- (Optional) run a short QAT loop for quality retention.


## Results summary (fill after a run)

| Mode | Size (GiB) | Compression | Perplexity | Tok/s |
| --- | --- | --- | --- | --- |
| FP16/FP32 | - | - | - | - |
| PTQ (ternary) | - | - | - | - |
| QAT (ternary) | - | - | - | - |

Phi-3 GGUF benchmark (llama.cpp, TQ1_0, CPU-only):

| Bundle | Size (MiB) | Peak RSS (MiB) | Prompt ms/token | Prompt tok/s | Eval ms/token | Eval tok/s |
| --- | --- | --- | --- | --- | --- | --- |
| phi3-tq1-fixed12.gguf | 1481.96 | 2260.02 | 54.35 | 18.4 | 56.22 | 17.79 |

Small-model baseline (GPT-2, CPU, tiny eval/QAT):

| Mode | Size (GiB) | Compression | Perplexity | Tok/s |
| --- | --- | --- | --- | --- |
| FP16/FP32 | 0.46 | - | 163.64 | 34.81 |
| PTQ (ternary) | 0.01 | 38.7x | nan | 0.04 |
| QAT (ternary) | 0.01 | 38.7x | nan | 0.07 |


## Setup

Install dependencies (run once):

```bash
pip install ".[torch]" transformers datasets
```

Notes:
- CPU runs are supported on any machine, but expect longer runtimes.
- If you have a GPU, the same cells will run faster automatically.


In [None]:
import time
from dataclasses import dataclass

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling

import t81 as t8

MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
THRESHOLD = 0.45
MAX_EVAL_TOKENS = 2048
MAX_TOKENS_GENERATE = 128
RUN_QAT = False  # Set True to run a short QAT loop.

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

print(f"Device: {DEVICE}, dtype: {DTYPE}")


In [None]:
def bytes_to_gib(num_bytes: int) -> float:
    return num_bytes / (1024 ** 3)


def model_param_bytes(model: torch.nn.Module) -> int:
    return sum(p.numel() * p.element_size() for p in model.parameters())


def ternary_weight_bytes(model: torch.nn.Module) -> int:
    total = 0
    for module in model.modules():
        if isinstance(module, t8.Linear):
            rows, cols = module.weight.shape
            k_limbs = (cols + 47) // 48
            total += rows * k_limbs * 16
            if module.bias is not None:
                total += module.bias.numel() * module.bias.element_size()
    return total


def measure_generate_latency(model, tokenizer, prompt: str, max_new_tokens: int) -> dict:
    inputs = tokenizer(prompt, return_tensors="pt").to(next(model.parameters()).device)
    _ = model.generate(**inputs, max_new_tokens=1)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start = time.perf_counter()
    _ = model.generate(**inputs, max_new_tokens=max_new_tokens)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    elapsed = time.perf_counter() - start
    tokens = max_new_tokens
    return {"seconds": elapsed, "tok_per_sec": tokens / max(elapsed, 1e-6)}


@torch.inference_mode()
def perplexity(model, tokenizer, dataset, max_tokens: int) -> float:
    model.eval()
    text = "\n\n".join(dataset["text"])
    enc = tokenizer(text, return_tensors="pt")
    input_ids = enc["input_ids"][0][:max_tokens]
    input_ids = input_ids.unsqueeze(0).to(next(model.parameters()).device)
    labels = input_ids.clone()
    outputs = model(input_ids=input_ids, labels=labels)
    return torch.exp(outputs.loss).item()


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE == "cuda" else "cpu",
)
model.eval()

print("Loaded model")


In [None]:
baseline_bytes = model_param_bytes(model)
print(f"Baseline param size: {bytes_to_gib(baseline_bytes):.2f} GiB")

prompt = "Summarize balanced ternary quantization in one paragraph."
latency = measure_generate_latency(model, tokenizer, prompt, MAX_TOKENS_GENERATE)
print(f"Baseline latency: {latency['tok_per_sec']:.2f} tok/s")

wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
baseline_ppl = perplexity(model, tokenizer, wikitext, MAX_EVAL_TOKENS)
print(f"Baseline perplexity: {baseline_ppl:.2f}")


## PTQ: swap in ternary layers

`t81.convert` replaces every `nn.Linear` with `t81.nn.Linear` and caches packed ternary weights
on first inference, so you can measure size and latency immediately.


In [None]:
model = t8.convert(model, threshold=THRESHOLD, inplace=True)
model.eval()

ternary_bytes = ternary_weight_bytes(model)
print(f"Ternary packed size (weights + biases): {bytes_to_gib(ternary_bytes):.2f} GiB")
print(f"Compression ratio: {baseline_bytes / max(ternary_bytes, 1):.1f}x")

latency_ptq = measure_generate_latency(model, tokenizer, prompt, MAX_TOKENS_GENERATE)
print(f"PTQ latency: {latency_ptq['tok_per_sec']:.2f} tok/s")

ptq_ppl = perplexity(model, tokenizer, wikitext, MAX_EVAL_TOKENS)
print(f"PTQ perplexity: {ptq_ppl:.2f}")


## Optional QAT (short run)

Set `RUN_QAT = True` to run a short QAT loop on a small slice of Wikitext-2.
This is intentionally small to keep runtimes manageable.


In [None]:
if RUN_QAT:
    train_split = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
    tokenized = train_split.map(lambda x: tokenizer(x["text"], truncation=True), batched=True)
    tokenized = tokenized.remove_columns(["text"])
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    args = t8.TernaryTrainingArguments(
        output_dir="phi3-qat",
        per_device_train_batch_size=1,
        max_steps=100,
        logging_steps=20,
        save_steps=100,
        learning_rate=5e-5,
        ternary_threshold=THRESHOLD,
        ternary_warmup_steps=50,
    )

    trainer = t8.TernaryTrainer(
        model=model,
        args=args,
        train_dataset=tokenized,
        data_collator=data_collator,
    )
    trainer.train()
    model = trainer.model
    model.eval()

    qat_ppl = perplexity(model, tokenizer, wikitext, MAX_EVAL_TOKENS)
    latency_qat = measure_generate_latency(model, tokenizer, prompt, MAX_TOKENS_GENERATE)

    print(f"QAT perplexity: {qat_ppl:.2f}")
    print(f"QAT latency: {latency_qat['tok_per_sec']:.2f} tok/s")


## Summary table

Fill in the values (or capture them programmatically) to compare baseline vs PTQ vs QAT.

| Mode | Size (GiB) | Compression | Perplexity | Tok/s |
| --- | --- | --- | --- | --- |
| FP16/FP32 | - | - | - | - |
| PTQ (ternary) | - | - | - | - |
| QAT (ternary) | - | - | - | - |


## CLI equivalent

You can reproduce the PTQ conversion + GGUF export via CLI:

```bash
t81 convert microsoft/Phi-3-mini-4k-instruct phi3-t81 --threshold 0.45 --force-cpu-device-map

t81 gguf phi3-tq1.gguf --from-t81 phi3-t81 --quant TQ1_0 --validate
```


## CLI follow-up

Use the unified CLI to export a GGUF bundle or inspect available flags:


In [None]:
# Optional: inspect CLI helpers
!t81 --help

# Convert + export GGUF (adjust paths as needed)
!t81 convert microsoft/Phi-3-mini-4k-instruct phi3-t81 --threshold 0.45 --force-cpu-device-map
!t81 gguf phi3-tq1.gguf --from-t81 phi3-t81 --quant TQ1_0 --validate
