Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge branch copy into develop #3

Merged
merged 25 commits into from
Sep 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
82ed883
Fixed topk decoder.
kylegao91 Oct 1, 2017
57d72c7
Fixed unit test.
kylegao91 Oct 4, 2017
2bfe47c
Make consistent of outputs of Decoder and TopKDecoder.
kylegao91 Oct 4, 2017
33cd333
Fixed division in python3
kylegao91 Oct 4, 2017
19f4615
Updated apidocs.
kylegao91 Oct 4, 2017
c5ccf32
Updated code style.
kylegao91 Oct 4, 2017
d78d333
Refactored decoder from decoder rnn.
kylegao91 Oct 8, 2017
7486875
Added Seq2SeqDataset class.
kylegao91 Oct 9, 2017
c3b70e5
Integration test using new dataset class.
kylegao91 Oct 9, 2017
5665934
Added index field for indexing.
kylegao91 Oct 10, 2017
48205a1
Added copy decoder.
kylegao91 Oct 10, 2017
5d6b479
Refactored interfaces to take batch instead of variables.
kylegao91 Oct 12, 2017
799a922
Refactored the place of symbol generation.
kylegao91 Oct 12, 2017
ab06567
Pass batch through decoder.
kylegao91 Oct 12, 2017
3613879
Added dynamic vocabulary to dataset.
kylegao91 Oct 13, 2017
a6ac847
Merge from `topk`.
kylegao91 Oct 14, 2017
8ec3900
Implemented CopyDecoder and refactoring.
kylegao91 Oct 24, 2017
9bbface
Use new dataset class in sample.
kylegao91 Nov 7, 2017
d8fcaef
Fixed argument order.
kylegao91 Nov 7, 2017
6469d9b
Fixed noncontinuous tensor view.
kylegao91 Nov 7, 2017
325b62d
Fixed wrong variable.
kylegao91 Nov 7, 2017
f851cdc
Merge branch 'master' into copy
kylegao91 Nov 7, 2017
fd274c2
Updated integration test script with new dataset class.
kylegao91 Nov 7, 2017
1da2f9c
Merge branch 'develop' into copy
pskrunner14 Sep 1, 2018
cde0b25
Update DecoderRNN.py
pskrunner14 Sep 1, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions examples/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq
from seq2seq.loss import Perplexity
from seq2seq.optim import Optimizer
from seq2seq.dataset import SourceField, TargetField
from seq2seq.dataset import Seq2SeqDataset
from seq2seq.evaluator import Predictor
from seq2seq.util.checkpoint import Checkpoint

Expand All @@ -29,10 +29,10 @@
# python examples/sample.py --train_path $TRAIN_PATH --dev_path $DEV_PATH --expt_dir $EXPT_PATH --load_checkpoint $CHECKPOINT_DIR

