In [33]:
import tensorflow as tf
from model import LanguageModel
import json
import numpy as np
from utils import batchify, get_batch
import time

In [2]:
sess = tf.Session()

In [3]:
params = {'rnn_layers': [{'units': 1150, 'input_size': 400, 'drop_i': 0.65, 'drop_w': 0.5, 'drop_o': 0.3}, {'units': 1150, 'input_size': 1150, 'drop_w': 0.5, 'drop_o': 0.3}, {'units': 400, 'input_size': 1150, 'drop_o': 0.4, 'drop_w': 0.5}], 'vocab_size': 380, 'drop_e': 0.1}

In [4]:
model = LanguageModel(**params, is_training=False)

In [5]:
model.build_model()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

In [6]:
def optimistic_restore(session, save_file):
    reader = tf.train.NewCheckpointReader(save_file)
    saved_shapes = reader.get_variable_to_shape_map()
    var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
        if var.name.split(':')[0] in saved_shapes])
    restore_vars = []
    name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables()))
    with tf.variable_scope('', reuse=True):
        for var_name, saved_var_name in var_names:
            curr_var = name2var[saved_var_name]
            var_shape = curr_var.get_shape().as_list()
            if var_shape == saved_shapes[saved_var_name]:
                restore_vars.append(curr_var)
    saver = tf.train.Saver(restore_vars)
    saver.restore(session, save_file)
    return restore_vars, saver

In [7]:
rv, saver = optimistic_restore(sess, '2/checkpoints/test-53210')

INFO:tensorflow:Restoring parameters from 2/checkpoints/test-53210


In [8]:
rv

[<tf.Variable 'LanguageModel/decoder_b:0' shape=(380,) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/embedding_weight:0' shape=(380, 400) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell/U:0' shape=(1150, 4600) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell/W:0' shape=(400, 4600) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell/bias:0' shape=(4600,) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell_1/U:0' shape=(1150, 4600) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell_1/W:0' shape=(1150, 4600) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell_1/bias:0' shape=(4600,) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell_2/U:0' shape=(400, 1600) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell_2/W:0' shape=(1150, 1600) dtype=float32_ref>,
 <tf.Variable 'LanguageModel/lstm_fused_cell_2/bias:0' shape=(1600,) dtype=float32_ref>]

In [9]:
with open('word2idx.json', 'r') as inp:
    word2idx = json.load(inp)
idx2word = {k: v for v, k in word2idx.items()}

In [10]:
def sample(a, temperature=0.5):
    a = a / temperature
    dist = np.exp(a) / np.sum(np.exp(a))
    choices = range(len(a))
    return np.random.choice(choices, p=dist)

In [11]:
def generate_text(initial_word, gen_len):
    init_word = [x for x in initial_word.replace(' ', '_')]
    seq_lens = [len(init_word)]
    inputs = np.expand_dims(np.array([word2idx[x] for x in init_word]), -1)
    # Run the first time
    output = sess.run(model.decoder, feed_dict={model.inputs: inputs, model.reset_state: True, model.seq_lens: seq_lens})
    next_idx = sample(output[-1, 0, :])
    result = [
        idx2word[next_idx]
    ]
    for i in range(1, gen_len):
        output = sess.run(model.decoder, feed_dict={model.inputs: [[next_idx]], model.reset_state: False, model.seq_lens: [1]})
        next_idx = sample(output[-1, 0, :])
        result.append(idx2word[next_idx])
    return initial_word + ''.join(x if x != '_' else ' ' for x in result)
        

In [28]:
with open('baomoi/test.npy','rb') as inp:
    test = np.load(inp)

In [30]:
test_data = batchify(test, 530).T

In [32]:
y = tf.placeholder(dtype=tf.int32, shape=[None, None], name='y')
test_loss = tf.contrib.seq2seq.sequence_loss(
    logits=model.decoder,
    targets=y,
    weights=model.seq_masks,
    average_across_timesteps=True,
    average_across_batch=True,
    name='test_loss'
)

In [36]:
def evaluate_step(test_data, bptt):
    start_time = time.time()
    total_loss = 0
    step = None
    t = 0
    for i in range(0, len(test_data), bptt):
        next_x, next_y = get_batch(test_data, bptt, i, evaluate=True)
        loss = sess.run(
            test_loss,
            feed_dict={
                model.inputs: next_x,
                y: next_y,
                model.seq_lens: [next_x.shape[0]]*next_x.shape[1],
                model.reset_state: i == 0
            }
        )
        total_loss += loss * len(next_x)
        print("Evaluate loss {}, time {}".format(
            loss, time.time()-start_time))
        t += 1
    total_time = time.time()-start_time
    total_loss /= len(test_data)
    print("Evaluate total loss {}, time {}, avg.time".format(
        total_loss, total_time, total_time / t))

