# LSTM POS tagger: char level
- simple LSTM model using character embeddings

In [None]:
%load_ext autoreload
%autoreload 2

## load data

In [None]:
from data import load_penn_treebank_data

In [None]:
train_data, test_data = load_penn_treebank_data()

In [None]:
print('train_data[0][0]: {}'.format(train_data[0][0]))
print('train_data[0][1]: {}'.format(train_data[0][1]))

----

## Convert data to index

In [None]:
from util import get_conversion_tables, prepare_sequence

In [None]:
char_to_ix, word_to_ix, tag_to_ix = get_conversion_tables(train_data, min_count=1)
n_chars = len(char_to_ix)
n_words = len(word_to_ix)
n_tags = len(tag_to_ix)
print('n_chars: {}'.format(n_chars))
print('n_words: {}'.format(n_words))
print('n_tags: {}'.format(n_tags))

----

# LSTM

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class LSTMTagger(nn.Module):
    def __init__(self, n_chars, n_words, n_tags, char_dim, char_lstm_dim, word_dim, hidden_dim):
        super(LSTMTagger, self).__init__()

        self.n_chars = n_chars
        self.n_words = n_words
        self.n_tags = n_tags
        
        self.char_dim = char_dim
        self.char_lstm_dim = char_lstm_dim
        self.word_dim = word_dim
        
        self.char_embeddings = nn.Embedding(n_chars, char_dim)
        self.char_lstm = nn.LSTM(char_dim, char_lstm_dim)
        self.word_embeddings = nn.Embedding(n_words, word_dim)
        self.lstm = nn.LSTM(char_dim + word_dim, hidden_dim)
        self.hidden2out = nn.Linear(hidden_dim, n_tags)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, words, sentence):
        # char level
        char_embeds = [self.char_embeddings(w) for w in words]
        char_lstm_out = [self.char_lstm(ce.view(len(ce), 1, -1))[0][-1][0] for ce in char_embeds]
        
        # word level table mapping
        sent_embeds = self.word_embeddings(sentence)
        
        # concat word embs and char embs
        embeds = [torch.cat((clo, se), 0) for clo, se in zip(char_lstm_out, sent_embeds)]
        embeds = torch.stack(embeds)

        lstm_out, _ = self.lstm(embeds.view(len(embeds), 1, -1))
        tag_space = self.hidden2out(lstm_out.view(len(sentence), -1))
        outputs = self.softmax(tag_space)

        return outputs

----

# Train

In [None]:
CHAR_DIM = 25
CHAR_LSTM_DIM = 25
WORD_DIM = 100
HIDDEN_DIM = 25
EPOCH_NUM = 10

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMTagger(n_chars, n_words, n_tags, CHAR_DIM, CHAR_LSTM_DIM, WORD_DIM, HIDDEN_DIM).to(device)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [None]:
%%time
print('using {}'.format(device))

for epoch in range(EPOCH_NUM):
    running_loss = 0
    for i, (sentence, tags) in enumerate(train_data):
        model.zero_grad()
        
        ## char_to_ix
        words_in = [prepare_sequence(w, char_to_ix) for w in sentence]
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)
        sentence_in, targets = sentence_in.to(device), targets.to(device)
        
        outputs = model(words_in, sentence_in)
        
        loss = loss_function(outputs, targets)

        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0            
        
print('finished')

---

# Evaluate

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for sentence, tags in test_data:
        words_in = [prepare_sequence(w, char_to_ix) for w in sentence]
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)

        outputs = model(words_in, sentence_in)
        _, predicted = torch.max(outputs.data, 1)
        
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

print('Accuracy: {:.2f} %'.format(100 * correct / total)) 