# Fine Tuning DistilBERT, MobileBERT and TinyBERT for Fake News Detection

## 1. Load the Fake news dataset

In [None]:
!pip install -U transformers
!pip install -U accelerate
!pip install -U datasets
!pip install -U bertviz
!pip install -U umap-learn
!pip install seaborn --upgrade

!pip install -U openpyxl

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import pandas as pd

df = pd.read_excel("https://github.com/laxmimerit/All-CSV-ML-Data-Files-Download/raw/master/fake_news.xlsx")

In [None]:
df

In [None]:
df.isnull().sum()
df = df.dropna()

df.isnull().sum()

In [None]:
df

In [None]:
df.shape

df['label'].value_counts()

## 2. Dataset Analysis

In [None]:
import matplotlib.pyplot as plt

In [None]:
label_counts = df['label'].value_counts(ascending=True)
label_counts.plot.barh()
plt.title("Frequency of Classes")
plt.show()

In [None]:
# 1.5 tokens per word on average
df['title_tokens'] = df['title'].apply(lambda x: len(x.split())*1.5)
df['text_tokens'] = df['text'].apply(lambda x: len(x.split())*1.5)


fig, ax = plt.subplots(1,2, figsize=(15,5))

ax[0].hist(df['title_tokens'], bins=50, color = 'skyblue')
ax[0].set_title("Title Tokens")

ax[1].hist(df['text_tokens'], bins=50, color = 'orange')
ax[1].set_title("Text Tokens")

plt.show()

## 3. Data Loader & Train, Test, Val split

In [None]:
from sklearn.model_selection import train_test_split

# 70% for training, 20% test, 10% validation
train, test = train_test_split(df, test_size=0.3, stratify=df['label'])
test, validation = train_test_split(test, test_size=1/3, stratify=test['label'])

train.shape, test.shape, validation.shape, df.shape

In [None]:
from datasets import Dataset, DatasetDict

dataset = DatasetDict(
    {
        "train": Dataset.from_pandas(train, preserve_index=False),
        "test": Dataset.from_pandas(test, preserve_index=False),
        "validation": Dataset.from_pandas(validation, preserve_index=False)
    }
)

dataset

## 4. Data Tokenization using Distilbert


In [None]:
from transformers import AutoTokenizer

text = "We are trying to tokenize this text :)"

model_ckpt = "distilbert-base-uncased"
distilbert_tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
distilbert_tokens = distilbert_tokenizer.tokenize(text)


In [None]:
distilbert_tokens

In [None]:
distilbert_tokenizer

In [None]:
def tokenize(batch):
    temp = distilbert_tokenizer(batch['title'], padding=True, truncation=True)
    return temp

print(tokenize(dataset['train'][:2]))

In [None]:
encoded_dataset = dataset.map(tokenize, batch_size=None, batched=True)

In [None]:
encoded_dataset

## 5. Build the distilbert model

In [None]:
from transformers import AutoModelForSequenceClassification, AutoConfig
import torch

label2id = {"Real": 0, "Fake": 1}
id2label = {0:"Real", 1:"Fake"}

model_ckpt = "distilbert-base-uncased"
# model_ckpt = "google/mobilebert-uncased"
# model_ckpt = "huawei-noah/TinyBERT_General_4L_312D"


num_labels = len(label2id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = AutoConfig.from_pretrained(model_ckpt, label2id=label2id, id2label=id2label)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, config=config).to(device)


In [None]:
model.config

## 6. Fine-tune the model

In [None]:
!pip install evaluate


