# Load the dataset

In [1]:
from modules.texts import Vocab, GloVeLoader
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import modules.extractive as ext
import modules.abstractive as abs
from modules.data import Documents
from torch.utils.data import DataLoader

# Initialize the pretrained embeddings

In [2]:
import numpy as np

# Load the pretrained embedding into the memory
path_glove = os.path.join(os.path.expanduser('~'),
             'data/NLP/word_embeddings/GloVe/glove.6B.200d.txt')
glove = GloVeLoader(path_glove)

# Load the dataset
doc_file = './data/kaggle_news_rouge1.pkl'
docs = Documents(doc_file, n_samples=2000, vocab_size = 30000)
docs.set_doc_classes(np.random.randint(2, size = len(docs)).tolist()) # attach random document labels
vocab = docs.vocab

d = 200
emb = nn.Embedding(vocab.V, d)

def init_emb(emb, vocab):
    for word in vocab.word2id:
        try:
            emb.weight.data[vocab[word]] = torch.from_numpy(glove[word])
        except KeyError as e:
            # Case when pretrained embedding for a word does not exist
            pass
#     emb.weight.requires_grad = False # suppress updates
    print('Initialized the word embeddings.')

init_emb(emb, vocab)

The pretrained vector file to use: /home/yhs/data/NLP/word_embeddings/GloVe/glove.6B.200d.txt
The number of words in the pretrained vector: 400000
The dimension of the pretrained vector: 200
Initialized the word embeddings.


In [3]:
# Test
from copy import deepcopy
from torch import optim
import time
from itertools import chain

vocab_size = vocab.V
emb_size = emb.weight.data.size(1)
n_kernels = 50
kernel_sizes = [1,2,3,4,5]
pretrained = emb
sent_size = len(kernel_sizes) * n_kernels
hidden_size = 400
num_layers = 1
n_classes = 2
batch_size = 1
torch.manual_seed(7)
torch.cuda.manual_seed(7)

ext_s_enc = ext.SentenceEncoder(vocab_size, emb_size,
                                   n_kernels, kernel_sizes, pretrained)
ext_d_enc = ext.DocumentEncoder(sent_size, hidden_size)
ext_extc = ext.ExtractorCell(sent_size, hidden_size)
ext_d_classifier = ext.DocumentClassifier(sent_size, n_classes)
abs_enc = abs.EncoderRNN(emb, hidden_size, num_layers)
abs_dec = abs.AttnDecoderRNN(emb, hidden_size * 2, num_layers)

models = [ext_s_enc, ext_d_enc, ext_extc, ext_d_classifier,
         abs_enc, abs_dec]
params = list(chain(*[model.parameters() for model in models]))
optimizer = optim.Adam(params, lr = .005)

loss_fn_ext = nn.BCELoss()
loss_fn_dclass = nn.NLLLoss()
loss_fn_abs = nn.CrossEntropyLoss()

def get_accuracy(probs, targets, verbose = False):   
    '''
    Calculates the accuracy for the predictions

    Args:
        probs: extraction probability
        targets: ground truth labels for extraction
    '''
    import numpy as np
    preds = np.array([1 if p > 0.5 else 0 for p in probs])
    if verbose:
        print(preds)
    accuracy = np.mean(preds == targets)
    
    return accuracy

# class RougeScorer:
#     def __init__(self):
#         from rouge import Rouge
#         self.rouge = Rouge()
#     def score(self, reference, generated, type = 1):
#         score = self.rouge.get_scores(reference, generated, avg=True)
#         score = score['rouge-%s' % type]['f']
#         return score

# rouge = RougeScorer()
    
