In [18]:
import argparse
import shutil
import re
import os

import torch
from torch.autograd import Variable

#from seq2seq import *
#from utils import *

In [19]:
def construct_vocab(file_, mincount=10):
    vocab2id = {
        '<s>': 0,
        '</s>': 1,
        '<pad>': 2,
        '<unk>': 3
    }
    
    id2vocab = {
        0: '<s>',
        1: '</s>',
        2: '<pad>',
        3: '<unk>'
    }
    cnt = 4
    with open(file_, 'r') as fp:
        for line in fp:
            arr = re.split('<sec>', line[:-1])
            if int(arr[1]) >= mincount:
                vocab2id[arr[0]] = cnt
                id2vocab[cnt] = arr[0]
                cnt += 1
    
    return vocab2id, id2vocab

In [20]:
def create_batch_file(file_name, batch_size):
    folder = 'batch_folder'
    fkey = 'batch_'
    if os.path.exists(folder):
        shutil.rmtree(folder)
    os.mkdir(folder)
    
    fp = open(file_name, 'r')
    cnt = 0
    for line in fp:
        try:
            arr.append(line)
        except:
            arr = []
        if len(arr) == batch_size:
            fout = open(folder+'/'+fkey+str(cnt), 'w')
            for itm in arr:
                fout.write(itm)
            fout.close()
            arr = []
            cnt += 1
    fp.close()
    
    return cnt

In [21]:
def process_minibatch(batch_id, vocab2id, max_lens=[512, 64]):
    
    folder = 'batch_folder'
    fkey = 'batch_'
    file_ = folder + '/' + fkey + str(batch_id)
    fp = open(file_, 'r')
    src_arr = []
    trg_arr = []
    for line in fp:
        arr = re.split('<sec>', line[:-1])
            
        dabs = re.split('<pg>|<st>', arr[2])
        for j in range(len(dabs)):
            dabs[j] += '.'
        dabs = ''.join(dabs)
        dabs = re.split('\s', dabs)
        dabs = filter(None, dabs)
        dabs = ['<s>'] + dabs + ['</s>']
        dabs2id = [
            vocab2id[wd] if wd in vocab2id
            else vocab2id['<unk>']
            for wd in dabs
        ]
        trg_arr.append(dabs2id)
        
        dart = ''.join(re.split('<pg>|<st>', arr[3]))
        dart = re.split('\s', dart)
        dart = filter(None, dart)
        dart = ['<s>'] + dart + ['</s>']
        dart2id = [
            vocab2id[wd] if wd in vocab2id
            else vocab2id['<unk>']
            for wd in dart
        ]
        src_arr.append(dart2id)
    fp.close()
    
    src_arr = [itm[:max_lens[0]] for itm in src_arr]
    trg_arr = [itm[:max_lens[1]] for itm in trg_arr]
    
    #src_lens = [len(itm) for itm in src_arr]
    #trg_lens = [len(itm) for itm in trg_arr]
    #max_lens = [max(src_lens), max(trg_lens)]

    src_arr = [
        itm[:-1] + [vocab2id['<pad>']]*(1+max_lens[0]-len(itm))
        for itm in src_arr
    ]
    trg_input_arr = [
        itm[:-1] + [vocab2id['<pad>']]*(1+max_lens[1]-len(itm))
        for itm in trg_arr
    ]
    trg_output_arr = [
        itm[1:] + [vocab2id['<pad>']]*(1+max_lens[1]-len(itm))
        for itm in trg_arr
    ]
    
    src_var = Variable(torch.LongTensor(src_arr))
    trg_input_var = Variable(torch.LongTensor(trg_input_arr))
    trg_output_var = Variable(torch.LongTensor(trg_output_arr))
    
    return src_var, trg_input_var, trg_output_var

In [22]:
import numpy as np
import torch
from torch.autograd import Variable

