Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
Merge pull request #74 from tensorflow/neural-translation
Browse files Browse the repository at this point in the history
Continue working on neural translation model #9. Added language model example #64
  • Loading branch information
ilblackdragon committed Feb 10, 2016
2 parents 8646210 + 05454c6 commit 5db4eb1
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 68 deletions.
20 changes: 14 additions & 6 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
Examples of Using skflow
========================
# Examples of Using skflow

* [Deep Neural Network Regression with Boston Data](boston.py)
* [Convolutional Neural Networks with Digits Data](digits.py)
Expand All @@ -14,19 +13,28 @@ Examples of Using skflow
* [Out-of-core Data Classification Using Dask](out_of_core_data_classification.py)


Image classification
--------------------
## Image classification

* [Convolutional Neural Networks on MNIST Data](mnist.py)
* [Deep Residual Networks on MNIST Data](resnet.py)


Text classification
-------------------
## Text classification

* [Text Classification Using Recurrent Neural Networks on Words](text_classification.py)
(See also [Simplified Version Using Built-in RNN Model](text_classification_builtin_rnn_model.py) with easy to use built-in parameters)
* [Text Classification Using Convolutional Neural Networks on Words](text_classification_cnn.py)
* [Text Classification Using Recurrent Neural Networks on Characters](text_classification_character_rnn.py)
* [Text Classification Using Convolutional Neural Networks on Characters](text_classification_character_cnn.py)


## Language modeling

* [Character level language modeling](language_model.py)


## Text sequence to sequence

* [Character level neural language translation](neural_translation.py)
* [Word level neural language translation](neural_translation_word.py)

102 changes: 102 additions & 0 deletions examples/language_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# encoding: utf-8

# Copyright 2015-present Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division, print_function, absolute_import

import itertools
import math
import os
import numpy as np

import tensorflow as tf

import skflow

### Training data

CORPUS_FILENAME = "europarl-v6.fr-en.en"
MAX_DOC_LENGTH = 10

def training_data(filename):
f = open(filename)
for line in f:
yield line


def iter_docs(docs):
for doc in docs:
n_parts = int(math.ceil(float(len(doc)) / MAX_DOC_LENGTH))
for part in range(n_parts):
offset_begin = part * MAX_DOC_LENGTH
offset_end = offset_begin + MAX_DOC_LENGTH
inp = np.zeros(MAX_DOC_LENGTH, dtype=np.int32)
out = np.zeros(MAX_DOC_LENGTH, dtype=np.int32)
inp[:min(offset_end - offset_begin, len(doc) - offset_begin)] = doc[offset_begin:offset_end]
out[:min(offset_end - offset_begin, len(doc) - offset_begin - 1)] = doc[offset_begin + 1:offset_end + 1]
yield inp, out


def unpack_xy(iter_obj):
X, y = itertools.tee(iter_obj)
return (item[0] for item in X), (item[1] for item in y)


byte_processor = skflow.preprocessing.ByteProcessor(
max_document_length=MAX_DOC_LENGTH)

data = training_data(CORPUS_FILENAME)
data = byte_processor.transform(data)
X, y = unpack_xy(iter_docs(data))


### Model

HIDDEN_SIZE = 10


def seq_autoencoder(X, y):
"""Sequence auto-encoder with RNN."""
inputs = skflow.ops.one_hot_matrix(X, 256)
in_X, in_y, out_y = skflow.ops.seq2seq_inputs(inputs, y, MAX_DOC_LENGTH, MAX_DOC_LENGTH)
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, encoder_cell, decoder_cell)
return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding)


def get_language_model(hidden_size):
"""Returns a language model with given hidden size."""

def language_model(X, y):
inputs = skflow.ops.one_hot_matrix(X, 256)
inputs = skflow.ops.split_squeeze(1, MAX_DOC_LENGTH, inputs)
target = skflow.ops.split_squeeze(1, MAX_DOC_LENGTH, y)
encoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(hidden_size),256)
output, _ = tf.nn.rnn(encoder_cell, inputs, dtype=tf.float32)
return skflow.ops.sequence_classifier(output, target)

return language_model


### Training model.

estimator = skflow.TensorFlowEstimator(model_fn=get_language_model(HIDDEN_SIZE),
n_classes=256,
optimizer='Adam', learning_rate=0.01,
steps=1000, batch_size=64, continue_training=True)

estimator.fit(X, y)

113 changes: 65 additions & 48 deletions examples/neural_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,88 +16,103 @@

from __future__ import division, print_function, absolute_import

import itertools
import os
import numpy as np

import tensorflow as tf
from tensorflow.python.ops import rnn_cell, rnn, seq2seq

import skflow

# Get training data

# This dataset can be downloaded from http://www.statmt.org/europarl/v6/fr-en.tgz

def X_iter():
ENGLISH_CORPUS = "europarl-v6.fr-en.en"
FRENCH_CORPUS = "europarl-v6.fr-en.fr"