def run_epoch(docs):
    
    epoch_loss_abs = 0
    epoch_loss_ext = 0
    epoch_loss_dclass = 0
    epoch_accuracy_ext = 0
    epoch_accuracy_dclass = 0

    for doc in docs:
        optimizer.zero_grad()
        docloader = DataLoader(doc, batch_size=1, shuffle=False)
        # Encode the sentences in a document
        sents_raw = []
        sents_encoded = []
        ext_labels = []
        doc_class = Variable(torch.LongTensor([doc.doc_class])).cuda()
        for sent, ext_label in docloader:
            # only accept sentences that conforms the maximum kernel sizes
            if sent.size(1) < max(kernel_sizes):
                continue
            sent = Variable(sent).cuda()
            sents_raw.append(sent)
            sents_encoded.append(ext_s_enc(sent))
            ext_labels.append(ext_label.cuda())
        # Ignore if the content is a single sentence(no need to train)
        if len(sents_raw) <= 1:
            continue

        # Build the document representation using encoded sentences
        d_encoded = torch.cat(sents_encoded, dim = 0).unsqueeze(1)
        ext_labels = Variable(torch.cat(ext_labels, dim = 0).type(torch.FloatTensor).view(-1)).cuda()
        init_sent = ext_s_enc.init_sent(batch_size)
        d_ext = torch.cat([init_sent, d_encoded[:-1]], dim = 0)

        # Extractive Summarizer
        ## Initialize the d_encoder
        h, c = ext_d_enc.init_h0c0(batch_size)
        h0 = Variable(h.data)
        ## An input goes through the document encoder
        output, hn, cn = ext_d_enc(d_ext, h, c)
        ## Initialize the decoder
        ### calculate p0, h_bar0, c_bar0
        h_ = hn.squeeze(0)
        c_ = cn.squeeze(0)
        p = ext_extc.init_p(h0.squeeze(0), h_)
        ### calculate p_t, h_bar_t, c_bar_t
        d_encoder_hiddens = torch.cat((h0, output[:-1]), 0) #h0 ~ h_{n-1}
        extract_probs = Variable(torch.zeros(len(sents_encoded))).cuda()
        for i, (s, h) in enumerate(zip(sents_encoded, d_encoder_hiddens)):
            h_, c_, p = ext_extc(s, h, h_, c_, p)
            extract_probs[i] = p
        ## Document Classifier
        q = ext_d_classifier(extract_probs.view(-1,1), d_encoded.squeeze(1))

        # Abstractive Summarizer
        loss_abs = 0
        ## Run through the encoder
        words = torch.cat(sents_raw, dim=1).t()
        abs_enc_hidden = abs_enc.init_hidden(batch_size)
        abs_enc_output, abs_enc_hidden = abs_enc(words, abs_enc_hidden)
        ## Remove to too long documents to tackle memory overflow
        if len(abs_enc_output) > 6000:
            continue
        ## Run through the decoder
        abs_dec_hidden = abs_dec.init_hidden(batch_size)
        for target in doc.summ:
            target = Variable(torch.LongTensor([target]).unsqueeze(1)).cuda()
            abs_dec_output, abs_dec_hidden, attn_weights = abs_dec(target, abs_dec_hidden, abs_enc_output)
            loss_abs += loss_fn_abs(abs_dec_output, target.squeeze(1))

        # Optimization
        loss_ext = loss_fn_ext(extract_probs, ext_labels)
        loss_dclass = loss_fn_dclass(q.view(1,-1), doc_class)
        epoch_loss_ext += loss_ext.data.cpu().numpy()[0]
        epoch_loss_dclass += loss_dclass.data.cpu().numpy()[0]
        epoch_loss_abs += loss_abs.data.cpu().numpy()[0]
#         torch.autograd.backward([loss_ext, loss_dclass, loss_abs])
        torch.autograd.backward([loss_ext, loss_abs])
        optimizer.step()

        # Measure the accuracy
        p_cpu = extract_probs.data.cpu().numpy()
        t_cpu = ext_labels.data.cpu().numpy()
        q_cpu = q.data.cpu().numpy()
        c_cpu = doc_class.data.cpu().numpy()
        epoch_accuracy_ext += get_accuracy(p_cpu, t_cpu)
        epoch_accuracy_dclass += get_accuracy(q_cpu, c_cpu)

    acc_ext = epoch_accuracy_ext / len(docs)
    acc_dclass = epoch_accuracy_dclass / len(docs)
    
    return epoch_loss_ext, epoch_loss_dclass, epoch_loss_abs, acc_ext, acc_dclass

def train(docs, n_epochs = 10, print_every = 1):
    import time
    
    for epoch in range(n_epochs):
        start_time = time.time()
        ext_loss, dclass_loss, abs_loss, ext_acc, dclass_acc = run_epoch(docs)
        end_time = time.time()
        wall_clock = (end_time - start_time) / 60
        if epoch % print_every == 0:
            print('Epoch:%2i / Loss:(%.3f/%.3f/%.3f) / Accuracy:(%.3f/%.3f) / TrainingTime:%.3f(min)' %
                  (epoch, ext_loss, dclass_loss, abs_loss, ext_acc, dclass_acc, wall_clock))

# Initial Training
train(docs, n_epochs = 10, print_every = 1)            

Epoch: 0 / Loss:(1261.125/3068.850/773485.033) / Accuracy:(0.570/0.483) / TrainingTime:12.322(min)
Epoch: 1 / Loss:(1221.491/3937.374/618076.558) / Accuracy:(0.599/0.483) / TrainingTime:12.284(min)
Epoch: 2 / Loss:(1248.366/6934.532/495365.583) / Accuracy:(0.568/0.483) / TrainingTime:12.244(min)
Epoch: 3 / Loss:(1261.470/10030.308/nan) / Accuracy:(0.578/0.483) / TrainingTime:12.251(min)


KeyboardInterrupt: 