## Training the Teacher Model
The first step in the pipeline is to train a teacher model on the SST-2 dataset. This model will be used to classify the synthetic data generated by the generator.  


In [22]:
import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding

datasets = load_dataset("glue", "sst2", split="train[:256]")
datasets = datasets.rename_column("label", "labels")

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small", use_fast=True)
def tokenize_function(examples):
    return tokenizer(examples["sentence"], truncation=True, max_length=32)

tokenized_datasets = datasets.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])




Map:   0%|          | 0/256 [00:00<?, ? examples/s]

In [24]:
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-small")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
print(f"Model loaded on {device}")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on mps


In [26]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="tmp/teacher_poc",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=1,
    learning_rate=2e-5,
    eval_strategy="steps",
    eval_steps=50,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets.shuffle(seed=0).select(range(64)),  # tiny eval slice
    data_collator=DataCollatorWithPadding(tokenizer),
)

In [27]:
trainer.train()

Step,Training Loss,Validation Loss
50,No log,0.669309
100,No log,0.646186


TrainOutput(global_step=128, training_loss=0.6919310092926025, metrics={'train_runtime': 24.1107, 'train_samples_per_second': 10.618, 'train_steps_per_second': 5.309, 'total_flos': 349921897560.0, 'train_loss': 0.6919310092926025, 'epoch': 1.0})