# 🧠 BERT Text Classification
Fine-tune BERT for text classification using HuggingFace Transformers.

## 📦 Install dependencies

In [None]:
!pip install transformers datasets scikit-learn torch

## 🔍 Load dataset

In [None]:
from datasets import load_dataset

# Load IMDb for binary classification or switch to AG News for multi-class
dataset = load_dataset('imdb')
dataset = dataset.shuffle(seed=42)
dataset['train'][0]

## 🧹 Preprocess text

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

encoded = dataset.map(tokenize, batched=True)
encoded.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

## 🧠 Define model

In [None]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

## 🏋️ Fine-tune

In [None]:
from transformers import TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    return {
        'accuracy': accuracy_score(labels, preds),
        'f1': f1_score(labels, preds)
    }

args = TrainingArguments(
    output_dir='bert-output',
    evaluation_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=10,
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded['train'].select(range(1000)),  # use a subset for speed
    eval_dataset=encoded['test'].select(range(500)),
    compute_metrics=compute_metrics,
)

trainer.train()

## 📈 Evaluate

In [None]:
metrics = trainer.evaluate()
print(metrics)

## 🔮 Predict on new text

In [None]:
text = "This movie was fantastic!"
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
label = torch.argmax(outputs.logits).item()
print("Predicted label:", label)