In [1]:
import random
import torch
import torch.utils.data
from torch import nn
import numpy as np
from IPython.core.debugger import set_trace
# set device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [2]:
import csv
jp_sentences = []
en_sentences = []
with open('data/kyoto_lexicon.csv', 'r', encoding='utf-8') as file:
    reader = csv.reader(file, delimiter=',')
    # skip the header row
    startLooking = False
    for row in reader:
        if startLooking:
            jp_sentences.append(row[0])
            en_sentences.append(row[1])
        startLooking = True
print(jp_sentences[:5])
print(en_sentences[:5])
print(len(jp_sentences))
print(len(en_sentences))

['102世吉田日厚貫首', '1月15日：成人祭、新年祭', '1月3日：家運隆盛、商売繁盛祈願祭', '1月7日：七種粥神事', '21世紀COEプログラム']
['the 102nd head priest, Nikko TOSHIDA', '15th January: Seijin-sai (Adult Festival), the New Year Festival', '3rd January: Prayer Festival for the prosperity of family fortunes and business', '7th January: Nanakusa-gayu shinji (a divine service for a rice porridge with seven spring herbs to insure health for the new year)', 'The 21st Century Center Of Excellence Program']
51982
51982


# character-by-character prediction

In [3]:
# encoding and decoding characters
class CharacterTable:
    def __init__(self, charset):
        self.charset = charset
        self.charset = frozenset(self.charset)
        self.charlist = ['<null>', '<sos>'] + list(self.charset)
        # it is important that null is at index 0 since padding fills with zeroes
        self.vocab_size = len(self.charlist)
    def encode(self, char):
        '''convert from character to index
        can process (nested) list of characters'''
        if type(char) is type('asdf'):
            # char is a string
            return self.charlist.index(char)
        else:
            # char is a list of strings
            return [self.encode(char) for char in char]
    def decode(self, charInd):
        '''convert from index to character
        can process (nested) list of indices'''
        if type(charInd) is type(22):
            # charInd is an int
            return self.charlist[charInd]
        else:
            # charInd is a list of ints
            return [self.decode(charInd) for charInd in charInd]
jp_chartable = CharacterTable(set(''.join(jp_sentences)))
en_chartable = CharacterTable(set(''.join(en_sentences)))
print(en_chartable.encode([['a', 'b'], ['c', 'd']]))
print(jp_chartable.decode(1234))
print(jp_chartable.vocab_size, en_chartable.vocab_size)

[[102, 78], [32, 47]]
値
3911 173


In [4]:
# sequence prediction model
class Predictor(nn.Module):
    def __init__(self, table, embedding_dimensions=64, hidden_size=100):
        super(Predictor, self).__init__()
        # model constants
        self.embedding_dimensions = embedding_dimensions
        self.hidden_size = hidden_size
        self.table = table
        self.vocab_size = self.table.vocab_size
        # model layers
        self.embedding = nn.Embedding(self.vocab_size, embedding_dimensions)
        self.RNN = nn.LSTM(
            input_size=self.embedding_dimensions,
            hidden_size=self.hidden_size, 
            batch_first=True
        )
        # linear layer for converting from hidden state to softmax
        self.linear = nn.Sequential(
            nn.Linear(self.hidden_size, self.vocab_size),
            nn.LogSoftmax(dim=-1)
        )
    
    
    def forward(self, padded_seq, lengths):
        '''
        predicts sequence of characters at every step
        seq (batch, seq) padded tensor of character indices
        returns (batch, seq, vocab) softmaxes
        implicit teacher forcing by torch RNN
        '''
        seq_len = padded_seq.shape[1]
        padded_seq_embed = self.embedding(padded_seq) # (batch, seq, embed)
        packed_seq_embed = torch.nn.utils.rnn.pack_padded_sequence(padded_seq_embed, lengths, batch_first=True)
        packed_hidden_states, (h_final, cell_final) = self.RNN(packed_seq_embed)
        padded_hidden_states, input_sizes = torch.nn.utils.rnn.pad_packed_sequence(packed_hidden_states, batch_first=True, total_length=seq_len)
        # hidden_states (batch, seq, hidden) hidden states
        y_hat = self.linear(padded_hidden_states)
        # y_hat (batch, seq, vocab) softmaxes
        return y_hat
    
    
    def predict(self, padded_seq, lengths):
        pred = self.forward(padded_seq, lengths)
        # (batch, seq, vocab)
        maxInds = pred.max(2)[1]
        # (batch, seq)
        return pred, maxInds

