In [1]:
pip install datasets

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import accuracy_score, classification_report
from datasets import load_dataset
import numpy as np
import nltk
from nltk.tokenize import word_tokenize

# Download NLTK data
nltk.download('punkt')

# Load the AG News dataset
dataset = load_dataset("ag_news")

# NLTK Tokenizer Function
def nltk_tokenizer(text):
    return word_tokenize(text)

# Convert to PyTorch tensors
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_length):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        tokenized_text = [self.vocab.get(token, self.vocab['<UNK>']) for token in text]
        if len(tokenized_text) < self.max_length:
            tokenized_text += [self.vocab['<PAD>']] * (self.max_length - len(tokenized_text))
        else:
            tokenized_text = tokenized_text[:self.max_length]
        return torch.tensor(tokenized_text, dtype=torch.long), torch.tensor(label, dtype=torch.long)

# Build vocabulary
def build_vocab(dataset, tokenizer):
    vocab = {'<PAD>': 0, '<UNK>': 1}
    for example in dataset:
        tokens = tokenizer(example['text'])
        for token in tokens:
            if token not in vocab:
                vocab[token] = len(vocab)
    return vocab

# Tokenize and build vocab
train_texts = [nltk_tokenizer(example['text']) for example in dataset['train']]
train_labels = [example['label'] for example in dataset['train']]
vocab = build_vocab(dataset['train'], nltk_tokenizer)

# Set max length for padding
max_length = 128

# Create dataset
full_dataset = AGNewsDataset(train_texts, train_labels, vocab, max_length)

# Split training set into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Prepare test dataset
test_texts = [nltk_tokenizer(example['text']) for example in dataset['test']]
test_labels = [example['label'] for example in dataset['test']]
test_dataset = AGNewsDataset(test_texts, test_labels, vocab, max_length)


[nltk_data] Downloading package punkt to /home/IAIS/rrao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = self.embedding(x)
        output, hidden = self.rnn(x)
        return self.fc(hidden[-1])


In [3]:
# Training Parameters
BATCH_SIZE = 64
EPOCHS = 10
EMBED_DIM = 128
HIDDEN_DIM = 256
OUTPUT_DIM = 4
LR = 0.001

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model, loss function, and optimizer
model = RNNModel(len(vocab), EMBED_DIM, HIDDEN_DIM, OUTPUT_DIM).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)



In [4]:
# Training loop
for epoch in range(EPOCHS):
    model.train()
    for texts, labels in train_loader:
        texts, labels = texts.to(device), labels.to(device)
        
        outputs = model(texts)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_labels = []
    val_preds = []
    with torch.no_grad():
        for texts, labels in val_loader:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            _, preds = torch.max(outputs, 1)
            val_labels.extend(labels.cpu().numpy())
            val_preds.extend(preds.cpu().numpy())
    
    val_accuracy = accuracy_score(val_labels, val_preds)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Validation Accuracy: {val_accuracy:.4f}')

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in test_loader:
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=['World', 'Sports', 'Business', 'Sci/Tech'])

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)


Epoch 1/10, Validation Accuracy: 0.2477
Epoch 2/10, Validation Accuracy: 0.2405
Epoch 3/10, Validation Accuracy: 0.2455
Epoch 4/10, Validation Accuracy: 0.2500
Epoch 5/10, Validation Accuracy: 0.2359
Epoch 6/10, Validation Accuracy: 0.2547
Epoch 7/10, Validation Accuracy: 0.2642
Epoch 8/10, Validation Accuracy: 0.2630
Epoch 9/10, Validation Accuracy: 0.2533
Epoch 10/10, Validation Accuracy: 0.2387
Test Accuracy: 0.2330
Classification Report:
              precision    recall  f1-score   support

       World       0.00      0.00      0.00      1900
      Sports       0.24      0.49      0.32      1900
    Business       0.00      0.00      0.00      1900
    Sci/Tech       0.22      0.45      0.30      1900

    accuracy                           0.23      7600
   macro avg       0.12      0.23      0.16      7600
weighted avg       0.12      0.23      0.16      7600



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
