In [2]:
import torch
from torch import nn
import numpy as np
import pdb
# set device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [3]:
import csv
jp_sentences = []
en_sentences = []
with open('data/kyoto_lexicon.csv', 'r', encoding='utf-8') as file:
    reader = csv.reader(file, delimiter=',')
    # skip the header row
    startLooking = False
    for row in reader:
        if startLooking:
            jp_sentences.append(row[0])
            en_sentences.append(row[1])
        startLooking = True
print(jp_sentences[:5])
print(en_sentences[:5])

['102世吉田日厚貫首', '1月15日：成人祭、新年祭', '1月3日：家運隆盛、商売繁盛祈願祭', '1月7日：七種粥神事', '21世紀COEプログラム']
['the 102nd head priest, Nikko TOSHIDA', '15th January: Seijin-sai (Adult Festival), the New Year Festival', '3rd January: Prayer Festival for the prosperity of family fortunes and business', '7th January: Nanakusa-gayu shinji (a divine service for a rice porridge with seven spring herbs to insure health for the new year)', 'The 21st Century Center Of Excellence Program']


# character-by-character prediction

In [40]:
class CharacterTable:
    def __init__(self, charset):
        self.charset = charset
        self.charset = frozenset(self.charset)
        self.charlist = ['<null>', '<sos>', '<eos>'] + list(self.charset)
        self.vocab_size = len(self.charlist)
    def encode(self, char):
        '''convert from character to index
        can process (nested) list of characters'''
        if type(char) is type('asdf'):
            # char is a string
            return self.charlist.index(char)
        else:
            # char is a list of strings
            return [self.encode(char) for char in char]
    def decode(self, charInd):
        '''convert from index to character
        can process (nested) list of indices'''
        if type(charInd) is type(22):
            # charInd is an int
            return self.charlist[charInd]
        else:
            # charInd is a list of ints
            return [self.encode(charInd) for charInd in charInd]
jp_chartable = CharacterTable(set(''.join(jp_sentences)))
en_chartable = CharacterTable(set(''.join(en_sentences)))
print(en_chartable.encode([['a', 'b'], ['c', 'd']]))
print(jp_chartable.decode(1234))

[[144, 90], [82, 44]]
譴


In [41]:
import random
# character-by-character prediction model
class CharacterPredictor(nn.Module):
    def __init__(self, chartable, embedding_dimensions=64, hidden_size=100):
        super(CharacterPredictor, self).__init__()
        # model constants
        self.embedding_dimensions = embedding_dimensions
        self.hidden_size = hidden_size
        self.chartable = chartable
        self.vocab_size = self.chartable.vocab_size
        # model layers
        self.embedding = nn.Embedding(self.vocab_size, embedding_dimensions)
        self.RNN = nn.LSTM(
            input_size=self.embedding_dimensions,
            hidden_size=self.hidden_size, 
            batch_first=True
        )
        # linear layer for converting from hidden state to softmax
        self.linear = nn.Sequential(
            nn.Linear(self.hidden_size, self.vocab_size),
            nn.LogSoftmax(dim=-1)
        )
    
    
    def forward(self, seq):
        '''
        predicts sequence of characters at every step
        seq (batch, seq) tensor of character indices
        returns (batch, seq, vocab) softmaxes
        implicit teacher forcing by torch RNN
        '''
        seq_embed = self.embedding(seq) # (batch, seq, embed)
        batch_size = seq.shape[0]
        hidden_states, (h_final, cell_final) = self.RNN(seq_embed)
        # hidden_states (seq, batch, hidden) hidden states
        y_hat = self.linear(hidden_states)
        # y_hat (seq, batch, vocab) softmaxes
        return y_hat

In [42]:
# example usage of model
model = CharacterPredictor(en_chartable).to(device)
sentence = en_sentences[0]
sentence_indices = [en_chartable.encode(list(sentence))]
sentence_tensor = torch.LongTensor(sentence_indices).to(device)
print(sentence)
print(sentence_tensor, sentence_tensor.shape)
print(model(sentence_tensor), model(sentence_tensor).shape)

the 102nd head priest, Nikko TOSHIDA
tensor([[161,  99, 132,  55, 143,  28,  43,  38,  44,  55,  99, 132, 144,  44,
          55, 154,  94, 124, 132,  35, 161,  63,  55, 121, 124,  98,  98,  34,
          55,  22, 100, 164,  36,   4,  92, 110]], device='cuda:0') torch.Size([1, 36])
tensor([[[-5.1961, -5.1094, -5.1551,  ..., -5.2134, -5.1618, -5.2853],
         [-5.2293, -5.1681, -5.1038,  ..., -5.2491, -5.0742, -5.2203],
         [-5.2933, -5.1750, -5.1183,  ..., -5.2007, -5.1836, -5.2660],
         ...,
         [-5.1902, -5.0303, -5.1560,  ..., -5.2124, -5.1696, -5.3023],
         [-5.2850, -4.9637, -5.2178,  ..., -5.1580, -5.0941, -5.2161],
         [-5.3089, -5.0646, -5.1772,  ..., -5.2496, -5.0917, -5.1905]]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward>) torch.Size([1, 36, 174])


In [None]:
# figure out packing
l = [
    [1, 2, 3],
    [4, 5, 0],
    [6, 1, 0]# not actually length 1, 
]
padded = torch.tensor(l)
lengths = torch.tensor([3,2,1])

In [39]:
# load data
def group_by_length(lists):
    '''2D list -> {length: 2D list of things with that length}
    rows grouped by length under a dictionary from length to list of rows of that length'''
    grouped_list = {}
    for row in lists:
        length = len(row)
        if length in grouped_list:
            grouped_list[length].append(row)
        else:
            grouped_list[length] = [row]
    return grouped_list
def train_test(sentences, chartable, train_test_split=.2):
    sentence_indices = [chartable.encode(list(sentence)) for sentence in sentences]
    # list of list of indices
    lengths = torch.LongTensor([len(sentence) for sentence in sentence_indices])
    sentence_tensors = [torch.LongTensor(sentence).to(device) for sentence in  sentence_indices]
    padded = torch.nn.utils.rnn.pad_sequence(sentence_tensors, batch_first=True)
    lengths, perm_idx = lengths.sort(0, descending=True)
    padded = padded[perm_idx]
    ### left off here. you need to pack the input before passing it through the lstm
    lengths = sentence_tensor.

[161, 98, 131, 54, 142, 27, 42, 37, 43, 54, 98, 131, 143, 43, 54, 154, 93, 123, 131, 34, 161, 62, 54, 120, 123, 97, 97, 33, 54, 21, 99, 164, 35, 1, 91, 109]


In [None]:
def train_char(chartable, ):
    model = CharacterPredictor(chartable)
    optimizer = torch.optim.SGD()