In [5]:
# load data
def padded_train_test(sentences, table, train_test_split=.2, batch_size=500, word=False):
    '''
    small train_test_split means mostly train data
    ['hello world', ...] or [['hello', 'world',...],...], table, train_test_split -> (train data, test data) padded tensor dataloaders
    small train_test_split means mostly train data
    output "shapes" (train_size, maxlen), (test_size, maxlen) with given batch size
    '''
    
    def encode_sequence(sentences, ans):
        # ans is whether this is the input sequence or the true sequence
        # ans=True means don't append an <sos>
        # ans=False means append an <sos> and remove the last
        sentence_indices = []
        for sentence in sentences:
            if word:
                encoded = table.encode(sentence)
            else:
                encoded = table.encode(list(sentence))
            if not ans:
                # add sos and remove last
                # this is an input sequence
                encoded = [table.encode('<sos>')] + encoded[:-1]
            sentence_indices.append(encoded)
        return sentence_indices
    def pad_sequence(sentences, ans):
        # ans is whether this is the input sequence or the true sequence
        # ans=True means don't append an <sos>
        # ans=False means append an <sos> and remove the last
        '''
        ['hello world', ...] or [['hello', 'world',...],...] -> (padded long tensor, lengths tensor)
        tensors are padded and sorted 
        '''
        sentence_indices = encode_sequence(sentences, ans)
        # list of list of indices
        lengths = torch.LongTensor([len(sentence) for sentence in sentence_indices])
        sentence_tensors = [torch.LongTensor(sentence).to(device) for sentence in  sentence_indices]
        padded = torch.nn.utils.rnn.pad_sequence(sentence_tensors, batch_first=True)
        lengths, perm_idx = lengths.sort(0, descending=True)
        # perm_idx is the permutation of sentence indices as sorted by length
        padded = padded[perm_idx]
        return padded, lengths
    
    length = len(sentences)
    # the index to separate train from test
    split = int(length * train_test_split)
    
    # shuffle before splitting so test doesn't just get the alphabetically sooner sentences
    sentences = random.sample(sentences, length)
    
    train_sentences = sentences[split:]
    test_sentences = sentences[:split]
    
    # the input sequences (with sos and removed last)
    padded_train_in = pad_sequence(train_sentences, False)
    padded_test_in = pad_sequence(test_sentences, False)
    # the output sequences (with no sos)
    padded_train_true = pad_sequence(train_sentences, True)
    padded_test_true = pad_sequence(test_sentences, True)
    
    padded_trainset = torch.utils.data.TensorDataset(*padded_train_in, *padded_train_true)
    padded_testset = torch.utils.data.TensorDataset(*padded_test_in, *padded_test_true)
    
    padded_trainloader = torch.utils.data.DataLoader(padded_trainset, batch_size=batch_size, shuffle=False, num_workers=0)
    padded_testloader = torch.utils.data.DataLoader(padded_testset, batch_size=batch_size, shuffle=False, num_workers=0)
    # shuffle must be false to maintain sorting by length
    
    return padded_trainloader, padded_testloader
padded_en_trainloader, padded_en_testloader = padded_train_test(en_sentences, en_chartable)
padded_jp_trainloader, padded_jp_testloader = padded_train_test(jp_sentences, jp_chartable)

