In [1]:
import torch
from torchtext.datasets import AG_NEWS

In [2]:
train_iter = AG_NEWS(root='../../dat/pyt', split='train')

In [3]:
test_iter = AG_NEWS(root='../../dat/pyt', split='test')

In [4]:
next(train_iter)

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

In [5]:
next(train_iter)

(3,
 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')

In [6]:
next(train_iter)

(3,
 "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.")

In [7]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [8]:
tokenizer = get_tokenizer('basic_english')

In [9]:
def yield_tokens(data_iter):
  for _, text in data_iter:
    yield tokenizer(text)

In [10]:
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [11]:
len(vocab)

95808

In [12]:
vocab(['here', 'is', 'an', 'example'])

[475, 21, 30, 5297]

In [13]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

In [15]:
text_pipeline('here is an example')

[475, 21, 30, 5297]

In [16]:
label_pipeline('9')

8

In [17]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
def collate_batch(batch):
  ys, xs, offsets = [], [], [0]
  for (y, x) in batch:
    ys.append(label_pipeline(y))
    ts = torch.tensor(text_pipeline(x), dtype=torch.int64)
    xs.append(ts)
    offsets.append(ts.size(0))
  ys = torch.tensor(ys, dtype=torch.int64)
  offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
  xs = torch.cat(xs)
  return ys.to(device), xs.to(device), offsets.to(device)

In [29]:
train_iter = AG_NEWS(root='../../dat/pyt', split='train')
training_data_loader = DataLoader(train_iter, batch_size=32, shuffle=False, collate_fn=collate_batch)

In [19]:
len(training_data_loader)

3750

In [20]:
len(training_data_loader.dataset)

120000

In [24]:
from torch import nn

class NeuralNetwork(nn.Module):
  def __init__(self, vocab_size, embed_dim, num_class):
    super(NeuralNetwork, self).__init__()
    self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
    self.linear = nn.Linear(embed_dim, num_class)
    self.init_weights()

  def init_weights(self):
    self.embedding.weight.data.uniform_(-0.5, 0.5)
    self.linear.weight.data.uniform_(-0.5, 0.5)
    self.linear.bias.data.zero_()
  
  def forward(self, text, offsets):
    embeded = self.embedding(text, offsets) # compute mean vectors of all words in the text
    return self.linear(embeded)

In [30]:
num_class = len(set([label for (label, _) in train_iter]))

In [31]:
# 1 : World
# 2 : Sports
# 3 : Business
# 4 : Sci/Tec
print(num_class)

4


In [32]:
train_iter = AG_NEWS(root='../../dat/pyt', split='train')
next(train_iter)

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

In [34]:
vocab_size = len(vocab)
print(vocab_size)

95808


In [35]:
model = NeuralNetwork(vocab_size, 64, num_class).to(device)

In [36]:
model

NeuralNetwork(
  (embedding): EmbeddingBag(95808, 64, mode=mean)
  (linear): Linear(in_features=64, out_features=4, bias=True)
)

In [37]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [40]:
criterion = nn.CrossEntropyLoss() # useful when the output is an unnormalized score vector.

In [39]:
import time

In [41]:
def train(dataloader, model, criterion, optimizer):
  N = len(dataloader.dataset)
  model.train()
  start_time = time.time()
  for batch, (ys, xs, offsets) in enumerate(dataloader):
    zs = model(xs, offsets)
    loss = criterion(zs, ys)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if batch % 500 == 0:
      elapsed = time.time() - start_time
      loss, current = loss.item(), batch * len(xs)
      print(f"loss: {loss:>7f} [{current:>5d}/{N:>5d}], elapsed: {elapsed}")
      start_time = time.time()

In [42]:
def test(dataloader, model, criterion):
  N = len(dataloader.dataset)
  num_batch = len(dataloader)
  model.eval()
  test_loss, correct = 0, 0
  with torch.no_grad():
    for (ys, xs, offsets) in dataloader:
      zs = model(xs, offsets)
      test_loss += criterion(zs, ys).item()
      correct += (zs.argmax(1) == ys).type(torch.float).sum().item()
  test_loss /= num_batch
  correct /= N
  print(f"Test accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [43]:
epochs = 20 # 5 minutes on my MBP
for t in range(epochs):
  train_iter = AG_NEWS(root='../../dat/pyt', split='train')
  test_iter = AG_NEWS(root='../../dat/pyt', split='test')
  training_data_loader = DataLoader(train_iter, batch_size=32, shuffle=False, collate_fn=collate_batch)
  test_data_loader = DataLoader(test_iter, batch_size=32, shuffle=False, collate_fn=collate_batch)  
  print(f"Epoch {t+1}\n-------")
  train(training_data_loader, model, criterion, optimizer)
  test(test_data_loader, model, criterion)
print("Done.")

Epoch 1
-------
loss: 1.426393 [    0/120000], elapsed: 0.007916688919067383
loss: 1.370969 [641000/120000], elapsed: 2.0563488006591797
loss: 1.347753 [1315000/120000], elapsed: 1.908263921737671
loss: 1.324668 [2068500/120000], elapsed: 1.9025630950927734
loss: 1.374274 [2690000/120000], elapsed: 1.9186592102050781
loss: 1.381827 [3127500/120000], elapsed: 1.8854007720947266
loss: 1.351782 [4032000/120000], elapsed: 1.878354787826538
loss: 1.333173 [4788000/120000], elapsed: 1.878736972808838
Test accuracy: 30.2%, Avg loss: 1.371654 

Epoch 2
-------
loss: 1.388052 [    0/120000], elapsed: 0.004431247711181641
loss: 1.359332 [641000/120000], elapsed: 1.9160959720611572
loss: 1.324045 [1315000/120000], elapsed: 1.9002797603607178
loss: 1.305135 [2068500/120000], elapsed: 1.9554660320281982
loss: 1.360818 [2690000/120000], elapsed: 1.9280998706817627
loss: 1.357424 [3127500/120000], elapsed: 1.9295012950897217
loss: 1.333590 [4032000/120000], elapsed: 1.9556457996368408
loss: 1.316975 