In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, pipeline
from datasets import load_dataset
import evaluate

from trainer.distillation import DistillationTrainer, DistillationTrainingArguments

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
clinc = load_dataset("clinc_oos", "plus")
intents = clinc["test"].features["intent"]
num_labels = intents.num_classes

Found cached dataset clinc_oos (/Users/Tony/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
teacher_ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
student_ckpt = "distilbert-base-uncased"

In [6]:
student_tokenizer = AutoTokenizer.from_pretrained(student_ckpt)

def tokenize_text(batch):
    return student_tokenizer(batch["text"], truncation=True)

clinc_enc = clinc.map(tokenize_text, batched=True, remove_columns=["text"])
clinc_enc = clinc_enc.rename_column("intent", "labels")

Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-be08a2d98145e176.arrow
Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-1c9e99ec23fdf840.arrow
Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-008f33dd2111f5e7.arrow


In [7]:
batch_size = 48

finetuned_ckpt = "distilbert-base-uncased-finetuned-clinc"
# student_training_args = DistillationTrainingArguments(
#     output_dir=finetuned_ckpt,
#     evaluation_strategy="epoch",
#     num_train_epochs=1,
#     learning_rate=2e-5,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     alpha=1,
#     weight_decay=0.01)

student_training_args = DistillationTrainingArguments(
    output_dir=finetuned_ckpt,
    evaluation_strategy="steps",
    max_steps=50,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    alpha=1,
    weight_decay=0.01)


student_training_args.logging_steps = len(clinc_enc['train']) // batch_size
student_training_args.disable_tqdm = False
student_training_args.save_steps = 1e9
# student_training_args.log_level = 40

%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [8]:
pipe = pipeline("text-classification", model=teacher_ckpt)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_ckpt).to(device)

In [9]:
id2label = pipe.model.config.id2label
label2id = pipe.model.config.label2id

In [10]:
student_config = AutoConfig.from_pretrained(student_ckpt,
                                            num_labels=num_labels, 
                                            id2label=id2label,
                                            label2id=label2id)
student_model = AutoModelForSequenceClassification.from_pretrained(student_ckpt, config=student_config).to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.bias', 'pre_classifi

In [11]:
acc_metric = evaluate.load("accuracy")

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return acc_metric.compute(predictions=predictions, references=labels)

In [12]:
distilbert_trainer = DistillationTrainer(model=student_model,
                                         teacher_model=teacher_model,
                                         args=student_training_args,
                                         train_dataset=clinc_enc['train'],
                                         eval_dataset=clinc_enc['validation'],
                                         compute_metrics=compute_metrics,
                                         tokenizer=student_tokenizer)

In [13]:
distilbert_trainer.evaluate()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[34m[1mwandb[0m: Currently logged in as: [33mtw581[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'eval_loss': 5.023814678192139,
 'eval_accuracy': 0.0064516129032258064,
 'eval_runtime': 37.0674,
 'eval_samples_per_second': 83.631,
 'eval_steps_per_second': 1.754}

In [14]:
distilbert_trainer.train()



Step,Training Loss,Validation Loss


TrainOutput(global_step=50, training_loss=5.007057189941406, metrics={'train_runtime': 58.5683, 'train_samples_per_second': 40.978, 'train_steps_per_second': 0.854, 'total_flos': 13261186118880.0, 'train_loss': 5.007057189941406, 'epoch': 0.16})

In [15]:
distilbert_trainer.evaluate()

{'eval_loss': 4.98302698135376,
 'eval_accuracy': 0.024516129032258065,
 'eval_runtime': 37.3429,
 'eval_samples_per_second': 83.014,
 'eval_steps_per_second': 1.741,
 'epoch': 0.16}