In [None]:
# ChatGPT

In [None]:
!pwd

In [None]:
#!pip install transformers
#!pip install bitsandbytes
#!pip install accelerate

In [16]:
import os
import math
import random
from typing import List, Dict
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch, torch.nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,               # teacher
    AutoModelForSequenceClassification,  # student
    AutoTokenizer,
    AutoConfig,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup
)


In [38]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEACHER_MODEL_NAME = "openai/gpt-oss-20b"  # or "openai/gpt-oss-120b"
STUDENT_MODEL_NAME = "distilroberta-base"  # tiny options: "prajjwal1/bert-tiny", "distilbert-base-uncased"

MAX_LEN = 256
BATCH_SIZE = 8
LR = 3e-5
EPOCHS = 3


In [32]:
def get_incar_assist_dataset():
    input_file = "s3://data-daizika-com/incar_assist/data/intent_classification/incar_assist_samples.csv"
    df = pd.read_csv(input_file)
    labels_df = df[['Label']]
    input_file = "s3://data-daizika-com/incar_assist/data/intent_classification/incar_assist_labels.csv"
    labels_df = pd.read_csv(input_file)
    labels2id_dict = {rec['label']: rec['id'] for rec in labels_df.to_dict(orient="records")}
    id2labels_dict = {rec['id']: rec['label'] for rec in labels_df.to_dict(orient="records")}
    df = df.set_index('Label').join(labels_df.set_index('label'), how="left").reset_index()
    df.columns = ['intent', 'text', 'label']
    df_dict = df[['text', 'label']].to_dict(orient="records")
    return df_dict, labels2id_dict, id2labels_dict
    
 # Example dataset: list of dicts with 'text' and 'label'
#train_examples = [{"text":"book me a flight","label":3}, ...]
all_examples, label2id, id2label = get_incar_assist_dataset()
all_texts = [rec['text'] for rec in all_examples]
all_labels = [rec['label'] for rec in all_examples]
verbalizers = {key:id2label[key].replace(" ", "-") for key in id2label}

train_texts, val_texts, train_labels, val_labels = train_test_split(all_texts, all_labels, test_size=0.20, random_state=42)


In [37]:
# -----------------------------
# 2) Toy dataset (replace with your own)
# -----------------------------
class IntentDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int]):
        self.texts = texts
        self.labels = labels

    def __len__(self): return len(self.texts)

    def __getitem__(self, idx):
        return {"text": self.texts[idx], "label": self.labels[idx]}

train_ds = IntentDataset(train_texts, train_labels)
val_ds   = IntentDataset(val_texts, val_labels)


In [None]:
# -----------------------------
# 3) Load teacher & student
# -----------------------------

teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME)
#teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL_NAME)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_MODEL_NAME, num_labels=len(all_labels), output_hidden_states=True, output_attentions=True
)
teacher_model.to(DEVICE).eval()
for p in teacher_model.parameters():
    p.requires_grad = False

student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL_NAME, use_fast=True)
student_model = AutoModelForSequenceClassification.from_pretrained(
    STUDENT_MODEL_NAME,
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id,
)
student_model.to(DEVICE)


MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.17G [00:00<?, ?B/s]

model-00000-of-00002.safetensors:   0%|          | 0.00/4.79G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.80G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:



# -----------------------------
# 4) Prompting & teacher soft labels
# -----------------------------
# We’ll treat distillation as next-token classification:
#   prompt = "Classify the driver command into one of: ...\nCommand: {text}\nLabel:"
# We read the logits for the *next token* and map them to our verbalizer token ids.
# NOTE: This is an approximation. For better fidelity, score a short phrase (sum logprobs across tokens).

label_first_token_ids: Dict[int, int] = {}
for i, phrase in verbalizers.items():
    # Get the first token id for the verbalizer under the *teacher tokenizer*.
    tok = teacher_tokenizer(phrase, add_special_tokens=False, return_tensors="pt")
    label_first_token_ids[i] = tok["input_ids"][0, 0].item()

def build_prompt(text: str) -> str:
    choices = ", ".join([id2label[i] for i in range(len(id2label))])
    return (
        "You are an expert intent classifier. "
        "Respond with exactly one label token from the set below.\n"
        f"Valid labels: {choices}\n\n"
        f"Command: {text}\nLabel:"
    )

