Skip to content
This repository has been archived by the owner on Sep 16, 2019. It is now read-only.

Commit

Permalink
1.兼容tensorflow 1.2.0版本;2.采取临时方法fix了一个tensorflow 1.2.0版本的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Miopas committed Jul 4, 2017
1 parent 0ec94f1 commit 6369d98
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions s2s_model.py
@@ -1,5 +1,7 @@

import pdb
import random
import copy

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -28,9 +30,9 @@ def __init__(self,
self.learning_rate = learning_rate

# LSTM cells
cell = tf.nn.rnn_cell.BasicLSTMCell(size)
cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=dropout)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)
cell = tf.contrib.rnn.BasicLSTMCell(size)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=dropout)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)

output_projection = None
softmax_loss_function = None
Expand All @@ -51,27 +53,32 @@ def __init__(self,
)
output_projection = (w, b)

def sampled_loss(inputs, labels):
def sampled_loss(labels, logits):
labels = tf.reshape(labels, [-1, 1])
# 因为选项有选fp16的训练,这里同意转换为fp32
local_w_t = tf.cast(w_t, tf.float32)
local_b = tf.cast(b, tf.float32)
local_inputs = tf.cast(inputs, tf.float32)
local_inputs = tf.cast(logits, tf.float32)
return tf.cast(
tf.nn.sampled_softmax_loss(
local_w_t, local_b, local_inputs, labels,
num_samples, self.target_vocab_size
weights=local_w_t,
biases=local_b,
labels=labels,
inputs=local_inputs,
num_sampled=num_samples,
num_classes=self.target_vocab_size
),
dtype
)
softmax_loss_function = sampled_loss

# seq2seq_f
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
return tf.nn.seq2seq.embedding_attention_seq2seq(
tmp_cell = copy.deepcopy(cell)
return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
encoder_inputs,
decoder_inputs,
cell,
tmp_cell,
num_encoder_symbols=source_vocab_size,
num_decoder_symbols=target_vocab_size,
embedding_size=size,
Expand Down Expand Up @@ -110,7 +117,7 @@ def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
]

if forward_only:
self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets(
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.encoder_inputs,
self.decoder_inputs,
targets,
Expand All @@ -129,7 +136,7 @@ def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
for output in self.outputs[b]
]
else:
self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets(
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.encoder_inputs,
self.decoder_inputs,
targets,
Expand Down

0 comments on commit 6369d98

Please sign in to comment.