parser = argparse.ArgumentParser()
parser.add_argument('--train_path', action='store', dest='train_path',
help='Path to train data')
parser.add_argument('--dev_path', action='store', dest='dev_path',
help='Path to dev data')
parser.add_argument('--train_src', action='store', help='Path to source train data')
parser.add_argument('--train_tgt', action='store', help='Path to target train data')
parser.add_argument('--dev_src', action='store', help='Path to source develop data')
parser.add_argument('--dev_tgt', action='store', help='Path to source develop data')
parser.add_argument('--expt_dir', action='store', dest='expt_dir', default='./experiment',
help='Path to experiment directory. If load_checkpoint is True, then path to checkpoint directory has to be provided')
parser.add_argument('--load_checkpoint', action='store', dest='load_checkpoint',
Expand All @@ -59,35 +59,15 @@
output_vocab = checkpoint.output_vocab
else:
# Prepare dataset
src = SourceField()
tgt = TargetField()
max_len = 50
def len_filter(example):
return len(example.src) <= max_len and len(example.tgt) <= max_len
train = torchtext.data.TabularDataset(
path=opt.train_path, format='tsv',
fields=[('src', src), ('tgt', tgt)],
filter_pred=len_filter
)
dev = torchtext.data.TabularDataset(
path=opt.dev_path, format='tsv',
fields=[('src', src), ('tgt', tgt)],
filter_pred=len_filter
)
src.build_vocab(train, max_size=50000)
tgt.build_vocab(train, max_size=50000)
input_vocab = src.vocab
output_vocab = tgt.vocab

# NOTE: If the source field name and the target field name
# are different from 'src' and 'tgt' respectively, they have
# to be set explicitly before any training or inference
# seq2seq.src_field_name = 'src'
# seq2seq.tgt_field_name = 'tgt'
train = Seq2SeqDataset.from_file(opt.train_src, opt.train_tgt, )
train.build_vocab(50000, 50000)
dev = Seq2SeqDataset.from_file(opt.dev_src, opt.dev_tgt, share_fields_from=train)
input_vocab = train.src_field.vocab
output_vocab = train.tgt_field.vocab

# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi[train.tgt_field.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
loss.cuda()
Expand All @@ -98,11 +78,13 @@ def len_filter(example):
# Initialize model
hidden_size=128
bidirectional = True
encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
max_len = 50
encoder = EncoderRNN(len(input_vocab), max_len, hidden_size,
bidirectional=bidirectional, variable_lengths=True)
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size,

decoder = DecoderRNN(len(output_vocab), max_len, hidden_size * 2 if bidirectional else 1,
dropout_p=0.2, use_attention=True, bidirectional=bidirectional,
eos_id=tgt.eos_id, sos_id=tgt.sos_id)
eos_id=train.tgt_field.eos_id, sos_id=train.tgt_field.sos_id)
seq2seq = Seq2seq(encoder, decoder)
if torch.cuda.is_available():
seq2seq.cuda()
Expand Down
19 changes: 7 additions & 12 deletions scripts/generate_toy_data.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,31 @@
from __future__ import print_function
import argparse
import os
import shutil
import random

parser = argparse.ArgumentParser()
parser.add_argument('--dir', help="data directory", default="../data")
parser.add_argument('--max-len', help="max sequence length", default=10)
args = parser.parse_args()


def generate_dataset(root, name, size):
path = os.path.join(root, name)
if not os.path.exists(path):
os.mkdir(path)

# generate data file
data_path = os.path.join(path, 'data.txt')
with open(data_path, 'w') as fout:
src_path = os.path.join(path, 'src.txt')
tgt_path = os.path.join(path, 'tgt.txt')
with open(src_path, 'w') as src_out, open(tgt_path, 'w') as tgt_out:
for _ in range(size):
length = random.randint(1, args.max_len)
seq = []
for _ in range(length):
seq.append(str(random.randint(0, 9)))
fout.write("\t".join([" ".join(seq), " ".join(reversed(seq))]))
fout.write('\n')

# generate vocabulary
src_vocab = os.path.join(path, 'vocab.source')
with open(src_vocab, 'w') as fout:
fout.write("\n".join([str(i) for i in range(10)]))
tgt_vocab = os.path.join(path, 'vocab.target')
shutil.copy(src_vocab, tgt_vocab)
src_out.write(" ".join(seq) + "\n")
tgt_out.write(" ".join(reversed(seq)) + "\n")


if __name__ == '__main__':
data_dir = args.dir
Expand Down
50 changes: 19 additions & 31 deletions scripts/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from seq2seq.trainer import SupervisedTrainer
from seq2seq.models import EncoderRNN, DecoderRNN, TopKDecoder, Seq2seq
from seq2seq.loss import Perplexity
from seq2seq.dataset import SourceField, TargetField
from seq2seq.dataset import Seq2SeqDataset
from seq2seq.evaluator import Predictor, Evaluator
from seq2seq.util.checkpoint import Checkpoint

