In [47]:
import torch
from torch import nn
import numpy as np
from IPython.core.debugger import set_trace
# set device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [3]:
import csv
jp_sentences = []
en_sentences = []
with open('data/kyoto_lexicon.csv', 'r', encoding='utf-8') as file:
    reader = csv.reader(file, delimiter=',')
    # skip the header row
    startLooking = False
    for row in reader:
        if startLooking:
            jp_sentences.append(row[0])
            en_sentences.append(row[1])
        startLooking = True
print(jp_sentences[:5])
print(en_sentences[:5])

['102世吉田日厚貫首', '1月15日：成人祭、新年祭', '1月3日：家運隆盛、商売繁盛祈願祭', '1月7日：七種粥神事', '21世紀COEプログラム']
['the 102nd head priest, Nikko TOSHIDA', '15th January: Seijin-sai (Adult Festival), the New Year Festival', '3rd January: Prayer Festival for the prosperity of family fortunes and business', '7th January: Nanakusa-gayu shinji (a divine service for a rice porridge with seven spring herbs to insure health for the new year)', 'The 21st Century Center Of Excellence Program']


# character-by-character prediction

In [40]:
class CharacterTable:
    def __init__(self, charset):
        self.charset = charset
        self.charset = frozenset(self.charset)
        self.charlist = ['<null>', '<sos>', '<eos>'] + list(self.charset)
        self.vocab_size = len(self.charlist)
    def encode(self, char):
        '''convert from character to index
        can process (nested) list of characters'''
        if type(char) is type('asdf'):
            # char is a string
            return self.charlist.index(char)
        else:
            # char is a list of strings
            return [self.encode(char) for char in char]
    def decode(self, charInd):
        '''convert from index to character
        can process (nested) list of indices'''
        if type(charInd) is type(22):
            # charInd is an int
            return self.charlist[charInd]
        else:
            # charInd is a list of ints
            return [self.encode(charInd) for charInd in charInd]
jp_chartable = CharacterTable(set(''.join(jp_sentences)))
en_chartable = CharacterTable(set(''.join(en_sentences)))
print(en_chartable.encode([['a', 'b'], ['c', 'd']]))
print(jp_chartable.decode(1234))

[[144, 90], [82, 44]]
譴


In [52]:
# character-by-character prediction model
class CharacterPredictor(nn.Module):
    def __init__(self, chartable, embedding_dimensions=64, hidden_size=100):
        super(CharacterPredictor, self).__init__()
        # model constants
        self.embedding_dimensions = embedding_dimensions
        self.hidden_size = hidden_size
        self.chartable = chartable
        self.vocab_size = self.chartable.vocab_size
        # model layers
        self.embedding = nn.Embedding(self.vocab_size, embedding_dimensions)
        self.RNN = nn.LSTM(
            input_size=self.embedding_dimensions,
            hidden_size=self.hidden_size, 
            batch_first=True
        )
        # linear layer for converting from hidden state to softmax
        self.linear = nn.Sequential(
            nn.Linear(self.hidden_size, self.vocab_size),
            nn.LogSoftmax(dim=-1)
        )
    
    
    def forward(self, padded_seq, lengths):
        '''
        predicts sequence of characters at every step
        seq (batch, seq) padded tensor of character indices
        returns (batch, seq, vocab) softmaxes
        implicit teacher forcing by torch RNN
        '''
        padded_seq_embed = self.embedding(padded_seq) # (batch, seq, embed)
        packed_seq_embed = torch.nn.utils.rnn.pad_packed_sequence(padded_seq_embed, lengths, batch_first=True)
        batch_size = seq.shape[0]
        hidden_states, (h_final, cell_final) = self.RNN(packed_seq_embed)
        # hidden_states (seq, batch, hidden) hidden states
        y_hat = self.linear(hidden_states)
        # y_hat (seq, batch, vocab) softmaxes
        return y_hat

In [50]:
# load data
def group_by_length(lists):
    '''2D list -> {length: 2D list of things with that length}
    rows grouped by length under a dictionary from length to list of rows of that length'''
    grouped_list = {}
    for row in lists:
        length = len(row)
        if length in grouped_list:
            grouped_list[length].append(row)
        else:
            grouped_list[length] = [row]
    return grouped_list
def train_test(sentences, chartable, train_test_split=.2):
    def pad_sequence(sentences):
        sentence_indices = [chartable.encode(list(sentence)) for sentence in sentences]
        # list of list of indices
        lengths = torch.LongTensor([len(sentence) for sentence in sentence_indices])
        sentence_tensors = [torch.LongTensor(sentence).to(device) for sentence in  sentence_indices]
        padded = torch.nn.utils.rnn.pad_sequence(sentence_tensors, batch_first=True)
        lengths, perm_idx = lengths.sort(0, descending=True)
        padded = padded[perm_idx]
        return padded, lengths
    length = len(sentences)
    split = floor(length * train_test_split)
    ### left off here. just call pad_sequence on split sentences 
    
train_test(en_sentences, en_chartable)

> [1;32m<ipython-input-50-49b30accb1c5>[0m(23)[0;36mtrain_test[1;34m()[0m
[1;32m     20 [1;33m    [0mpadded[0m [1;33m=[0m [0mpadded[0m[1;33m[[0m[0mperm_idx[0m[1;33m][0m[1;33m[0m[0m
[0m[1;32m     21 [1;33m    [0mset_trace[0m[1;33m([0m[1;33m)[0m[1;33m[0m[0m
[0m[1;32m     22 [1;33m    [1;31m### left off here. you need to pack the input before passing it through the lstm[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m---> 23 [1;33m    [1;32mpass[0m[1;33m[0m[0m
[0m[1;32m     24 [1;33m[0mtrain_test[0m[1;33m([0m[0men_sentences[0m[1;33m,[0m [0men_chartable[0m[1;33m)[0m[1;33m[0m[0m
[0m
torch.Size([51982, 315])


BdbQuit: 

In [None]:
def train_char(chartable, ):
    model = CharacterPredictor(chartable)
    optimizer = torch.optim.SGD()