# Imports

In [24]:
!pip install datasets transformers --quiet



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizerFast

# Check device


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)


Device: cuda


# Load AG News Dataset

In [3]:
dataset = load_dataset("ag_news")


In [4]:
train_set = dataset['train']
test_set = dataset['test']

In [5]:
print(f"Train size: {len(train_set)}, Test size: {len(test_set)}")
print("Example:", train_set[0])

Train size: 120000, Test size: 7600
Example: {'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}


In [6]:
train_set.shape

(120000, 2)

In [7]:
train_set.column_names

['text', 'label']

# Tokenizer

In [8]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# Tokenization + Padding + Truncation


In [9]:
def tokenize(batch):
    return tokenizer(
        batch['text'],
        padding='max_length',
        truncation=True,
        max_length=20
    )

## Map_style

In [10]:
train_dataset = train_set.map(tokenize, batched=True)
test_dataset = test_set.map(tokenize, batched=True)


# Set Format for PyTorch

In [11]:
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# DataLoader


In [12]:
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Check a batch


In [13]:
batch = next(iter(train_loader))
batch

{'label': tensor([0, 2, 3, 3, 2, 2, 3, 2, 3, 0, 1, 1, 1, 0, 1, 0, 2, 3, 0, 3, 0, 1, 0, 0,
         0, 1, 3, 1, 0, 1, 2, 0]),
 'input_ids': tensor([[  101,  2845,  4946, 10544, 18516, 13593,  2291,  1037,  4595, 14548,
           4244,  1010,  2030,  2055,  1032,  1002,  4090,  1010,  2001,   102],
         [  101,  3514,  2379,  2015,  1032,  1002,  2753,  2006,  7387,  4425,
          15508,  2414,  1006, 26665,  1007,  1011,  2152,  1011,  3909,   102],
         [  101, 26408,  1010,  9980,  2136,  2039,  2006,  3036,  9980,  1001,
           4464,  1025,  1055, 14841,  6767,  3669,  3036, 12646,  3208,   102],
         [  101,  3604,  5930,  1024, 16396,  3436,  2250,  3604,  1005,  1055,
          16635,  4254,  3795, 12959,  2003,  8701,  3604, 14345,  4969,   102],
         [  101,  2813,  2395,  2275,  2000,  2330,  2091,  1006, 26665,  1007,
          26665,  1011,  1057,  1012,  1055,  1012,  6661,  2020,  3517,   102],
         [  101,  2057, 10532, 25293, 20330,  3822,  1011

In [14]:
print("Input IDs shape:", batch['input_ids'].shape)
print("Labels shape:", batch['label'].shape)

Input IDs shape: torch.Size([32, 20])
Labels shape: torch.Size([32])
