In [17]:
import tensorflow as tf
import pandas as pd
import numpy as np

#### Load and preprocess data

In [18]:
def read_dataset():
    with tf.gfile.GFile("dataset.tsv", "r") as f:
        df = pd.read_csv(f, sep="\t", header=None, names=["xml", "closure"])
        return df

df = read_dataset()
df.head(4)

Unnamed: 0,xml,closure
0,<h1>totam</h1>,"[:h1 ""totam""]"
1,"<h2><div position=""415u8Lu6""><ul style=""4lSV13...","[:h2 [:div {:position ""415u8Lu6""} [:ul {:style..."
2,<li>iustonemoatque</li>,"[:li ""iusto"" ""nemo"" ""atque""]"
3,"<p font=""lu3gBlCOl3ER97Zi2Cz66H6SsY3"" position...","[:p {:font ""lu3gBlCOl3ER97Zi2Cz66H6SsY3"", :pos..."


In [19]:
# Calculate which characters are used in the dataset
strs = df["xml"].str.cat(df["closure"])
chars = sorted(set("".join(strs.values.tolist())))
chars.insert(0, "<END>")
chars.insert(0, "")

int_to_char = { idx: ch for idx, ch in enumerate(chars) }
char_to_int = { ch: idx for idx, ch in int_to_char.items() }

# Calculate maximum string length
max_str_len = max(df["closure"].str.len().max(), df["xml"].str.len().max())

In [20]:
def encode_serie(serie, max_str_len):
    # Encoded sequences. Each character encoded to int.
    encoded = np.zeros(shape=(len(df), max_str_len), dtype=np.int32)
    # Sequnce lengths
    str_lens = np.zeros(shape=(len(df),), dtype=np.int32)
    # Encode
    for idx, single_str in enumerate(serie):
        str_lens[idx] = len(single_str)
        for j, char in enumerate(single_str):
            encoded[idx, j]   = char_to_int[char]
    return encoded, str_lens

closure_encoded, closure_lens = encode_serie(df["closure"], max_str_len)
xml_encoded, xml_lens = encode_serie(df["xml"], max_str_len)

In [21]:
# Check if the encoding is correct
def decode(arr):
    return "".join([ int_to_char[idx] for idx in arr])

print(decode(closure_encoded[0]))
print(decode(xml_encoded[0]))

[:h1 "totam"]
<h1>totam</h1>


In [22]:
def get_batches(input_seq, input_lens, target_seq, target_lens, batch_size):
    # Truncate dataset, so we'll have complete batches. Otherwise last batch won't be complete.
    elems_count = batch_size * (len(input_seq) // batch_size)
    for idx in range(0, elems_count, batch_size):
        yield input_seq[idx : idx + batch_size], input_lens[idx : idx + batch_size], target_seq[idx : idx + batch_size], target_lens[idx : idx + batch_size]

#### Create graph

In [26]:
num_units = 64
batch_size = 100
learning_rate = 0.001

tf.reset_default_graph()


# == Inputs ==

input_seq = tf.placeholder(tf.int32, shape=(None, max_str_len), name="input_seq")
input_seq_lens = tf.placeholder(tf.int32, shape=(None,), name="input_seq_lens")

target_seq = tf.placeholder(tf.int32, shape=(None, max_str_len), name="target_seq")
target_seq_lens = tf.placeholder(tf.int32, shape=(None,), name="target_seq_lens")


# == Encoder ==

char_ids = list(int_to_char.keys())
encoder_input = tf.one_hot(indices=input_seq,
                           depth=len(char_ids),
                           name="encoder_input_one_hot")

with tf.variable_scope("encoder"):
    encoder_cell = tf.contrib.rnn.BasicLSTMCell(num_units)
    encoder_initial_state = encoder_cell.zero_state(batch_size, tf.float32)
    _, encoder_state = tf.nn.dynamic_rnn(encoder_cell,
                                         encoder_input,
                                         sequence_length=input_seq_lens,
                                         initial_state=encoder_initial_state)

    
# == Decoder ==

trigger_idx = char_to_int["<END>"]
trigger_t = tf.constant(trigger_idx, shape=(batch_size, 1))

decoder_input_indices = tf.concat([trigger_t, target_seq], axis=1, name="decoder_input_indices")
decoder_input = tf.one_hot(indices=decoder_input_indices,
                          depth=len(char_ids),
                          name="decoder_input_one_hot")
decoder_lens = tf.add(target_seq_lens, tf.constant(1))

with tf.variable_scope("decoder"):
    decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
    decoder_outputs, decoder_state = tf.nn.dynamic_rnn(decoder_cell,
                                                       decoder_input,
                                                       sequence_length=decoder_lens,
                                                       dtype=tf.float32,
                                                       initial_state=encoder_state)

logits = tf.layers.dense(decoder_outputs, len(char_ids), name="dense")
output = tf.argmax(logits, axis=2, output_type=tf.int32, name="argmax")

# Loss

decoder_target_indices = tf.concat([target_seq, trigger_t], axis=1, name="decoder_target_indices")
decoder_target_one_hot = tf.one_hot(indices=decoder_target_indices,
                                    depth=len(char_ids),
                                    name="decoder_target_one_hot")

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=decoder_target_one_hot,
                                                        logits=logits,
                                                        name="cross_entropy")
