# **Distilling Step by Step**

In [16]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
from datasets import load_dataset

### **Loading Student Model**

In [30]:
# Choose the teacher (LLM) and student model
STUDENT_MODEL = "google-bert/bert-base-uncased"

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(STUDENT_MODEL)
model

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [18]:
# Load a small Hugging Face dataset (Change this to your preferred dataset)
dataset = load_dataset("sst2", split="train[:10]")

Using the latest cached version of the dataset since sst2 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/simranjeetsingh1497/.cache/huggingface/datasets/sst2/default/0.0.0/8d51e7e4887a4caaa95b3fbebbf53c0490b58bbb (last modified on Thu Feb 27 14:11:45 2025).


In [42]:
dataset

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 10
})

The history saving thread hit an unexpected error (OperationalError('unable to open database file')).History will not be written to the database.


### TEACHER MODEL

In [19]:
from langchain_ollama import ChatOllama

# Initialize the chat model
llm_engine = ChatOllama(
    model="deepseek-r1:1.5b",
    base_url="http://localhost:11434",
    temperature=0.3
)


def generate_rationale(input_text):
    """
    Uses Ollama's DeepSeek R1 model to generate a step-by-step rationale for the given input.
    """
    prompt = f"Explain step-by-step reasoning before answering: {input_text}"
    
    response = llm_engine.invoke(prompt)  # Using LangChain's invoke method
    
    return response.content if hasattr(response, "content") else response

print(generate_rationale("Explain AI in one sentence."))

<think>
Okay, so I need to explain AI in one sentence. Hmm, where do I start? Let me think about what AI is. AI stands for Artificial Intelligence, right? It's like a machine that can perform tasks that typically require human intelligence. So maybe something like "Artificial Intelligence enables machines to perform tasks that typically require human intelligence." That sounds good, but let me check if it covers all aspects.

Wait, does it cover things like learning, problem-solving, creativity? Yeah, AI uses various techniques to learn from data and make decisions or take actions without being explicitly programmed. So the sentence should mention that AI allows machines to learn, adapt, and perform tasks that require human-like reasoning.

I think I got it. The key points are: AI enables machines to do things like learning, problem-solving, creativity, and decision-making. So putting it all together in one sentence would be something like "Artificial Intelligence enables machines to p

In [20]:
# Prepare dataset with rationales
def process_data(example):
    input_text = example["sentence"]  # Change this depending on your dataset format
    rationale = generate_rationale(input_text)
    label = example["label"]
    
    # Tokenize input and rationale
    input_enc = tokenizer(input_text, truncation=True, padding="max_length", max_length=256)
    rationale_enc = tokenizer(rationale, truncation=True, padding="max_length", max_length=256)
    
    return {
        "input_ids": input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "labels": label,
        "rationale_ids": rationale_enc["input_ids"],
        "rationale_mask": rationale_enc["attention_mask"]
    }

# Apply function to dataset
processed_dataset = dataset.map(process_data)

Map: 100%|██████████| 10/10 [08:26<00:00, 50.66s/ examples]


In [45]:
processed_dataset

Dataset({
    features: ['idx', 'sentence', 'label', 'input_ids', 'attention_mask', 'labels', 'rationale_ids', 'rationale_mask'],
    num_rows: 10
})

In [22]:
from datasets import Dataset

# Assuming 'dataset' is your Dataset object
processed_dataset.save_to_disk('preprocessed_dataset')

Saving the dataset (1/1 shards): 100%|██████████| 10/10 [00:00<00:00, 564.36 examples/s]


In [23]:
from datasets import load_from_disk

# Load the dataset from the saved directory
processed_dataset = load_from_disk('preprocessed_dataset')

In [25]:
processed_dataset

Dataset({
    features: ['idx', 'sentence', 'label', 'input_ids', 'attention_mask', 'labels', 'rationale_ids', 'rationale_mask'],
    num_rows: 10
})

In [32]:
training_args = TrainingArguments(
    output_dir="./results",  # Directory to save the model and checkpoints
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    push_to_hub=False
)

In [40]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        rationale_ids = inputs.pop("rationale_ids", None)
        
        outputs = model(**inputs)
        
        loss_fn = torch.nn.CrossEntropyLoss()
        label_loss = loss_fn(outputs.logits, labels)
        
        if rationale_ids is not None:
            rationale_outputs = model(input_ids=rationale_ids, attention_mask=inputs["attention_mask"])
            rationale_loss = loss_fn(rationale_outputs.logits, rationale_ids)
            loss = label_loss + 0.5 * rationale_loss  # Weighted loss
        else:
            loss = label_loss
        
        return (loss, outputs) if return_outputs else loss

trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    eval_dataset=processed_dataset
)

In [39]:
processed_dataset

Dataset({
    features: ['idx', 'sentence', 'label', 'input_ids', 'attention_mask', 'labels', 'rationale_ids', 'rationale_mask'],
    num_rows: 10
})

In [None]:
trainer.train()
trainer.save_model("./results")
print("✅ Distillation Complete! Smaller model saved.")

# For AutoTrain

In [46]:
import pandas as pd

# Prepare data for the DataFrame
data = {
    "text": [],
    "rationale": [],
    "target": []
}

for example in dataset:
    input_text = example["sentence"]
    label = example["label"]
    rationale = generate_rationale(input_text)
    
    data["text"].append(input_text)
    data["rationale"].append(rationale)
    data["target"].append(label)

# Create DataFrame
df = pd.DataFrame(data)

# Save to CSV
df.to_csv("train.csv", index=False)