parser = argparse.ArgumentParser()
parser.add_argument('--train_path', action='store', dest='train_path',
help='Path to train data')
parser.add_argument('--dev_path', action='store', dest='dev_path',
help='Path to dev data')
parser.add_argument('--train_src', action='store', help='Path to train source data')
parser.add_argument('--train_tgt', action='store', help='Path to train target data')
parser.add_argument('--dev_src', action='store', help='Path to dev source data')
parser.add_argument('--dev_tgt', action='store', help='Path to dev target data')
parser.add_argument('--expt_dir', action='store', dest='expt_dir', default='./experiment',
help='Path to experiment directory. If load_checkpoint is True, then path to checkpoint directory has to be provided')
parser.add_argument('--load_checkpoint', action='store', dest='load_checkpoint',
Expand All @@ -37,29 +37,15 @@
logging.info(opt)

# Prepare dataset
src = SourceField()
tgt = TargetField()
max_len = 50
def len_filter(example):
return len(example.src) <= max_len and len(example.tgt) <= max_len
train = torchtext.data.TabularDataset(
path=opt.train_path, format='tsv',
fields=[('src', src), ('tgt', tgt)],
filter_pred=len_filter
)
dev = torchtext.data.TabularDataset(
path=opt.dev_path, format='tsv',
fields=[('src', src), ('tgt', tgt)],
filter_pred=len_filter
)
src.build_vocab(train, max_size=50000)
tgt.build_vocab(train, max_size=50000)
input_vocab = src.vocab
output_vocab = tgt.vocab
train = Seq2SeqDataset.from_file(opt.train_src, opt.train_tgt)
train.build_vocab(50000, 50000)
dev = Seq2SeqDataset.from_file(opt.dev_src, opt.dev_tgt, share_fields_from=train)
input_vocab = train.src_field.vocab
output_vocab = train.tgt_field.vocab

# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi[train.tgt_field.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
loss.cuda()
Expand All @@ -76,17 +62,19 @@ def len_filter(example):
optimizer = None
if not opt.resume:
# Initialize model
hidden_size=128
hidden_size = 128
bidirectional = True
encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
max_len = 50
encoder = EncoderRNN(len(input_vocab), max_len, hidden_size,
bidirectional=bidirectional,
rnn_cell='lstm',
variable_lengths=True)
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2,
dropout_p=0, use_attention=True,

decoder = DecoderRNN(len(output_vocab), max_len, hidden_size * 2,
dropout_p=0.2, use_attention=True,
bidirectional=bidirectional,
rnn_cell='lstm',
eos_id=tgt.eos_id, sos_id=tgt.sos_id)
eos_id=train.tgt_field.eos_id, sos_id=train.tgt_field.sos_id)
seq2seq = Seq2seq(encoder, decoder)
if torch.cuda.is_available():
seq2seq = seq2seq.cuda()
Expand Down
14 changes: 8 additions & 6 deletions scripts/integration_test.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#! /bin/sh

TRAIN_PATH=data/toy_reverse/train/data.txt
DEV_PATH=data/toy_reverse/dev/data.txt
TRAIN_SRC=data/toy_reverse/train/src.txt
TRAIN_TGT=data/toy_reverse/train/tgt.txt
DEV_SRC=data/toy_reverse/dev/src.txt
DEV_TGT=data/toy_reverse/dev/tgt.txt

# Start training
python scripts/integration_test.py --train_path $TRAIN_PATH --dev_path $DEV_PATH
python scripts/integration_test.py --train_src $TRAIN_SRC --train_tgt $TRAIN_TGT --dev_src $DEV_SRC --dev_tgt $DEV_TGT
# Resume training
python scripts/integration_test.py --train_path $TRAIN_PATH --dev_path $DEV_PATH --resume
python scripts/integration_test.py --train_src $TRAIN_SRC --train_tgt $TRAIN_TGT --dev_src $DEV_SRC --dev_tgt $DEV_TGT --resume
# Load checkpoint
python scripts/integration_test.py --train_path $TRAIN_PATH --dev_path $DEV_PATH \
--load_checkpoint $(ls -t experiment/checkpoints/ | head -1)
python scripts/integration_test.py --train_src $TRAIN_SRC --train_tgt $TRAIN_TGT --dev_src $DEV_SRC --dev_tgt $DEV_TGT \
--load_checkpoint $(ls -t experiment/checkpoints/ | head -1)
1 change: 1 addition & 0 deletions seq2seq/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .fields import SourceField, TargetField
from .dataset import Seq2SeqDataset
88 changes: 88 additions & 0 deletions seq2seq/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import codecs
from collections import Counter

import torch
import torchtext

from . import SourceField, TargetField
from .. import src_field_name, tgt_field_name

def make_example(line, fields):
pass

def _read_corpus(path):
with codecs.open(path, 'r', 'utf-8') as fin:
for line in fin:
yield line