In [6]:
def train_model(trainloader, table, lr=.1, epochs=200):
    model = Predictor(table).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    loss_fn = nn.NLLLoss()
    losses = []
    for epoch in range(epochs):
        total_loss = 0
        num_losses = 0
        for index, data in enumerate(trainloader, 0):
            model.zero_grad()
            padded_seq_in, lengths_in, padded_seq_true, lengths_true = data
            pred = model(padded_seq_in, lengths_in)
            
            batch_size = padded_seq_true.shape[0]
            maxlen = padded_seq_true.shape[1]
            vocab_size = pred.shape[-1]
            padded_seq_true_flat = padded_seq_true.view(batch_size*maxlen)
            pred_flat = pred.contiguous().view(batch_size*maxlen, vocab_size)
            
            loss = loss_fn(pred_flat, padded_seq_true_flat)
            loss.backward()
            optimizer.step()
            total_loss += loss.data
            num_losses += 1
        avg_loss = total_loss / num_losses
        losses.append(avg_loss)
        if (epoch + 1) % (epochs // 10) == 0:
            print('loss at epoch {}: {}'.format(epoch+1, avg_loss))
    print('final loss after {} epochs: {}'.format(epochs, losses[-1]))
    return model, losses

In [7]:
# save and load model
def get_state_path(name):
    return 'states/{}.pt'.format(name)
def save_model(model, name):
    torch.save(model, get_state_path(name))
def load_model(model, name):
    '''loads state dict into given model and returns it'''
    model = torch.load(get_state_path(name))
    return model

In [8]:
def initialize_models(should_train=True):
    global jp_model, en_model
    jp_model = Predictor(jp_chartable).to(device)
    jp_losses = None
    en_model = Predictor(en_chartable).to(device)
    en_losses = None
    if should_train:
        print('jp training')
        jp_model, jp_losses = train_model(padded_jp_trainloader, jp_chartable)
        save_model(jp_model, 'jp_char_model')
        print('en training')
        en_model, en_losses = train_model(padded_en_trainloader, en_chartable)
        save_model(en_model, 'en_char_model')
    else:
        jp_model = load_model(jp_model, 'jp_char_model')
        en_model = load_model(en_model, 'en_char_model')
    return jp_model, en_model
initialize_models(True)

jp training
loss at epoch 20: 0.814089834690094
loss at epoch 40: 0.7519795894622803
loss at epoch 60: 0.7228744029998779
loss at epoch 80: 0.709435760974884
loss at epoch 100: 0.7016097903251648
loss at epoch 120: 0.6961705684661865
loss at epoch 140: 0.691956102848053
loss at epoch 160: 0.6884689927101135
loss at epoch 180: 0.6854720115661621
loss at epoch 200: 0.682823121547699
final loss after 200 epochs: 0.682823121547699
en training


  "type " + obj.__name__ + ". It won't be checked "


loss at epoch 20: 0.25185489654541016
loss at epoch 40: 0.23443341255187988
loss at epoch 60: 0.2246248722076416
loss at epoch 80: 0.21615473926067352
loss at epoch 100: 0.20942629873752594
loss at epoch 120: 0.2041759341955185
loss at epoch 140: 0.19995242357254028
loss at epoch 160: 0.19644670188426971
loss at epoch 180: 0.19343675673007965
loss at epoch 200: 0.1907539814710617
final loss after 200 epochs: 0.1907539814710617


(Predictor(
   (embedding): Embedding(3911, 64)
   (RNN): LSTM(64, 100, batch_first=True)
   (linear): Sequential(
     (0): Linear(in_features=100, out_features=3911, bias=True)
     (1): LogSoftmax()
   )
 ), Predictor(
   (embedding): Embedding(173, 64)
   (RNN): LSTM(64, 100, batch_first=True)
   (linear): Sequential(
     (0): Linear(in_features=100, out_features=173, bias=True)
     (1): LogSoftmax()
   )
 ))

In [9]:
# metrics
def perplexity_metric(pred, actual):
    '''
    pred (batch, seq, vocab) logsoftmax
    actual (batch, seq) longs
    geometric mean of product of p(next word | previous words) for whole sentence
    average (arithmetic mean) by batch
    '''
    batch_size, seq_len, vocab_size = pred.shape
    pred = pred.cpu()
    pred = torch.exp(pred)
    geo_means = [] # probabilities of correct characters
    for i in range(batch_size):
        product = 1
        num_factors = 0
        curr_pred = pred[i]
        curr_actual = actual[i]
        for t in range(seq_len):
            trueInd = curr_actual[t].item()
            # the character index at this timestep
            if trueInd != 0:
                # we don't care how well it predicts nulls
                predSoftmax = curr_pred[t]
                confidence = predSoftmax[trueInd].item()
                product *= confidence
                num_factors += 1
        geo_means.append(product ** (1/num_factors))
    return sum(geo_means) / len(geo_means)
def print_metrics(model, name, testloader, word=False):
    loss_fn = nn.NLLLoss()
    losses = []
    sentence_accuracies = []
    character_accuracies = []
    perplexities = []
    for index, data in enumerate(testloader, 0):
        padded_seq_in, lengths_in, padded_seq, lengths = data
        pred, maxInds = model.predict(padded_seq, lengths)
        
        perplexity = perplexity_metric(pred, padded_seq)
        
        batch_size = padded_seq.shape[0]
        maxlen = padded_seq.shape[1]
        vocab_size = pred.shape[-1]
        
        padded_seq_flat = padded_seq.view(batch_size*maxlen)
        pred_flat = pred.contiguous().view(batch_size*maxlen, vocab_size)
        loss = loss_fn(pred_flat, padded_seq_flat).item()
        
        correct_characters = torch.sum(maxInds == padded_seq).item()
        total_characters = batch_size*maxlen
        correct_sentences = 0
        total_sentences = batch_size
        
        for i in range(batch_size):
            if torch.all(maxInds[i] == padded_seq[i]):
                correct_sentences += 1
        sentence_accuracy = correct_sentences / total_sentences
        character_accuracy = correct_characters / total_characters
        
        losses.append(loss)
        sentence_accuracies.append(sentence_accuracy)
        character_accuracies.append(character_accuracy)
        perplexities.append(perplexity)
    loss_avg = sum(losses) / len(losses)
    sentence_accuracy_avg = sum(sentence_accuracies) / len(sentence_accuracies)
    character_accuracy_avg = sum(character_accuracies) / len(character_accuracies)
    perplexity_avg = sum(perplexities) / len(perplexities)
    if word:
        print('model: {}\n\tvalidation loss: {}\n\tsentence accuracy: {}\n\tword accuracy: {}\n\tperplexity: {}'.format(name, loss_avg, sentence_accuracy_avg, character_accuracy_avg, perplexity_avg))
    else:
        print('model: {}\n\tvalidation loss: {}\n\tsentence accuracy: {}\n\tcharacter accuracy: {}\n\tperplexity: {}'.format(name, loss_avg, sentence_accuracy_avg, character_accuracy_avg, perplexity_avg))

# english word-to-word
since the japanese model had to learn a mixture of character prediction and word prediction at the same time, let's see how the english model predicts words, and compare it to the japanese character predictor

In [10]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\mthun\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [11]:
# tokenize sentences
tokenized_sentences = []
for sentence in en_sentences:
    tokenized = nltk.word_tokenize(sentence)
    if len(tokenized) > 0:
        tokenized_sentences.append(tokenized)
print(tokenized_sentences[0])

['the', '102nd', 'head', 'priest', ',', 'Nikko', 'TOSHIDA']


In [12]:
wordlist = []
for sentence in tokenized_sentences:
    for word in sentence:
        wordlist.append(word)
wordset = set(wordlist)
len(wordset)

43216

### that's way too many words!
let's limit the vocab size to 4000 to make the complexity theoretically similar to the japanese model

In [13]:
max_vocab_size = 4000
# word -> frequency
counts = {}
for word in wordlist:
    if word in counts:
        counts[word] += 1
    else:
        counts[word] = 1
sorted_wordset = sorted(list(wordset), key=lambda word: counts[word], reverse=True)
for word in sorted_wordset[:10]:
    print(word, counts[word], sep='\t')
vocab = set([])
for word in sorted_wordset:
    if len(vocab) < max_vocab_size:
        vocab.add(word)
len(vocab)

of	6995
(	6793
)	6769
the	5777
,	3457
no	2899
a	2872
Temple	1617
and	1278
in	1175


4000

In [14]:
# word encoding and decoding
class WordTable:
    def __init__(self, wordset):
        self.wordset = frozenset(wordset)
        self.wordlist = ['<null>', '<sos>', '<unk>'] + list(wordset)
        self.vocab_size = len(self.wordlist)
        
        
    def encode(self, word):
        '''
        expects word string or possibly nested list of word strings
        unks out-of-vocab words
        word(s) -> indices
        '''
        if type(word) == type('asdf'):
            if word in self.wordlist:
                return self.wordlist.index(word)
            else:
                # encode out-of-vocab words with unk
                return self.wordlist.index('<unk>')
        else:
            words = word
            return [self.encode(word) for word in words]
        
        
    def decode(self, wordInd):
        '''
        expects wordInd index or possibly nested list of word indices
        '''
        if type(wordInd) == type(123):
            return self.wordlist[wordInd]
        else:
            wordInds = wordInd
            return [self.decode(wordInd) for wordInd in wordInds]
wordtable = WordTable(vocab)
print(wordtable.decode(200))
print(wordtable.decode(wordtable.encode('why relu works')))
print(wordtable.vocab_size)

Junction
<unk>
4003


In [15]:
# load data
padded_word_trainloader, padded_word_testloader = padded_train_test(tokenized_sentences, wordtable, word=True)

In [16]:
print('training english word model')
word_model, word_model_losses = train_model(padded_word_trainloader, wordtable)

training english word model
loss at epoch 20: 0.327519029378891
loss at epoch 40: 0.2957797348499298
loss at epoch 60: 0.28763270378112793
loss at epoch 80: 0.2824614942073822
loss at epoch 100: 0.27881553769111633
loss at epoch 120: 0.2760738134384155
loss at epoch 140: 0.27389097213745117
loss at epoch 160: 0.27206066250801086
loss at epoch 180: 0.2704685628414154
loss at epoch 200: 0.2690542936325073
final loss after 200 epochs: 0.2690542936325073


In [17]:
print_metrics(jp_model, 'jp character predictor', padded_jp_testloader)
print_metrics(en_model, 'en character predictor', padded_en_testloader)
print_metrics(word_model, 'english word to word', padded_word_testloader, word=True)

model: jp character predictor
	validation loss: 1.0359487959316798
	sentence accuracy: 0.0
	character accuracy: 0.8824169164169166
	perplexity: 0.000256158712471985
model: en character predictor
	validation loss: 0.38650648295879364
	sentence accuracy: 0.0
	character accuracy: 0.9165669905904281
	perplexity: 0.00980208015188648
model: english word to word
	validation loss: 0.5625579747415724
	sentence accuracy: 0.0
	word accuracy: 0.9287701863354039
	perplexity: 0.001651558462524638


In [18]:
def predict_sentence(model, table, loader):
    seq_in, len_in, seq_true, len_true = iter(loader).next()
    pred, maxInds = model.predict(seq_in, len_in)
    trues = table.decode(seq_true[200:201].cpu().numpy().tolist())
    preds = table.decode(maxInds[200:201].cpu().detach().numpy().tolist())
    zipped = list(zip(trues, preds))
    for t, p in zipped:
        print(t)
        print(p)
predict_sentence(word_model, wordtable, padded_word_testloader)

['Six', '<unk>', 'of', 'Takashima', ',', '<unk>', ',', '<unk>', ',', '<unk>', ',', '<unk>', 'and', 'Aichi', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>']
['<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>', '<null>']
