Skip to content

Commit

Permalink
tie embedding in paraphrase identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Pengcheng Yin committed Nov 22, 2018
1 parent 761a98d commit 87094a0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
9 changes: 6 additions & 3 deletions model/decomposable_attention_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
class DecomposableAttentionModel(nn.Module):
"""Decomposable attention model for paraphrase identification"""

def __init__(self, src_vocab, tgt_vocab, embed_size, dropout=0., cuda=False):
def __init__(self, src_vocab, tgt_vocab, embed_size, dropout=0., tie_embed=False, cuda=False):
super(DecomposableAttentionModel, self).__init__()

self.src_embed = nn.Embedding(len(src_vocab), embed_size, padding_idx=src_vocab['<pad>'])
self.tgt_embed = nn.Embedding(len(tgt_vocab), embed_size, padding_idx=tgt_vocab['<pad>'])
if tie_embed:
self.src_embed = nn.Embedding(len(src_vocab), embed_size, padding_idx=src_vocab['<pad>'])
self.tgt_embed = nn.Embedding(len(tgt_vocab), embed_size, padding_idx=tgt_vocab['<pad>'])
else:
self.src_embed = self.tgt_embed = nn.Embedding(len(src_vocab), embed_size, padding_idx=src_vocab['<pad>'])

self.att_linear = nn.Linear(embed_size, embed_size, bias=False)

Expand Down
40 changes: 37 additions & 3 deletions model/paraphrase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
import os
from itertools import chain
import numpy as np

import torch
import torch.nn as nn
Expand All @@ -14,15 +15,17 @@
from components.reranker import RerankingFeature
from model import nn_utils
from model.decomposable_attention_model import DecomposableAttentionModel
from model.nn_utils import input_transpose


@Registrable.register('paraphrase_identifier')
class ParaphraseIdentificationModel(nn.Module, RerankingFeature, Savable):
def __init__(self, args, vocab, transition_system):
super(ParaphraseIdentificationModel, self).__init__()
self.pi_model = DecomposableAttentionModel(src_vocab=vocab.code, tgt_vocab=vocab.source,
self.pi_model = DecomposableAttentionModel(src_vocab=vocab, tgt_vocab=vocab,
embed_size=args.embed_size,
dropout=args.dropout,
tie_embed=True,
cuda=args.cuda)

self.vocab = vocab
Expand All @@ -41,8 +44,8 @@ def _score(self, src_codes, tgt_nls):
"""score examples sorted by code length"""
args = self.args

src_code_var = nn_utils.to_input_variable(src_codes, self.vocab.code, cuda=args.cuda).t()
tgt_nl_var = nn_utils.to_input_variable(tgt_nls, self.vocab.source, cuda=args.cuda).t()
src_code_var = self.to_input_variable(src_codes, cuda=args.cuda).t()
tgt_nl_var = self.to_input_variable(tgt_nls, cuda=args.cuda).t()

src_code_mask = Variable(nn_utils.length_array_to_mask_tensor([len(x) for x in src_codes], cuda=args.cuda, valid_entry_has_mask_one=True).float(), requires_grad=False)
tgt_nl_mask = Variable(nn_utils.length_array_to_mask_tensor([len(x) for x in tgt_nls], cuda=args.cuda, valid_entry_has_mask_one=True).float(), requires_grad=False)
Expand All @@ -65,6 +68,37 @@ def score(self, examples):
def tokenize_code(self, code):
return self.transition_system.tokenize_code(code, mode='decoder')

def to_input_variable(self, sequences, cuda=False, training=True):
"""
given a list of sequences,
return a tensor of shape (max_sent_len, batch_size)
"""
word_ids = []
for seq in sequences:
unk_dict = dict()
seq_wids = []
for word in seq:
if self.vocab.is_unk(word):
if word in unk_dict:
word_id = unk_dict[word]
else:
word_id = self.vocab['<unk_%d>' % len(unk_dict)]
unk_dict[word] = word_id
else:
word_id = self.vocab[word]

seq_wids.append(word_id)

word_ids.append(seq_wids)

sents_t = input_transpose(word_ids, self.vocab['<pad>'])

sents_var = Variable(torch.LongTensor(sents_t), volatile=(not training), requires_grad=False)
if cuda:
sents_var = sents_var.cuda()

return sents_var

def save(self, path):
dir_name = os.path.dirname(path)
if not os.path.exists(dir_name):
Expand Down

0 comments on commit 87094a0

Please sign in to comment.