# Simple PyTorch Text Classification (AG News)

This notebook demonstrates a basic text classification pipeline using PyTorch, Transformers tokenizers, and the Datasets library. It aims to be runnable within ~10 minutes on Google Colab using a simple custom model.

## 1. Installation

Install necessary libraries if they are not already available in the Colab environment.
- `datasets`: For loading datasets easily.
- `transformers`: For tokenizers (and potentially pre-trained models, though we use a simple one here).
- `scikit-learn`: For evaluation metrics.
- `accelerate`: Often helpful for optimizing PyTorch training, especially with Transformers.
- `pandas`, `seaborn`, `matplotlib`: For data handling and visualization.

In [None]:
# --- 1. Installation ---
print("Installing required packages...")
!pip install torch datasets transformers scikit-learn accelerate pandas seaborn matplotlib -q
print("Installation complete.")

## 2. Imports

Import libraries needed for the script.

In [None]:
# --- 2. Imports ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import time
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

print("Imports complete.")

## 3. Setup and Configuration

Set up the environment (device), define key hyperparameters and configuration parameters, and start a timer to track execution time.

In [None]:
# --- 3. Setup and Configuration ---
# Start timer to track execution time.
start_time = time.time()

# Set device to GPU (cuda) if available, otherwise use CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration parameters
MODEL_NAME = "bert-base-uncased" # Using tokenizer from this model
DATASET_NAME = "ag_news"
MAX_LENGTH = 128 # Max sequence length for tokenizer
BATCH_SIZE = 64
EMBED_DIM = 100 # Dimension for our simple embedding layer
LEARNING_RATE = 5e-4 # Learning rate for the optimizer
EPOCHS = 3 # Number of training epochs (keep low for speed)
TRAIN_SUBSET_SIZE = 10000 # Use a subset for faster training
TEST_SUBSET_SIZE = 1000 # Use a subset for faster evaluation

print("Setup complete.")

## 4. Load Dataset

Load the AG News dataset (a standard text classification benchmark with 4 classes: World, Sports, Business, Sci/Tech) using the `datasets` library. We load only subsets of the train and test splits for faster execution in this example.

In [None]:
# --- 4. Load Dataset ---
print(f"Loading dataset '{DATASET_NAME}'...")
dataset = load_dataset(DATASET_NAME, split={
    'train': f'train[:{TRAIN_SUBSET_SIZE}]',
    'test': f'test[:{TEST_SUBSET_SIZE}]'
})

# AG News class labels: 0: World, 1: Sports, 2: Business, 3: Sci/Tech
# Create a mapping for easy reference
label_map = {0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}
num_classes = len(dataset['train'].unique("label"))
print(f"Dataset loaded. Number of classes: {num_classes}")
print(f"Class labels: {label_map}")

## 5. Explore the Data

Let's look at a few raw examples from the training set to understand the text and corresponding labels.

In [None]:
# --- 5. Explore the Data ---
print("\n--- Sample Training Data ---")
# Display first 5 samples using Pandas DataFrame for better formatting
df_samples = pd.DataFrame(dataset['train'][:5])
df_samples['label_name'] = df_samples['label'].map(label_map)
print(df_samples[['text', 'label', 'label_name']])

## 6. Load Tokenizer

Load a tokenizer from the `transformers` library. We use the tokenizer associated with `bert-base-uncased`. The tokenizer converts raw text into numerical representations (token IDs) that the model can understand. It also handles tasks like splitting words into subwords if necessary (though less relevant for our simple embedding model here).

In [None]:
# --- 6. Load Tokenizer ---
print(f"\nLoading tokenizer '{MODEL_NAME}'...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
VOCAB_SIZE = tokenizer.vocab_size # Get the size of the tokenizer's vocabulary
print(f"Tokenizer loaded. Vocabulary size: {VOCAB_SIZE}")
print(f"Padding token ID: {tokenizer.pad_token_id}")

## 7. Preprocessing and Tokenization Visualization

Define a function to apply the tokenizer to our text data. This function will:
1.  Tokenize the text.
2.  Pad sequences to `MAX_LENGTH` so they all have the same size.
3.  Truncate sequences longer than `MAX_LENGTH`.

We then apply this function to the entire dataset using `.map()` for efficiency and set the format to PyTorch tensors. Finally, we visualize the tokenization output for one sample.

In [None]:
# --- 7. Preprocessing ---
# Define a function to tokenize the text data.
def preprocess_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH)

# Apply the preprocessing function to the dataset.
print("\nPreprocessing dataset (tokenizing)...")
encoded_dataset = dataset.map(preprocess_function, batched=True)

# Set the format to PyTorch tensors.
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

train_dataset = encoded_dataset['train']
test_dataset = encoded_dataset['test']
print("Preprocessing complete.")

# --- Tokenization Visualization ---
print("\n--- Sample Tokenized Data ---")
sample_processed = train_dataset[0]
print(f"Original Text: {dataset['train'][0]['text'][:100]}...")
print(f"Input IDs (sample): {sample_processed['input_ids'][:20]}...") # Show first 20 IDs
print(f"Attention Mask (sample): {sample_processed['attention_mask'][:20]}...") # 1 for real tokens, 0 for padding
# Decode the IDs back to tokens to see the result
tokens = tokenizer.convert_ids_to_tokens(sample_processed['input_ids'][:20])
print(f"Tokens (sample): {tokens}")

## 8. Create DataLoaders

Create PyTorch `DataLoader` objects. These efficiently load data in batches during training and evaluation, and can automatically shuffle the training data each epoch.