def read_iterator(filename):
f = open(filename)
for line in f:
yield line.strip()


def repeated_read_iterator(filename):
while True:
yield "some sentence"
yield "some other sentence"
f = open(filename)
for line in f:
yield line.strip()


def split_train_test(data, partition=0.2, random_seed=42):
rnd = np.random.RandomState(random_seed)
for item in data:
if rnd.uniform() > partition:
yield (0, item)
else:
yield (1, item)


def save_partitions(data, filenames):
files = [open(filename, 'w') for filename in filenames]
for partition, item in data:
files[partition].write(item + '\n')

X_pred = ["some sentence", "some other sentence"]

def y_iter():
def loop_iterator(data):
while True:
yield "какое-то приложение"
yield "какое-то другое приложение"
for item in data:
yield item

# Translation model

MAX_DOCUMENT_LENGTH = 10
HIDDEN_SIZE = 10

def rnn_decoder(decoder_inputs, initial_state, cell, scope=None):
with tf.variable_scope(scope or "dnn_decoder"):
states, sampling_states = [initial_state], [initial_state]
outputs, sampling_outputs = [], []
with tf.op_scope([decoder_inputs, initial_state], "training"):
for i in xrange(len(decoder_inputs)):
inp = decoder_inputs[i]
if i > 0:
tf.get_variable_scope().reuse_variables()
output, new_state = cell(inp, states[-1])
outputs.append(output)
states.append(new_state)
with tf.op_scope([initial_state], "sampling"):
for i in xrange(len(decoder_inputs)):
if i == 0:
sampling_outputs.append(outputs[i])
sampling_states.append(states[i])
else:
sampling_output, sampling_state = cell(sampling_outputs[-1], sampling_states[-1])
sampling_outputs.append(sampling_output)
sampling_states.append(sampling_state)
return outputs, states, sampling_outputs, sampling_states


def rnn_seq2seq(encoder_inputs, decoder_inputs, cell, dtype=tf.float32, scope=None):
with tf.variable_scope(scope or "rnn_seq2seq"):
_, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_states[-1], cell)
if not (os.path.exists('train.data') and os.path.exists('test.data')):
english_data = read_iterator(ENGLISH_CORPUS)
french_data = read_iterator(FRENCH_CORPUS)
parallel_data = ('%s;;;%s' % (eng, fr) for eng, fr in itertools.izip(english_data, french_data))
save_partitions(split_train_test(parallel_data), ['train.data', 'test.data'])

def Xy(data):
def split_lines(data):
for item in data:
yield item.split(';;;')
X, y = itertools.tee(split_lines(data))
return (item[0] for item in X), (item[1] for item in y)

X_train, y_train = Xy(repeated_read_iterator('train.data'))
X_test, y_test = Xy(read_iterator('test.data'))


# Translation model

MAX_DOCUMENT_LENGTH = 30
HIDDEN_SIZE = 100

def translate_model(X, y):
byte_list = skflow.ops.one_hot_matrix(X, 256)
in_X, in_y, out_y = skflow.ops.seq2seq_inputs(
byte_list, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH)
cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = rnn_seq2seq(in_X, in_y, cell)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, cell)
return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding)


vocab_processor = skflow.preprocessing.ByteProcessor(
max_document_length=MAX_DOCUMENT_LENGTH)

x_iter = vocab_processor.transform(X_iter())
y_iter = vocab_processor.transform(y_iter())
xpred = np.array(list(vocab_processor.transform(X_pred)))
x_iter = vocab_processor.transform(X_train)
y_iter = vocab_processor.transform(y_train)
xpred = np.array(list(vocab_processor.transform(X_test))[:20])
ygold = list(y_test)[:20]

PATH = '/tmp/tf_examples/ntm/'

if os.path.exists(PATH):
translator = skflow.TensorFlowEstimator.restore(PATH)
else:
translator = skflow.TensorFlowEstimator(model_fn=translate_model,
n_classes=256, continue_training=True)
n_classes=256,
optimizer='Adam', learning_rate=0.01, batch_size=128,
continue_training=True)

while True:
translator.fit(x_iter, y_iter, logdir=PATH)
Expand All @@ -106,7 +121,9 @@ def translate_model(X, y):
predictions = translator.predict(xpred, axis=2)
xpred_inp = vocab_processor.reverse(xpred)
text_outputs = vocab_processor.reverse(predictions)
for inp_data, input_text, pred, output_text in zip(xpred, xpred_inp, predictions, text_outputs):
print(input_text, output_text)
for inp_data, input_text, pred, output_text, gold in zip(xpred, xpred_inp,
predictions, text_outputs, ygold):
print('English: %s. French (pred): %s, French (gold): %s' %
(input_text, output_text, gold.decode('utf-8')))
print(inp_data, pred)

0 comments on commit 5db4eb1

Please sign in to comment.