In [1]:
# Import Packages
import numpy as np
import tensorflow as tf
import collections
import argparse
import time
import os
#from six.moves import cPickle
import pickle
print ("Packages Imported")

Packages Imported


In [2]:
# Load chars and vocab
load_dir = "data/linux_kernel"
with open(os.path.join(load_dir, 'chars_vocab.pkl'), 'rb') as f:
    chars, vocab = pickle.load(f, errors='ignore')
vocab_size = len(vocab) 
print ("'vocab_size' is %d" % (vocab_size))

'vocab_size' is 98


## Now, we are ready to make our RNN model with seq2seq
### This network is for sampling, so we don't need batches for sequenes nor optimizers

In [3]:
# Important RNN parameters 
rnn_size   = 128
num_layers = 2

batch_size = 1 # <= In the training phase, these were both 50
seq_length = 1

tf.reset_default_graph()

# Construct RNN model 
unitcell   = tf.contrib.rnn.BasicLSTMCell(rnn_size)
cell       = tf.contrib.rnn.MultiRNNCell([unitcell] * num_layers)
input_data = tf.placeholder(tf.int32, [batch_size, seq_length])
istate     = cell.zero_state(batch_size, tf.float32)

# Weigths 
with tf.variable_scope('rnnlm'):
    softmax_w = tf.get_variable("softmax_w", [rnn_size, vocab_size])
    softmax_b = tf.get_variable("softmax_b", [vocab_size])
    
    with tf.device("/cpu:0"):
        embedding = tf.get_variable("embedding", [vocab_size, rnn_size])
        inputs = tf.split( tf.nn.embedding_lookup(embedding, input_data), seq_length, 1)
        inputs = [tf.squeeze(_input, [1]) for _input in inputs]
        
outputs, final_state = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, istate, cell
                                          , loop_function=None, scope='rnnlm')
output = tf.reshape(tf.concat_v2(outputs,1 ), [-1, rnn_size])

logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
probs  = tf.nn.softmax(logits)

print ("Network Ready")

Network Ready


In [5]:
# Restore RNN
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
ckpt  = tf.train.get_checkpoint_state(load_dir)

print (ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)


data/linux_kernel/model.ckpt-0


# Finally, show what RNN has generated! 

In [6]:
# Sampling function
def weighted_pick(weights):
    t = np.cumsum(weights)
    s = np.sum(weights)
    return(int(np.searchsorted(t, np.random.rand(1)*s)))

# Sample using RNN and prime characters
prime = "/* "
state = sess.run(cell.zero_state(1, tf.float32))
for char in prime[:-1]:
    x = np.zeros((1, 1))
    x[0, 0] = vocab[char]
    state = sess.run(final_state, feed_dict={input_data: x, istate:state})

# Sample 'num' characters
ret  = prime
char = prime[-1] # <= This goes IN! 
num  = 1000
for n in range(num):
    x = np.zeros((1, 1))
    x[0, 0] = vocab[char]
    [probsval, state] = sess.run([probs, final_state]
        , feed_dict={input_data: x, istate:state})
    p      = probsval[0] 
    
    sample = weighted_pick(p)
    # sample = np.argmax(p)
    
    pred   = chars[sample]
    ret    = ret + pred
    char   = pred
    
print ("Sampling Done. \n___________________________________________\n")

print (ret)

Sampling Done. 
___________________________________________

/* .
;otåjs[y%dgVi;XZ[
}B+<%(t<$luM?Z8Fg%43_.F&`M>Ha;9Z`4)V;-ZXC*3:2zTe;ts{dbWR!9M3;&GRqGSz.4MR<t+kZpZ7J[yyd+a;{Z:d$KvGGd*tåF[osrt yw;#&
7GV<lålRTes_LO<CW=ge{ipj<M
<RyH-y3a!9! *9*[åG1Np3O5As=l<u_1å-$
g1?+yje<|J1_##VJdMdZ,;;$ZTeGå$y<*_M8HD$}RJ#;M(tRV.p%-*e4:7O.Gzyg%wGF#s1T%Lå1[:J%p.VV:*tl1d_wyhTpW%Myag#*3#V_2-y!$SsR}1V;j*%,x`<-kp$Z*$q'9_*Gta-p{0R%*wH3By[aZ 0_H,.mgVvtKZ$\l< 
H1J-pG,Z\2r<%7å.$G0såi.,GgOgHå,z:,eKvyTt0je'M%N3{KSA0y#9åe_ksM,{_og*RgOZK%_Be-j7$\~<[t tE#	eå-0ts9gd_:%#<dz9?G,ly:#{uVepj	_K7be1 t I2dG0.<;g_-u_,:M3%_t1&Mk{1
d,Yåo09*$dJBBY%9g`q|,##a[M5X$. X3sy+SåT#;:<p*,9dZG;åzåGG1å_TT9%|0ps#y4{dS$-+;:e#g?:#g1V`t0Z(peaKps.TOl<RG|gF4j.=M?$]*.l1-Ku.TuF%2G+#Nyv"KdN.<lKelVyoH__zMp_1Z,_<eåvpp+HdOgdd-:Td eO0BåGaJ$Z%,yygttye.G,åetGJug;
gOGZje<:,gk'*#SJsy%V+[å1å7åZ$u9J#ld%d1ho_-"Ze~,yl{__}1U$-V0[p[Z__K,Ks%lZ-Z_{#}LgCs.d
G84H';sH*!1$;#,,#q9*+_9åd_37]3xs|.t.Ju;ep.J<lK_eg%p$bTe-RgptxZ'|eLw_,vT C--y;:R>:32;V Vq.mzGå1;|9b,B<L-=RGGb9$0L

# Hope, it was good. 