class seq2seq(torch.nn.Module):
    '''
    LSTM encoder
    LSTM decoder
    '''
    def __init__(
        self,
        src_emb_dim=100,
        trg_emb_dim=100,
        src_hidden_dim=25,
        trg_hidden_dim=50,
        src_vocab_size=999,
        trg_vocab_size=999,
        src_pad_token=0,
        trg_pad_token=0,
        src_nlayer=2,
        trg_nlayer=1,
        src_bidirect=True,
        batch_size=128,
        dropout=0.0
    ):
        super(seq2seq, self).__init__()
        
        self.src_bidirect = src_bidirect
        self.trg_vocab_size = trg_vocab_size

        self.n_directions = 1
        self.src_hidden_dim = src_hidden_dim//2
        if src_bidirect:
            self.n_directions = 2
            self.src_hidden_dim = src_hidden_dim
        
        self.src_embedding = torch.nn.Embedding(
            src_vocab_size,
            src_emb_dim,
            padding_idx=0
        ).cuda()
        
        self.trg_embedding = torch.nn.Embedding(
            trg_vocab_size,
            trg_emb_dim,
            padding_idx=0
        ).cuda()
        
        self.encoder = torch.nn.LSTM(
            input_size=src_emb_dim,
            hidden_size=src_hidden_dim,
            num_layers=src_nlayer,
            bidirectional=src_bidirect,
            batch_first=True,
            dropout=dropout
        ).cuda()
        
        self.decoder = torch.nn.LSTM(
            input_size=trg_emb_dim,
            hidden_size=trg_hidden_dim,
            num_layers=trg_nlayer,
            batch_first=True,
            dropout=dropout
        ).cuda()
        
        self.src2trg = torch.nn.Linear(
            src_hidden_dim*self.n_directions,
            trg_hidden_dim
        ).cuda()
        
        self.trg2vocab = torch.nn.Linear(
            trg_hidden_dim,
            trg_vocab_size
        ).cuda()
        
        # init weights
        torch.nn.init.normal(self.src_embedding.weight, mean=0.0, std=0.02)
        torch.nn.init.normal(self.trg_embedding.weight, mean=0.0, std=0.02)
        torch.nn.init.constant(self.src2trg.bias, 0.0)
        torch.nn.init.constant(self.trg2vocab.bias, 0.0)
        
    def forward(self, input_src, input_trg):
        # init state
        src_emb = self.src_embedding(input_src)
        trg_emb = self.trg_embedding(input_trg)
        
        batch_size = input_src.size(1)
        if self.encoder.batch_first:
            batch_size = input_src.size(0)
            
        src_h_0 = Variable(torch.zeros(
            self.encoder.num_layers*self.n_directions,
            batch_size,
            self.src_hidden_dim
        )).cuda()
        
        src_c_0 = Variable(torch.zeros(
            self.encoder.num_layers*self.n_directions,
            batch_size,
            self.src_hidden_dim
        )).cuda()
                
        src_h, (src_h_t, src_c_t) = self.encoder(
            src_emb,
            (src_h_0, src_c_0)
        )
        
        if self.src_bidirect:
            h_t = torch.cat((src_h_t[-1], src_h_t[-2]), 1)
            c_t = torch.cat((src_c_t[-1], src_c_t[-2]), 1)
        else:
            h_t = src_h_t[-1]
            c_t = src_c_t[-1]
            
        trg_init_state = self.src2trg(h_t)
        trg_init_state = torch.nn.Tanh()(trg_init_state)

        trg_h_0 = trg_init_state.view(
            self.decoder.num_layers,
            trg_init_state.size(0),
            trg_init_state.size(1)
        )
        trg_c_0 = c_t.view(
            self.decoder.num_layers,
            c_t.size(0),
            c_t.size(1)
        )
        
        trg_h, (_, _) = self.decoder(
            trg_emb,
            (trg_h_0, trg_c_0)
        )
        
        trg_h_reshape = trg_h.contiguous().view(
            trg_h.size(0)*trg_h.size(1),
            trg_h.size(2)
        )
                
        decoder_output = self.trg2vocab(trg_h_reshape)
        decoder_output = decoder_output.view(
            trg_h.size(0),
            trg_h.size(1),
            decoder_output.size(1)
        )
        
        return decoder_output
    
    def decode(self, logits):
        logits_reshape = logits.view(-1, self.trg_vocab_size)
        word_probs = torch.nn.functional.softmax(logits_reshape)
        word_probs = word_probs.view(
            logits.size()[0], logits.size()[1], logits.size()[2]
        )
        return word_probs


In [23]:
data_dir = '../sum_data/'
file_vocab = 'cnn_vocab.txt'
file_corpus = 'cnn.txt'
n_epoch = 10
batch_size = 64

vocab2id, id2vocab = construct_vocab(data_dir+'/'+file_vocab)
print 'The vocabulary size: {0}'.format(len(vocab2id))

n_batch = create_batch_file(file_name='../sum_data/cnn.txt', batch_size=batch_size)
print 'The number of batches: {0}'.format(n_batch)

model = seq2seq(
    src_emb_dim=100,
    trg_emb_dim=100,
    src_hidden_dim=25,
    trg_hidden_dim=50,
    src_vocab_size=len(vocab2id),
    trg_vocab_size=len(vocab2id),
    src_pad_token=0,
    trg_pad_token=0,
    src_nlayer=2,
    trg_nlayer=1,
    src_bidirect=True,
    batch_size=batch_size,
    dropout=0.0
).cuda()

weight_mask = torch.ones(len(vocab2id)).cuda()
weight_mask[vocab2id['<pad>']] = 0
loss_criterion = torch.nn.CrossEntropyLoss(weight=weight_mask).cuda()

optimizer = torch.optim.Adam(model.parameters())

out_dir = 'results'
if os.path.exists(out_dir):
    shutil.rmtree(out_dir)
os.mkdir(out_dir)
losses = []
for epoch in range(n_epoch):
    for batch_id in range(n_batch):
        src_var, trg_input_var, trg_output_var = process_minibatch(
            batch_id, vocab2id, max_lens=[512, 64]
        )
        logits = model(src_var.cuda(), trg_input_var.cuda())
        optimizer.zero_grad()
        
        loss = loss_criterion(
            logits.contiguous().view(-1, len(vocab2id)),
            trg_output_var.view(-1).cuda()
        )
        loss.backward()
        optimizer.step()
        
        losses.append([epoch, batch_id, loss.data.cpu().numpy()[0]])
        if batch_id % 100 == 0:
            loss_np = np.array(losses)
            np.save(out_dir+'/loss', loss_np)
            
            print 'epoch={0} batch={1} loss={2}'.format(
                epoch, batch_id, loss.data.cpu().numpy()[0]
            )
            word_prob = model.decode(logits).data.cpu().numpy().argmax(axis=2)
            sen_pred = [id2vocab[x] for x in word_prob[0]]
            st_idx = len(sen_pred)
            for k, wd in enumerate(sen_pred):
                if wd == '</s>':
                    st_idx = k
                    break
            sen_pred = sen_pred[:st_idx]
            print ' '.join(sen_pred)
                        
shutil.rmtree('batch_folder')

The vocabulary size: 80475
The number of batches: 1444
epoch=0 batch=0 loss=11.292509079
distorted distorted aircrew aircrew four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly four-yearly aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew aircrew


KeyboardInterrupt: 