In [None]:
# --- 8. Create DataLoaders ---
print("\nCreating DataLoaders...")
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
print(f"DataLoaders created with batch size {BATCH_SIZE}.")
print(f"Number of batches in train_dataloader: {len(train_dataloader)}")
print(f"Number of batches in test_dataloader: {len(test_dataloader)}")

## 9. Define the Model

Define a simple text classification model using PyTorch's `nn.Module`.
1.  `nn.Embedding`: Converts input token IDs into dense vector representations (embeddings). `padding_idx` ensures padding tokens don't affect learning.
2.  Mean Pooling: Averages the embeddings of all non-padding tokens in a sequence to get a single vector representation for the entire sequence.
3.  `nn.Linear`: A fully connected layer that takes the pooled sequence representation and outputs scores for each of the possible classes.

In [None]:
# --- 9. Define the Model ---
class SimpleTextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=tokenizer.pad_token_id)
        # Linear layer
        self.fc = nn.Linear(embed_dim, num_class)

    def forward(self, batch):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # 1. Get Embeddings
        embedded = self.embedding(input_ids)

        # 2. Apply Masking and Mean Pooling
        mask_expanded = attention_mask.unsqueeze(-1).expand(embedded.size()).float()
        embedded = embedded * mask_expanded # Zero out padding embeddings
        sum_embeddings = torch.sum(embedded, 1)
        sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) # Count non-padding tokens
        mean_embeddings = sum_embeddings / sum_mask # Calculate mean

        # 3. Pass through Linear Layer
        return self.fc(mean_embeddings)

print("\nModel definition complete.")

## 10. Instantiate Model, Loss, and Optimizer

Create an instance of our `SimpleTextClassifier` model, define the loss function (`CrossEntropyLoss` for multi-class classification), and choose an optimizer (`Adam`). Move the model to the appropriate device (GPU or CPU).

In [None]:
# --- 10. Instantiate Model, Loss, and Optimizer ---
print("\nInstantiating model, loss function, and optimizer...")
model = SimpleTextClassifier(VOCAB_SIZE, EMBED_DIM, num_classes).to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Model, Loss, Optimizer instantiated.")

# Print model structure and parameter count
print("\nModel Architecture:")
print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal trainable parameters: {total_params:,}")

## 11. Training Loop

Train the model for the specified number of epochs. In each epoch:
1.  Iterate through batches provided by the `train_dataloader`.
2.  Move the batch data to the device.
3.  Clear previous gradients (`optimizer.zero_grad()`).
4.  Perform a forward pass to get model predictions.
5.  Calculate the loss between predictions and true labels.
6.  Perform a backward pass to compute gradients (`loss.backward()`).
7.  Update the model weights using the optimizer (`optimizer.step()`).
8.  Track and report the loss.

In [None]:
# --- 11. Training Loop ---
print("\n--- Starting Training ---")
model.train() # Set the model to training mode

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    epoch_loss = 0
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    for i, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        outputs = model(batch)
        loss = criterion(outputs, batch['label'])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        if (i + 1) % 50 == 0:
             print(f"  Batch {i+1}/{len(train_dataloader)}, Batch Loss: {loss.item():.4f}")

    epoch_end_time = time.time()
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} completed in {epoch_end_time - epoch_start_time:.2f} seconds.")
    print(f"Average Training Loss: {avg_epoch_loss:.4f}")

print("\n--- Training Finished ---")

## 12. Evaluation

Evaluate the trained model on the test dataset.
1.  Set the model to evaluation mode (`model.eval()`).
2.  Iterate through the `test_dataloader`.
3.  Disable gradient calculations (`torch.no_grad()`) for efficiency.
4.  Get model predictions for each batch.
5.  Collect all predictions and true labels.
6.  Calculate and display evaluation metrics (accuracy, classification report, confusion matrix).

In [None]:
# --- 12. Evaluation ---
print("\n--- Starting Evaluation ---")
model.eval() # Set the model to evaluation mode

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(batch)
        predictions = torch.argmax(outputs, dim=1)
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(batch['label'].cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=label_map.values(), digits=4)
conf_matrix = confusion_matrix(all_labels, all_preds)

print("\n--- Evaluation Results ---")
print(f"Test Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(report)

# Plot Confusion Matrix
print("\nConfusion Matrix:")
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=label_map.values(), yticklabels=label_map.values())
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

print("\n--- Evaluation Finished ---")

## 13. Final Timings

Calculate and display the total execution time for the notebook.

In [None]:
# --- 13. Final Timings ---
end_time = time.time()
total_time = end_time - start_time
print(f"\nTotal execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes).")

## 14. Optional: Example Prediction

Demonstrate how to use the trained model to classify a single, new sentence.

In [None]:
# --- Optional: Example Prediction ---
print("\n--- Example Prediction ---")
text = "The government announced new economic policies today." # Example sentence
model.eval() # Ensure model is in evaluation mode
with torch.no_grad(): # No need to track gradients for prediction
    # 1. Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=MAX_LENGTH)
    
    # 2. Move inputs to the correct device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # 3. Create the batch structure expected by the model
    batch = {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}
    
    # 4. Get model output (scores)
    output = model(batch)
    
    # 5. Get predicted class index
    prediction_idx = torch.argmax(output, dim=1).item()
    
    # 6. Map index to label name
    predicted_label = label_map.get(prediction_idx, 'Unknown')

    print(f"Sentence: '{text}'")
    print(f"Predicted class index: {prediction_idx}")
    print(f"Predicted label: {predicted_label}")
print("\n--- End of Notebook ---")