In [1]:
import torch

# Check if MPS (Metal Performance Shaders) is available and set the device accordingly
device = torch.device("mps" if torch.backends.mps.is_built() else "cpu")
print("Using device:", device)

Using device: mps


### Note
In last course, we used a pre-trained model `distilbert-base-uncased-finetuned-sst-2-english` which was fine-tuned on Stanford Sentiment Treebank dataset. So it has a good generalization on binary classification. In this course, we use a BERT before fine-tune to learn the process. 

In [2]:
from datasets import load_dataset

dataset = load_dataset("imdb")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [3]:
from transformers import AutoTokenizer

model_id = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

encoded_dataset = dataset.map(preprocess_function, batched=True)

In [4]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [5]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    push_to_hub=False,
    dataloader_pin_memory=False,
)

In [7]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Trainer(


In [8]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.6949,0.693148


TrainOutput(global_step=12500, training_loss=0.6971036328125, metrics={'train_runtime': 81821.5351, 'train_samples_per_second': 0.306, 'train_steps_per_second': 0.153, 'total_flos': 3311684966400000.0, 'train_loss': 0.6971036328125, 'epoch': 1.0})

In [22]:
rawModel = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2).to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
test_sentences = ["I hate and fuck this movie!", "This was a terrible film."]

# Tokenize
inputs = tokenizer(test_sentences, return_tensors="pt", padding=True, truncation=True)

# Move to same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# Predict
trainedModel = trainer.model
outputs1 = model(**inputs)
outputs2 = rawModel(**inputs)

print("Outputs from the trained model:", torch.softmax(outputs1.logits, dim=-1))
print("Outputs from the raw model:", torch.softmax(outputs2.logits, dim=-1))

Outputs from the trained model: tensor([[0.4993, 0.5007],
        [0.4993, 0.5007]], device='mps:0', grad_fn=<SoftmaxBackward0>)
Outputs from the raw model: tensor([[0.5018, 0.4982],
        [0.4969, 0.5031]], device='mps:0', grad_fn=<SoftmaxBackward0>)