loss_op = tf.reduce_mean(cross_entropy, name="loss_op")

# copy-paste from intro-to-RNN
grad_clip = 5
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(loss_op, tvars), grad_clip)
train_op = tf.train.AdamOptimizer(learning_rate)
optimizer_op = train_op.apply_gradients(zip(grads, tvars))


# Summaries

tf.summary.histogram("cross_entropy/histogram", cross_entropy)
tf.summary.scalar("cross_entropy/mean", loss_op)
merged_summary = tf.summary.merge_all()

In [None]:
num_epochs = 50

counter = 0
save_every_n = 100

saver = tf.train.Saver()

with tf.Session() as sess:
    summary_writer = tf.summary.FileWriter("logs/", sess.graph)
    sess.run(tf.global_variables_initializer())
    for epoch in range(num_epochs):
        print("Epoch: {}".format(epoch))
        for xx, xx_lens, yy, yy_lens in get_batches(closure_encoded, closure_lens, xml_encoded, xml_lens, batch_size):
            
            loss, _, summary = sess.run([loss_op,
                                         optimizer_op,
                                         merged_summary,
                                        ],
                                        feed_dict={input_seq: xx,
                                                   input_seq_lens: xx_lens,
                                                   target_seq: yy,
                                                   target_seq_lens: yy_lens,
                                                  })
            summary_writer.add_summary(summary, counter)
            counter += 1
            if (counter % save_every_n == 0):
                saver.save(sess, "checkpoints/i_{}.ckpt".format(counter))
                
    saver.save(sess, "checkpoints/i_{}.ckpt".format(counter))

Epoch: 0


In [80]:
checkpoint = tf.train.latest_checkpoint('checkpoints')

trigger_idx = char_to_int["<END>"]
empty_str = char_to_int[""]

with tf.Session() as sess:
    saver.restore(sess, checkpoint)
    
    result = np.zeros_like(xml_encoded[0:batch_size])
    result.fill(empty_str)
    result_len = np.zeros_like(xml_lens[0:batch_size])
    result_len[:] = 0
    for idx in range(20):
        predicted = sess.run([output], feed_dict={
            input_seq: closure_encoded[0:batch_size],
            input_seq_lens: closure_lens[0:batch_size],
            target_seq: result,
            target_seq_lens: result_len,
        })
        result_len[:] += 1
        trigger_coords = np.argwhere(predicted == trigger_idx)
        for c in trigger_coords:
            print(c[0])
            result_len[c[0]] = c[1]
        
    print(predicted)
    #xml_idx = np.argmax(xml_idx_logits, axis=1)
    #print(xml_idx.shape)
    #xml_idx = np.expand_dims(xml_idx, axis=0)
    #print(xml_idx.shape)
    #s = "".join([int_to_char[idx] for idx in xml_idx])
    #print(s)
    #decode(xml_idx)
#    txt = [in xml_idx]

INFO:tensorflow:Restoring parameters from checkpoints/i_1000.ckpt
[array([[17, 62, 63, ...,  0,  0,  0],
       [17, 59, 63, ...,  0,  0,  0],
       [17, 68, 68, ...,  0,  0,  0],
       ...,
       [17, 62, 63, ...,  0,  0,  0],
       [17, 68, 68, ...,  0,  0,  0],
       [17, 68, 68, ...,  0,  0,  0]], dtype=int32)]


In [81]:
i = 0
print("input:")
print(decode(closure_encoded[i]))
print("predicted:")
print(decode(predicted[0][i]))
print("expected:")
print(decode(xml_encoded[i]))

input:
[:h1 "totam"]
predicted:
<oppaaaapppppppppppp
expected:
<h1>totam</h1>
