In [1]:
import torch
import spacy
import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets

In [2]:
spacy_en = spacy.load("en")

In [3]:
TEXT = data.Field(tokenize="spacy", batch_first=True)
LABEL = data.LabelField()

In [4]:
train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)

In [5]:
len(train_data), len(test_data)

(5452, 500)

In [6]:
for d in train_data:
    print(vars(d)["text"])
    print(vars(d)["label"])
    break

['How', 'did', 'serfdom', 'develop', 'in', 'and', 'then', 'leave', 'Russia', '?']
DESC


In [7]:
TEXT.build_vocab(train_data, max_size=25000)
LABEL.build_vocab(train_data)

In [8]:
LABEL.vocab.stoi

defaultdict(<function torchtext.vocab._default_unk_index>,
            {'ABBR': 5, 'DESC': 2, 'ENTY': 0, 'HUM': 1, 'LOC': 4, 'NUM': 3})

In [9]:
class Net(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_filters, filter_sizes, output_size, p, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.conv_layers = nn.ModuleList([nn.Conv1d(in_channels=embedding_size, out_channels=num_filters, 
                                                    kernel_size=fs) for fs in filter_sizes])
        
        self.fc = nn.Linear(num_filters * len(filter_sizes), output_size)

        self.dropout = nn.Dropout(p)

    def forward(self, text):
        embedded = self.embedding(text).permute(0, 2, 1)

        conved_n = [F.relu(conv(embedded)) for conv in self.conv_layers]
        pooled_n = [F.max_pool1d(conved, kernel_size=conved.shape[2]).squeeze(2) for conved in conved_n]

        pooled = self.dropout(torch.cat(pooled_n, dim=1))
        output = self.fc(pooled)

        return output

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
batch_size = 64
vocab_size = len(TEXT.vocab)
embedding_size = 100
num_filters = 100
filter_sizes = [2, 3, 4]
output_size = len(LABEL.vocab)
p = 0.5
pad_idx = TEXT.vocab.stoi["<pad>"]

In [11]:
device

device(type='cpu')

In [12]:
train_batches, test_batches = data.BucketIterator.splits((train_data, test_data), batch_size=batch_size, device = device)

In [13]:
for batch in train_batches:
    print(batch.text.shape, batch.label.shape)
    break

torch.Size([64, 20]) torch.Size([64])


In [14]:
net = Net(vocab_size, embedding_size, num_filters, filter_sizes, output_size, p, pad_idx).to(device)
net

Net(
  (embedding): Embedding(9343, 100)
  (conv_layers): ModuleList(
    (0): Conv1d(100, 100, kernel_size=(2,), stride=(1,))
    (1): Conv1d(100, 100, kernel_size=(3,), stride=(1,))
    (2): Conv1d(100, 100, kernel_size=(4,), stride=(1,))
  )
  (fc): Linear(in_features=300, out_features=6, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [15]:
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [16]:
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]])

    return acc.item()

In [17]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        net.train()
        for batch in tqdm.tqdm(batches, total=len(batches)):
            texts = batch.text.to(device)
            labels = batch.label.to(device)

            preds = net(texts)
            loss = loss_fn(preds, labels)
            acc = get_accuracy(preds, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Inference Loop:")
        net.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(batches, total=len(batches)):
                texts = batch.text.to(device)
                labels = batch.label.to(device)

                preds = net(texts)
                loss = loss_fn(preds, labels)
                acc = get_accuracy(preds, labels)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [18]:
def predict(net, text, min_len=4):
    net.eval()
    tokens = [t.text for t in spacy_en.tokenizer(text)]
    if len(tokens) < min_len:
        tokens += ["<pad>"] * (min_len - len(tokens))

    indices = [TEXT.vocab.stoi[t] for t in tokens]
    indices = torch.LongTensor(indices).unsqueeze(0).to(device)
    
    preds = net(indices)
    preds = preds.argmax(dim=1)
    
    print(LABEL.vocab.itos[preds.item()])

In [19]:
text = "what are you doing?"

In [26]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, test_batches, False)
    
    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    predict(net, text)
    print("")

  2%|▏         | 2/86 [00:00<00:05, 14.49it/s]