class Seq2SeqDataset(torchtext.data.Dataset):
""" The idea of dynamic vocabulary is bought from [Opennmt-py](https://github.com/OpenNMT/OpenNMT-py)"""

def __init__(self, examples, src_field, tgt_field=None, dynamic=False, **kwargs):

# construct fields
self.src_field = src_field
self.tgt_field = tgt_field
self.fields = [(src_field_name, src_field)]
if tgt_field is not None:
self.fields.append((tgt_field_name, tgt_field))

self.dynamic = dynamic
self.dynamic_vocab = []
if self.dynamic:
src_index_field = torchtext.data.Field(use_vocab=False,
tensor_type=torch.LongTensor,
pad_token=0, sequential=True,
batch_first=True)
self.fields.append(('src_index', src_index_field))
examples = self._add_dynamic_vocab(examples)

idx_field = torchtext.data.Field(use_vocab=False,
tensor_type=torch.LongTensor,
sequential=False)
self.fields.append(('index', idx_field))
# construct examples
examples = [torchtext.data.Example.fromlist(list(data) + [i], self.fields)
for i, data in enumerate(examples)]


super(Seq2SeqDataset, self).__init__(examples, self.fields, **kwargs)

def _add_dynamic_vocab(self, examples):
tokenize = self.fields[0][1].tokenize # Tokenize function of the source field
for example in examples:
src_seq = tokenize(example[0])
dy_vocab = torchtext.vocab.Vocab(Counter(src_seq), specials=[])
self.dynamic_vocab.append(dy_vocab)
# src_indices = torch.LongTensor([dy_vocab.stoi[w] for w in tokenize(src_seq)])
src_indices = [dy_vocab.stoi[w] for w in src_seq]
yield tuple(list(example) + [src_indices])

@staticmethod
def from_file(src_path, tgt_path=None, share_fields_from=None, **kwargs):
src_list = _read_corpus(src_path)
if tgt_path is not None:
tgt_list = _read_corpus(tgt_path)
else:
tgt_list = None
return Seq2SeqDataset.from_list(src_list, tgt_list, share_fields_from, **kwargs)

@staticmethod
def from_list(src_list, tgt_list=None, share_fields_from=None, **kwargs):
corpus = src_list
if share_fields_from is not None:
src_field = share_fields_from.fields[src_field_name]
else:
src_field = SourceField()
tgt_field = None
if tgt_list is not None:
corpus = zip(corpus, tgt_list)
if share_fields_from is not None:
tgt_field = share_fields_from.fields[tgt_field_name]
else:
tgt_field = TargetField()
return Seq2SeqDataset(corpus, src_field, tgt_field, **kwargs)

def build_vocab(self, src_vocab_size, tgt_vocab_size):
self.src_field.build_vocab(self, max_size=src_vocab_size)
self.tgt_field.build_vocab(self, max_size=tgt_vocab_size)
33 changes: 16 additions & 17 deletions seq2seq/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,22 @@ def evaluate(self, model, data):
tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab
pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token]

with torch.no_grad():
for batch in batch_iterator:
input_variables, input_lengths = getattr(batch, seq2seq.src_field_name)
target_variables = getattr(batch, seq2seq.tgt_field_name)

decoder_outputs, decoder_hidden, other = model(input_variables, input_lengths.tolist(), target_variables)

# Evaluation
seqlist = other['sequence']
for step, step_output in enumerate(decoder_outputs):
target = target_variables[:, step + 1]
loss.eval_batch(step_output.view(target_variables.size(0), -1), target)

non_padding = target.ne(pad)
correct = seqlist[step].view(-1).eq(target).masked_select(non_padding).sum().item()
match += correct
total += non_padding.sum().item()
# with torch.no_grad():
for batch in batch_iterator:
decoder_outputs, decoder_hidden, other = model(batch)

# Evaluation
loss.eval_batch(decoder_outputs, batch)

seqlist = other['sequence']
target_variables = getattr(batch, seq2seq.tgt_field_name)
for step, step_output in enumerate(decoder_outputs):
target = target_variables[:, step + 1]

non_padding = target.ne(pad)
correct = seqlist[step].view(-1).eq(target).masked_select(non_padding).sum().data[0]
match += correct
total += non_padding.sum().data[0]

if total == 0:
accuracy = float('nan')
Expand Down
Loading