@torch.no_grad()
def teacher_probs_for_batch(batch_texts: List[str]) -> torch.Tensor:
    """
    Returns a float tensor of shape [B, num_labels] with teacher probability distribution.
    """
    prompts = [build_prompt(t) for t in batch_texts]
    enc = teacher_tokenizer(
        prompts,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(DEVICE)

    # Get logits for the next token after the prompt
    out = teacher_model(**enc)
    # For causal LMs, the next-token logits are the last position for each sequence.
    next_token_logits = out.logits[:, -1, :]  # [B, vocab_size]
    probs = F.softmax(next_token_logits, dim=-1)

    # Gather probability mass at our verbalizer token ids
    indices = torch.tensor([label_first_token_ids[i] for i in range(len(id2label))], device=DEVICE)  # [num_labels]
    gathered = probs.index_select(dim=1, index=indices)  # [B, num_labels] but not aligned; need gather per index
    # index_select reorders columns to exactly those indices – that's what we want
    teacher_p = gathered  # [B, num_labels]
    # Normalize again in case of any numerical drift (usually unnecessary)
    teacher_p = teacher_p / teacher_p.sum(dim=1, keepdim=True)
    return teacher_p

# -----------------------------
# 5) Dataloaders
# -----------------------------
class Collator:
    def __init__(self, tokenizer):
        self.pad = DataCollatorWithPadding(tokenizer)

    def __call__(self, features):
        texts = [f["text"] for f in features]
        labels = torch.tensor([f["label"] for f in features], dtype=torch.long)
        enc = self.pad(student_tokenizer(texts, truncation=True))
        enc["labels"] = labels
        enc["texts"] = texts  # keep raw for teacher
        return enc

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=Collator(student_tokenizer))
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, collate_fn=Collator(student_tokenizer))

# -----------------------------
# 6) Distillation training loop
# -----------------------------
EPOCHS = 4
LR = 5e-5
WARMUP_STEPS = 0
ALPHA_KD = 0.9       # weight on KD loss
TEMPERATURE = 2.0    # temperature for KD
ALPHA_CE = 1.0 - ALPHA_KD

optimizer = torch.optim.AdamW(student_model.parameters(), lr=LR)
total_steps = EPOCHS * math.ceil(len(train_loader))
scheduler = get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, total_steps)

kl_div = nn.KLDivLoss(reduction="batchmean")  # expects log-probs vs probs

def kd_step(batch) -> Dict[str, float]:
    student_model.train()
    input_ids = batch["input_ids"].to(DEVICE)
    attention_mask = batch["attention_mask"].to(DEVICE)
    labels = batch["labels"].to(DEVICE)
    texts = batch["texts"]

    # 1) Teacher soft labels
    with torch.no_grad():
        t_probs = teacher_probs_for_batch(texts)  # [B, num_labels]
        t_probs_T = F.softmax(torch.log(t_probs + 1e-12) / TEMPERATURE, dim=-1)  # optional re-temp of teacher

    # 2) Student forward
    out = student_model(input_ids=input_ids, attention_mask=attention_mask)
    s_logits = out.logits  # [B, num_labels]

    # 3) KD loss (student log-softmax at T vs teacher probs at T)
    s_log_probs_T = F.log_softmax(s_logits / TEMPERATURE, dim=-1)
    loss_kd = kl_div(s_log_probs_T, t_probs_T) * (TEMPERATURE ** 2)

    # 4) Optional CE on hard labels
    loss_ce = F.cross_entropy(s_logits, labels)

    loss = ALPHA_KD * loss_kd + ALPHA_CE * loss_ce
    loss.backward()
    torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad(set_to_none=True)

    with torch.no_grad():
        preds = s_logits.argmax(dim=-1)
        acc = (preds == labels).float().mean().item()

    return {"loss": loss.item(), "kd": loss_kd.item(), "ce": loss_ce.item(), "acc": acc}

@torch.no_grad()
def evaluate() -> float:
    student_model.eval()
    total, correct = 0, 0
    for batch in val_loader:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits
        preds = logits.argmax(dim=-1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()
    return correct / max(1, total)

for epoch in range(1, EPOCHS + 1):
    logs = []
    for step, batch in enumerate(train_loader, start=1):
        info = kd_step(batch)
        logs.append(info)
        if step % 10 == 0 or step == len(train_loader):
            print(f"Epoch {epoch} | step {step}/{len(train_loader)} "
                  f"loss={info['loss']:.4f} kd={info['kd']:.4f} ce={info['ce']:.4f} acc={info['acc']:.3f}")
    val_acc = evaluate()
    print(f"Epoch {epoch} done. Val acc: {val_acc:.3f}")

# -----------------------------
# 7) Inference helper
# -----------------------------
@torch.no_grad()
def predict_intent(texts: List[str]) -> List[Dict]:
    enc = student_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
    logits = student_model(**enc).logits
    probs = F.softmax(logits, dim=-1)
    preds = probs.argmax(dim=-1).tolist()
    return [
        {
            "text": t,
            "pred_id": p,
            "pred_label": id2label[p],
            "probs": {id2label[i]: float(probs[j, i]) for i in range(len(id2label))}
        }
        for j, (t, p) in enumerate(zip(texts, preds))
    ]

examples = [
    "please open bluetooth",
    "switch off the bluetooth",
    "show me the camera",
    "close the door",
    "open the window"
]
print(predict_intent(examples))
