In [5]:
import numpy as np 
from tqdm import *
import torch
from torch import nn
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import json
import random

pad_id = 4002
weight = [1 for i in range(4004)]
weight[4002] = 0

class Seq2Seq(nn.Module):

    def __init__(self, en_dims, en_voc, zh_dims, zh_voc, dropout, en_hiddens, zh_hiddens):

        super(Seq2Seq, self).__init__()
        
        self.weight = torch.Tensor(weight)
        self.en_dims = en_dims
        self.zh_dims = zh_dims
        self.en_hiddens = en_hiddens
        self.zh_hiddens = zh_hiddens

        self.en_embedding = torch.nn.Embedding(num_embeddings=en_voc, embedding_dim=en_dims)
        self.zh_embedding = torch.nn.Embedding(num_embeddings=zh_voc, embedding_dim=zh_dims)

        self.enc_fw_lstm = torch.nn.LSTM(input_size=self.en_dims,
                                   hidden_size=self.en_hiddens, 
                                   dropout=dropout,
                                   batch_first=False)
        
        self.enc_bw_lstm = torch.nn.LSTM(input_size=self.en_dims,
                                   hidden_size=self.en_hiddens, 
                                   dropout=dropout,
                                   batch_first=False)
        
        self.dec_lstm_cell = torch.nn.LSTMCell(input_size=self.en_hiddens*2+self.zh_dims,
                                               hidden_size=self.zh_hiddens, 
                                               bias=True)
        
        self.fc = nn.Linear(in_features=self.zh_hiddens, out_features=zh_voc)
        
        self.softmax = nn.Softmax()
        
        self.cost_func = nn.CrossEntropyLoss(weight=self.weight)
        
        

    def forward(self, inputs, gtruths, is_train):
        
        inputs = Variable(torch.from_numpy(inputs).long()).cuda()
        gtruths = Variable(torch.from_numpy(gtruths)).long().cuda()
        
        inputs = self.en_embedding(inputs)
        gtruths = self.zh_embedding(gtruths)
        
        inputs = torch.transpose(inputs, 0, 1)
        gtruths = torch.transpose(gtruths, 0, 1)
        
        enc_fw, _ = self.enc_fw_lstm(inputs)
        enc_bw, _ = self.enc_bw_lstm(inputs)
        
        #print ('enc_fw size: ', enc_fw[-1, :, :].size(), ';  enc_bw size: ', enc_bw[0, :, :].size())
        encoder = torch.cat([enc_fw[-1, :, :], enc_bw[0, :, :]], -1)
    
        #print ('encoder.size(): ', encoder.size())
        
        #forward lstm
        hx = Variable(torch.randn(encoder.size(0), self.zh_hiddens)).cuda()
        cx = Variable(torch.randn(encoder.size(0), self.zh_hiddens)).cuda()
        
        logits = [0 for i in range(gtruths.size(0))]
        predic = [0 for i in range(gtruths.size(0))]
        
        
        for i in range(gtruths.size(0)):
            
            if is_train:
                inp = torch.cat([encoder, gtruths[i]], -1)
            else:
                if i == 0:
                    inp = torch.cat([encoder, gtruths[0]], -1)
                else:
               
                    prev = self.zh_embedding(predic[i-1])
                    prev = prev.view(prev.size(1), prev.size(2))
                    inp = torch.cat([encoder, prev], -1)
                
            hx, cx = self.dec_lstm_cell(inp, (hx, cx))
            logits[i] = self.fc(hx)

            _, predic[i] = torch.max(logits[i], 1)
            logits[i] = logits[i].view(1, logits[i].size(0), logits[i].size(1))
            
            #print (logits[i])
            #print (predic[i])
            #print (predic[i].size())
            predic[i] = predic[i].view(1, predic[i].size(0))
            #print (predic[i])
        
        predic = torch.cat(predic, 0)
        predic = torch.transpose(predic, 0, 1)
        return torch.cat(logits), predic.data.cpu().numpy()
            
        
    def get_loss(self, logits, labels):
        
        labels = Variable(torch.from_numpy(labels)).long().cuda()
        
        labels = torch.transpose(labels, 0, 1)
        
        #print (labels.size())
        #print (logits.size())
        
        logits = logits.view(-1, logits.size(-1))
        labels = labels.contiguous().view(-1)
        
        #print (logits.size())
        #print (labels.size())
        
        loss = torch.mean(self.cost_func(logits, labels))
        
        return loss
        



In [6]:
net = Seq2Seq(en_dims = 256,
           en_voc = 50004,
           zh_dims = 256,
           zh_voc = 4004,
           dropout = 0.5,
           en_hiddens = 128,
           zh_hiddens = 256)

In [7]:
net.cuda()

Seq2Seq (
  (en_embedding): Embedding(50004, 256)
  (zh_embedding): Embedding(4004, 256)
  (enc_fw_lstm): LSTM(256, 128, dropout=0.5)
  (enc_bw_lstm): LSTM(256, 128, dropout=0.5)
  (dec_lstm_cell): LSTMCell(512, 256)
  (fc): Linear (256 -> 4004)
  (softmax): Softmax ()
  (cost_func): CrossEntropyLoss (
  )
)