In [1]:
import tensorflow as tf
import numpy as np


In [2]:
char_rdic = ['h', 'e', 'l', 'o'] # id -> char
char_dic = {w : i for i, w in enumerate(char_rdic)} # char -> id
print (char_dic)

{'l': 2, 'e': 1, 'h': 0, 'o': 3}


In [3]:
ground_truth = [char_dic[c] for c in 'hello']
print (ground_truth)

[0, 1, 2, 2, 3]


In [4]:
x_data = np.array([[1,0,0,0], # h
                   [0,1,0,0], # e
                   [0,0,1,0], # l
                   [0,0,0,1]], # l
                 dtype = 'f')

In [5]:
x_data = tf.one_hot(ground_truth[:-1], len(char_dic), 1.0, 0.0, -1)
print(x_data)

Tensor("one_hot:0", shape=(4, 4), dtype=float32)


In [6]:
# Configuration
rnn_size = len(char_dic) # 4
batch_size = 1
output_size = 4

In [7]:
# RNN Model
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size,
                                       input_size = None, # deprecated at tensorflow 0.9
                                       #activation = tanh,
                                       
                                       )
print(rnn_cell)

<tensorflow.python.ops.rnn_cell.BasicRNNCell object at 0x10f1a8588>


In [8]:
initial_state = rnn_cell.zero_state(batch_size, tf.float32)
print(initial_state)

Tensor("zeros:0", shape=(1, 4), dtype=float32)


In [9]:
initial_state_1 = tf.zeros([batch_size, rnn_cell.state_size]) #  위 코드와 같은 결과
print(initial_state_1)

Tensor("zeros_1:0", shape=(1, 4), dtype=float32)


In [10]:
x_split = tf.split(0, len(char_dic), x_data) # 가로축으로 4개로 split
print(x_split)
"""
[[1,0,0,0]] # h
[[0,1,0,0]] # e
[[0,0,1,0]] # l
[[0,0,0,1]] # l
"""

[<tf.Tensor 'split:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:2' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:3' shape=(1, 4) dtype=float32>]


'\n[[1,0,0,0]] # h\n[[0,1,0,0]] # e\n[[0,0,1,0]] # l\n[[0,0,0,1]] # l\n'

In [11]:
outputs, state = tf.nn.rnn(cell = rnn_cell, inputs = x_split, initial_state = initial_state)

In [12]:
print (outputs, '\n')
print (state)

[<tf.Tensor 'RNN/BasicRNNCell/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_1/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_2/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_3/Tanh:0' shape=(1, 4) dtype=float32>] 

Tensor("RNN/BasicRNNCell_3/Tanh:0", shape=(1, 4), dtype=float32)


In [13]:
logits = tf.reshape(tf.concat(1, outputs), # shape = 1 x 16
                    [-1, rnn_size])        # shape = 4 x 4
logits.get_shape()
"""
[[logit from 1st output],
[logit from 2nd output],
[logit from 3rd output],
[logit from 4th output]]
"""

'\n[[logit from 1st output],\n[logit from 2nd output],\n[logit from 3rd output],\n[logit from 4th output]]\n'

In [14]:
targets = tf.reshape(ground_truth[1:], [-1]) # a shape of [-1] flattens into 1-D
targets.get_shape()

TensorShape([Dimension(4)])

In [15]:
weights = tf.ones([len(char_dic) * batch_size])

In [16]:
loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [targets], [weights])
cost = tf.reduce_sum(loss) / batch_size
train_op = tf.train.RMSPropOptimizer(0.01, 0.9).minimize(cost)

In [17]:
# Launch the graph in a session
with tf.Session() as sess:
    tf.initialize_all_variables().run()
    for i in range(100):
        sess.run(train_op)
        result = sess.run(tf.argmax(logits, 1))
        print(result, [char_rdic[t] for t in result]) 

[0 0 3 0] ['h', 'h', 'o', 'h']
[0 0 3 0] ['h', 'h', 'o', 'h']
[0 0 3 0] ['h', 'h', 'o', 'h']
[0 1 3 0] ['h', 'e', 'o', 'h']
[2 1 3 0] ['l', 'e', 'o', 'h']
[2 1 3 0] ['l', 'e', 'o', 'h']
[2 1 3 0] ['l', 'e', 'o', 'h']
[2 1 2 0] ['l', 'e', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 0] ['l', 'l', 'l', 'h']
[2 2 2 2] ['l', 'l', 'l', 'l']
[2 2 2 2] ['l', 'l', 'l', 'l']
[2 2 2 2] ['l', 'l', 'l', 'l']
[2 2 2 2] ['l', 'l', 'l', 'l']
[2 2 2 3] ['l', 'l', 'l', 'o']
[2 2 2 3] ['l', 'l', 'l', 'o']
[2 2 2 3] ['l', 'l', 'l', 'o']
[2 2 2 3] ['l', 'l', 'l', 'o']
[2 2 2 3] ['l', 'l', 'l', 'o']
[2 2 2 3