In [37]:
evaluate_step(test_data, 100)

Evaluate loss 1.2794289588928223, time 0.6178994178771973
Evaluate loss 1.2413350343704224, time 1.2045297622680664
Evaluate loss 1.2227312326431274, time 1.7912700176239014
Evaluate loss 1.2187833786010742, time 2.3776304721832275
Evaluate loss 1.2169562578201294, time 2.963874578475952
Evaluate loss 1.2116812467575073, time 3.5499892234802246
Evaluate loss 1.225031852722168, time 4.136122226715088
Evaluate loss 1.2177261114120483, time 4.72249960899353
Evaluate loss 1.2154732942581177, time 5.3103649616241455
Evaluate loss 1.2171730995178223, time 5.896874666213989
Evaluate loss 1.2257611751556396, time 6.483163595199585
Evaluate loss 1.2218865156173706, time 7.069720983505249
Evaluate loss 1.231946349143982, time 7.656374454498291
Evaluate loss 1.2176791429519653, time 8.243592739105225
Evaluate loss 1.2294617891311646, time 8.83069109916687
Evaluate loss 1.2303547859191895, time 9.417949676513672
Evaluate loss 1.2253262996673584, time 10.005308866500854
Evaluate loss 1.207040071487

Evaluate loss 1.2245556116104126, time 85.94512629508972
Evaluate loss 1.2208575010299683, time 86.54025292396545
Evaluate loss 1.2171481847763062, time 87.13532209396362
Evaluate loss 1.2137984037399292, time 87.73044419288635
Evaluate loss 1.225003957748413, time 88.32616591453552
Evaluate loss 1.2169160842895508, time 88.9215931892395
Evaluate loss 1.2212707996368408, time 89.51661109924316
Evaluate loss 1.2096461057662964, time 90.11299681663513
Evaluate loss 1.221401572227478, time 90.7087779045105
Evaluate loss 1.217787742614746, time 91.30518817901611
Evaluate loss 1.2176214456558228, time 91.90243196487427
Evaluate loss 1.2240887880325317, time 92.49833583831787
Evaluate loss 1.2347034215927124, time 93.09418559074402
Evaluate loss 1.219739556312561, time 93.68912100791931
Evaluate loss 1.2149226665496826, time 94.28417658805847
Evaluate loss 1.219927430152893, time 94.87947225570679
Evaluate loss 1.2366633415222168, time 95.47469139099121
Evaluate loss 1.2189414501190186, time

Evaluate loss 1.2219680547714233, time 171.73472356796265
Evaluate loss 1.221449851989746, time 172.3310101032257
Evaluate loss 1.222186803817749, time 172.92722511291504
Evaluate loss 1.222793459892273, time 173.5236337184906
Evaluate loss 1.2155845165252686, time 174.12065982818604
Evaluate loss 1.2189505100250244, time 174.71727967262268
Evaluate loss 1.212111234664917, time 175.3140890598297
Evaluate loss 1.2237862348556519, time 175.91034722328186
Evaluate loss 1.2107315063476562, time 176.50656700134277
Evaluate loss 1.1951760053634644, time 177.10266184806824
Evaluate loss 1.2285171747207642, time 177.69953393936157
Evaluate loss 1.2207238674163818, time 178.2956600189209
Evaluate loss 1.2230889797210693, time 178.8917531967163
Evaluate loss 1.218405842781067, time 179.48787569999695
Evaluate loss 1.234965205192566, time 180.08413290977478
Evaluate loss 1.2288219928741455, time 180.68092370033264
Evaluate loss 1.2301366329193115, time 181.27692794799805
Evaluate loss 1.231378078

Evaluate loss 1.234251856803894, time 257.044960975647
Evaluate loss 1.2242413759231567, time 257.6411590576172
Evaluate loss 1.23050856590271, time 258.23758268356323
Evaluate loss 1.2247917652130127, time 258.83385586738586
Evaluate loss 1.2323541641235352, time 259.431676864624
Evaluate loss 1.22425377368927, time 260.0280261039734
Evaluate loss 1.2221533060073853, time 260.6242094039917
Evaluate loss 1.2173981666564941, time 261.2205128669739
Evaluate loss 1.2247272729873657, time 261.81762194633484
Evaluate loss 1.224429965019226, time 262.4139702320099
Evaluate loss 1.230566143989563, time 263.011438369751
Evaluate loss 1.2167742252349854, time 263.6078293323517
Evaluate loss 1.2372338771820068, time 264.20400643348694
Evaluate loss 1.2265651226043701, time 264.80083560943604
Evaluate loss 1.2297868728637695, time 265.39694356918335
Evaluate loss 1.238741159439087, time 265.99284195899963
Evaluate loss 1.2210954427719116, time 266.5895662307739
Evaluate loss 1.2220721244812012, t