In [None]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics_evaluate(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
from transformers import TrainingArguments

batch_size = 32
training_dir = "train_dir"



training_args = TrainingArguments(
    output_dir=training_dir,
    overwrite_output_dir=True,
    num_train_epochs=2,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    eval_strategy="epoch",   # ✅ correct argument name
    report_to="none"               # ✅ disables wandb/tensorboard
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    compute_metrics=compute_metrics_evaluate,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['validation'],
    tokenizer=distilbert_tokenizer
)



In [None]:
trainer.train()

In [None]:
trainer.save_model("/content/drive/MyDrive/llm_finetuning_transformers/Fake_News_Detection/distilbert-base-uncased-news-detection-model")

## 7. Evaluate the model

In [None]:
preds_output = trainer.predict(encoded_dataset['test'])


In [None]:
preds_output.metrics

In [None]:
y_pred = np.argmax(preds_output.predictions, axis=1)
y_true = encoded_dataset['test'][:]['label']

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred, target_names=list(label2id)))

## 8. Benchmarking (Bert, DistilBert, MobileBert, TinyBert)

In [None]:
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)

    return {"accuracy": acc, "f1": f1}

In [None]:
model_dict = {
    "bert-base": "bert-base-uncased",
    "distilbert": "distilbert-base-uncased",
    "mobilebert": "google/mobilebert-uncased",
    "tinybert": "huawei-noah/TinyBERT_General_4L_312D"
}

def train_model(model_name):
    model_ckpt = model_dict[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
    config = AutoConfig.from_pretrained(model_ckpt, label2id=label2id, id2label=id2label)
    model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, config=config).to(device)

    def local_tokenizer(batch):
        temp = tokenizer(batch['title'], padding=True, truncation=True)
        return temp

    encoded_dataset = dataset.map(local_tokenizer, batched=True, batch_size=None)

    trainer = Trainer(
                model=model,
                compute_metrics=compute_metrics,
                train_dataset=encoded_dataset['train'],
                eval_dataset=encoded_dataset['validation'],
                tokenizer=tokenizer
            )

    trainer.train()

    preds = trainer.predict(encoded_dataset['test'])

    return preds.metrics


import time
model_performance = {}
for model_name in model_dict:
    print("\n\n")
    print("Training Model: ", model_name)

    start = time.time()
    result = train_model(model_name)
    end = time.time()

    model_performance[model_name] = {model_name:result, "time taken": end-start}

In [None]:
model_performance

In [None]:
# Extract into DataFrame
rows = []
for model, results in model_performance.items():
    inner = results[model]
    rows.append({
        "Model": model,
        "Accuracy": inner['test_accuracy'],
        "F1 Score": inner['test_f1'],
        "Runtime (s)": inner['test_runtime'],
        "Training Time (s)": results['time taken']
    })

df = pd.DataFrame(rows)

#3 bar charts side by side
fig, axes = plt.subplots(1, 3, figsize=(18,5))

# Accuracy vs F1 (bar)
df.plot(x="Model", y=["Accuracy", "F1 Score"], kind="bar", ax=axes[0])
axes[0].set_title("Accuracy vs F1")
axes[0].set_ylabel("Score")
axes[0].grid(axis="y", linestyle="--", alpha=0.7)

# Runtime
df.plot(x="Model", y="Runtime (s)", kind="bar", color="orange", ax=axes[1], legend=False)
axes[1].set_title("Runtime Comparison")
axes[1].set_ylabel("Runtime (s)")
axes[1].grid(axis="y", linestyle="--", alpha=0.7)

# Training Time
df.plot(x="Model", y="Training Time (s)", kind="bar", color="green", ax=axes[2], legend=False)
axes[2].set_title("Training Time Comparison")
axes[2].set_ylabel("Training Time (s)")
axes[2].grid(axis="y", linestyle="--", alpha=0.7)

plt.tight_layout()
plt.show()




In [None]:
# Line chart with zoom for Accuracy & F1
plt.figure(figsize=(8,5))
plt.plot(df["Model"], df["Accuracy"], marker="o", label="Accuracy")
plt.plot(df["Model"], df["F1 Score"], marker="o", label="F1 Score")

plt.ylim(0.94, 0.97)
plt.title("Accuracy vs F1 (Zoomed)")
plt.ylabel("Score")
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
plt.show()