In [1]:
import pickle as pkl
import gzip
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

## Import Dictionaries

In [2]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pkl.load(f)
        return loaded_object

In [3]:
id2word_zh_dic = load_zipped_pickle("../embeddings/id2word_zh_dic.p")
word2id_zh_dic = load_zipped_pickle("../embeddings/word2id_zh_dic.p")

id2word_en_dic = load_zipped_pickle("../embeddings/id2word_en_dic.p")
word2id_en_dic = load_zipped_pickle("../embeddings/word2id_en_dic.p")

## Import tokenzied datasets

In [4]:
train_zh_num = load_zipped_pickle("../data/zh-en-tokens/train_zh_tok_num.p")
train_en_num = load_zipped_pickle("../data/zh-en-tokens/train_en_tok_num.p")

test_zh_num = load_zipped_pickle("../data/zh-en-tokens/test_zh_tok_num.p")
test_en_num = load_zipped_pickle("../data/zh-en-tokens/test_en_tok_num.p")

In [5]:
# Append eos to english sentence
for i, line in enumerate(train_en_num):
    if len(line) == 0:
        del train_en_num[i]
        del train_zh_num[i]
    else:
        line[-1] = 110000

for i, line in enumerate(test_en_num):
    if len(line) == 0:
        del test_en_num[i]
        del test_zh_num[i]
    else:
        line[-1] = 110000

# Add eos to en dict
id2word_en_dic[110000] = '</s>'
word2id_en_dic['</s>'] = 110000

In [6]:
index = 27598

for num in train_zh_num[index]:
    print(id2word_zh_dic[num], end=' ')
print()
for num in train_en_num[index]:
    print(id2word_en_dic[num], end=' ')

简单 的 说 如果 <unk> 重病 病患 <unk> 患者 接受 了 搭桥 手术 他 的 症状 会 稍微 好转 </s> 
basically , if you take an extremely sick patient and you give them a bypass , they get a little bit better . </s> 

In [7]:
def max_length(sample):
    max_length = 0
    for line in sample:
        if len(line) > max_length:
            max_length = len(line)
    return max_length

max_train_zh = max_length(train_zh_num) #531
max_train_en = max_length(train_en_num) #666

## Padding

In [8]:
def pad(data, length):
    for i, line in enumerate(data):
        if len(line) < length:
            for i in range(len(line), length):
                line.append(0)
        else:
            data[i] = line[0:length]
    return data

train_zh_num = pad(train_zh_num, 10)#max_length(train_zh_num))
train_en_num = pad(train_en_num, 10) #max_length(train_en_num))

train_zh_num = torch.tensor(train_zh_num, dtype=torch.long, device=device)
train_en_num = torch.tensor(train_en_num, dtype=torch.long, device=device)

## Encoder and Decoder

In [9]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
    
    def forward(self, input_mat, hidden):
        embedded = self.embedding(input_mat).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
    
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_mat, hidden):
        output = self.embedding(input_mat).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [10]:
def train(input_tensor, target_tensor, encoder, decoder, 
          encoder_optimizer, decoder_optimizer, criterion, 
          max_length=700):
    
    encoder_hidden = encoder.initHidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0
    
    
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[0]], device=device)

    decoder_hidden = encoder_hidden

    for di in range(target_length):
        decoder_output, decoder_hidden = decoder(
            decoder_input, decoder_hidden)
        topv, topi = decoder_output.topk(1)
        decoder_input = topi.squeeze().detach()  # detach from history as input
        loss += criterion(decoder_output, torch.tensor([target_tensor[di]]))
        if decoder_input.item() == word2id_en_dic['</s>']:
            break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [11]:
learning_rate = 0.01
#batch_size = 32
encoder = EncoderRNN(110001, 40).to(device)
decoder = DecoderRNN(40, 110001).to(device)
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.NLLLoss()

In [12]:
for epoch in range(1):
    #for batch_i in range(int(train_zh_num.shape[0]/batch_size)):
    for batch_i in range(train_zh_num.shape[0]):
        if batch_i > 1000:
            break
        input_batch = train_zh_num[batch_i, :].to(device)
        target_batch = train_en_num[batch_i, :].to(device)
        loss = train(input_batch, target_batch, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        if batch_i % 50 == 0:
            print(loss)

11.580879211425781
7.921060180664062
7.822477722167969
2.417995262145996
8.301812744140625
2.338675880432129
3.250885772705078
7.495432281494141
8.13516845703125
9.270360565185547
3.7707706451416017
8.732978820800781
5.452120590209961
4.979118728637696
8.055752563476563
8.035268402099609
5.584222030639649
3.3305824279785154
6.803962707519531
5.576179885864258
5.780976486206055


In [13]:
def test(index):
    for num in train_zh_num[index]:
        print(id2word_zh_dic[int(num)], end=' ')
    input_tensor = train_zh_num[index, :]
    target_tensor = train_en_num[index, :]
    target_length = target_tensor.size(0)
    input_length = input_tensor.size(0)
    encoder_hidden = encoder.initHidden()
    encoder_outputs = torch.zeros(700, encoder.hidden_size, device=device)
    
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]
    
    decoder_hidden = encoder_hidden
    decoder_input = torch.tensor([[0]], device=device)
    
    print()
    
    for di in range(target_length):
        decoder_output, decoder_hidden = decoder(
            decoder_input, decoder_hidden)
        topv, topi = decoder_output.topk(1)
        decoder_input = topi.squeeze().detach()  # detach from history as input
        print(id2word_en_dic[int(torch.argmax(decoder_output))], end=' ')
        if decoder_input.item() == word2id_en_dic['</s>']:
            break

In [29]:
test(15937)

<unk> 后 警方 又 给 她 <unk> 新 的 照片 
so is , <unk> , <unk> , <pad> <pad> <pad> 