In [None]:
from transformers import Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizerFast
from datasets import Dataset, load_metric 
import torch, numpy as np 

In [None]:
# Preparing the dataset  
train_ds = Dataset.from_dict({"text":train_texts, "label": train_labels })
valid_ds = Dataset.from_dict({"text": valid_texts, "label": valid_labels}) 
test_ds = Dataset.from_dict({"text": test_texts, "label": test_labels})

In [None]:
# Tokenizing 
tokenizer = DistilBertTokenizerFast.from_pretrained("distil-bert-uncased") 
def tokenize(batch): 
    return tokenizer(batch["text"],padding=True,truncation=True) 
train_ds = train_ds.map(tokenize, batched=True) #batched for speeding 
valid_ds = valid_ds.map(tokenize, batched=True)  
test_ds = test_ds.map(tokenize, batched=True) 

In [None]:
# Loading the model 
model =DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased"
).to("cuda" if torch.cuda.is_available() else "cpu") 

In [None]:
metric = load_metric("accuracy")#Loading the accuracy metric function
def compute_metrics(eval_pred): 
    logits, labels = eval_pred 
    predictions = np.argmax(logits, axis =-1) #always target the last dimension 
    return metric.compute(predictions=predictions, reference=labels)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",                  # where to save model & checkpoints
    num_train_epochs=3,                      # number of training epochs
    per_device_train_batch_size=16,          # batch size per GPU/CPU
    evaluation_strategy="epoch",             # evaluate at the end of each epoch
    logging_steps=100,                       # log metrics every 100 steps
    save_strategy="epoch"                     # save checkpoint at end of each epoch
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics
)

# Train and evaluate 
trainer.train() 