Train Loop:


100%|██████████| 86/86 [00:04<00:00, 21.04it/s]
100%|██████████| 8/8 [00:00<00:00, 164.23it/s]
  3%|▎         | 3/86 [00:00<00:03, 23.64it/s]



Inference Loop:


epoch: 0 | train_loss: 1.3760 | train_acc: 0.4542 | val_loss: 0.9114 | val_acc: 0.6883
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.17it/s]
100%|██████████| 8/8 [00:00<00:00, 175.42it/s]
  2%|▏         | 2/86 [00:00<00:04, 19.96it/s]



Inference Loop:


epoch: 1 | train_loss: 0.8555 | train_acc: 0.6857 | val_loss: 0.6416 | val_acc: 0.7895
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.35it/s]
100%|██████████| 8/8 [00:00<00:00, 177.88it/s]
  2%|▏         | 2/86 [00:00<00:05, 15.04it/s]



Inference Loop:


epoch: 2 | train_loss: 0.6569 | train_acc: 0.7552 | val_loss: 0.5609 | val_acc: 0.8137
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.46it/s]
100%|██████████| 8/8 [00:00<00:00, 176.23it/s]
  3%|▎         | 3/86 [00:00<00:03, 22.33it/s]



Inference Loop:


epoch: 3 | train_loss: 0.5112 | train_acc: 0.8229 | val_loss: 0.4939 | val_acc: 0.8431
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.25it/s]
100%|██████████| 8/8 [00:00<00:00, 185.69it/s]
  3%|▎         | 3/86 [00:00<00:03, 21.83it/s]



Inference Loop:


epoch: 4 | train_loss: 0.4305 | train_acc: 0.8565 | val_loss: 0.4778 | val_acc: 0.8334
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.51it/s]
100%|██████████| 8/8 [00:00<00:00, 170.33it/s]
  2%|▏         | 2/86 [00:00<00:04, 17.11it/s]



Inference Loop:


epoch: 5 | train_loss: 0.3339 | train_acc: 0.8932 | val_loss: 0.4594 | val_acc: 0.8486
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.70it/s]
100%|██████████| 8/8 [00:00<00:00, 172.94it/s]
  2%|▏         | 2/86 [00:00<00:04, 19.85it/s]



Inference Loop:


epoch: 6 | train_loss: 0.2694 | train_acc: 0.9145 | val_loss: 0.4487 | val_acc: 0.8486
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.08it/s]
100%|██████████| 8/8 [00:00<00:00, 176.92it/s]
  2%|▏         | 2/86 [00:00<00:04, 19.63it/s]



Inference Loop:


epoch: 7 | train_loss: 0.2232 | train_acc: 0.9322 | val_loss: 0.4298 | val_acc: 0.8579
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 21.68it/s]
100%|██████████| 8/8 [00:00<00:00, 175.85it/s]
  2%|▏         | 2/86 [00:00<00:04, 17.87it/s]



Inference Loop:


epoch: 8 | train_loss: 0.1757 | train_acc: 0.9442 | val_loss: 0.4583 | val_acc: 0.8520
DESC

Train Loop:


100%|██████████| 86/86 [00:03<00:00, 22.33it/s]
100%|██████████| 8/8 [00:00<00:00, 172.33it/s]



Inference Loop:


epoch: 9 | train_loss: 0.1493 | train_acc: 0.9547 | val_loss: 0.4930 | val_acc: 0.8394
DESC






In [29]:
LABEL.vocab.stoi

defaultdict(<function torchtext.vocab._default_unk_index>,
            {'ABBR': 5, 'DESC': 2, 'ENTY': 0, 'HUM': 1, 'LOC': 4, 'NUM': 3})

In [27]:
predict(net, "how many seconds are there in a minute?")

NUM


In [31]:
predict(net, "what is the full form of NASA?")

ABBR


In [33]:
predict(net, "what is your name?")

ENTY


In [35]:
predict(net, "where is the head quarters of CERN?")

LOC
