<a href="https://colab.research.google.com/github/teymour-aldridge/NN/blob/master/text/pos_tagger/GRU_tagger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
import torch
from torch import nn
import torch.nn.functional as F
import nltk
from nltk.corpus import brown

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Lookup table

This is essentially a lookup table, mapping each word to a number.

The table is built by looking through every word in the corpus, assigning a number to every word. 

In [0]:
class Lang:
  def __init__(self):
    # Create a dictionary to map words to numbers
    self.word2index = {}
    # Create a dictionary to count the frequency of each word
    self.word2count = {}
    # Create a dictionary to map numbers to words
    self.index2word = {}
    self.n_words = 0
  def add_word(self, word):
    # Check if the word is already assigned to a number
    if word in self.word2index:
      self.word2count[word] += 1
    # Otherwise add it to the dictionary
    else:
      self.word2index[word] = self.n_words
      self.index2word[self.n_words] = word
      self.word2count[word] = 1
      # Increase the total number of words by 1
      self.n_words += 1
  def add_sentence(self, sentence):
    for word in sentence:
      self.add_word(word)

In [5]:
# Download the dataset
nltk.download('brown')
# Load the sentences
brown_dataset = brown.tagged_sents()

[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!


In [0]:
from torch.utils.data import Dataset
REMOVE_TOKENS = "',:!01234567.-_+=-?" + '"'
class POSDataset(Dataset):
  def __init__(self, sentences, tagged_sentences, sentence_language, tagged_language):
    self.sentences = sentences
    self.tagged_sentences = tagged_sentences
    self.sents = sentence_language
    self.pos_tags = tagged_language
    
  def __getitem__(self, i):
    item, item_tags = self.sentences[i], self.tagged_sentences[i]
    item2index = []
    item_tags2index = []
    for i, token in enumerate(item):
      item2index.append(self.sents.word2index[token])
    for i, token in enumerate(item_tags):
      item_tags2index.append(self.pos_tags.word2index[token])
    return torch.tensor(item2index, dtype=torch.long, device=device), torch.tensor(item_tags2index, dtype=torch.long, device=device)
  
  @classmethod
  def from_corpus(cls, corpus):
    sentences = []
    tagged_sentences = []
    for item in corpus:
      sentences.append([pair[0].strip(REMOVE_TOKENS) for pair in item if len(pair[0].strip(REMOVE_TOKENS)) > 0])
      tagged_sentences.append([pair[1] for pair in item if len(pair[0].strip(REMOVE_TOKENS)) > 0])
    sents, pos_tags = Lang(), Lang()
    for item in sentences:
      sents.add_sentence(item)
    for item in tagged_sentences:
      pos_tags.add_sentence(item)
    return POSDataset(sentences, tagged_sentences, sents, pos_tags)

In [0]:
brown_dataset = POSDataset.from_corpus(brown_dataset)

In [0]:
class POSTagger(nn.Module):
  def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, n_layers, device=device):
    super(POSTagger, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.gru = nn.GRU(embedding_dim, hidden_dim, n_layers)
    self.word2tag = nn.Linear(hidden_dim, tagset_size)
  def forward(self, x):
    x.to(device)
    x = self.embedding(x)
    x = x.unsqueeze(0)
    x, _ = self.gru(x)
    x = self.word2tag(x.view(len(sentence), -1))
    x = F.log_softmax(x, dim=1)
    return x

In [0]:
import os
import time
model_save_path = '/content/gdrive/My Drive/Computing/ML/Text/POS tagger/pos_tagger_weights.pt'

N_EPOCHS = 100
EMBEDDING_DIM = 256
HIDDEN_DIM = 12

vocab_size = len(brown_dataset.sents.word2index)
tagset_size = len(brown_dataset.pos_tags.word2index)
model = POSTagger(EMBEDDING_DIM, HIDDEN_DIM, vocab_size, tagset_size, 1).to(device)
loss_fn = nn.NLLLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

for i in range(N_EPOCHS):
  for sentence, tags in brown_dataset:
    # Ignore the training sample if it is of length 0 or less
    if len(sentence) == 0 or len(tags) == 0:
      continue
    t = time.time()
    model.zero_grad()
    
    pred = model(sentence)
    
    loss = loss_fn(pred, tags)
    
    loss.backward()
    optimizer.step()
  print("Loss: {}, Time: {}".format(loss.item(), time.time() - t))
  torch.save(model.state_dict(), model_save_path)