In [1]:
import numpy as np
import tensorflow as tf#r.12
import pandas as pd

### Loading toy data with 20 vocab size

In [2]:
source = open("data/s2s_sources.txt",'r')
source = [[int(word) for word in line.split()] for line in source]

target = open("data/s2s_targets.txt",'r')
target = [[int(word) for word in line.split()] for line in target]

### Hyper Parameters

In [3]:
ENCODER_VOCABSIZE = 20 + 3
DECODER_VOCABSIZE = 20 + 3

ENCODER_CELLSIZE = 512
DECODER_CELLSIZE = 512*2

BATCHSIZE = 100

NLAYERS = 3

EOS = 20
GO = 21
PAD = 22

In [4]:
#inverse encoder input: TBD
#r1.2 takes care of variable size input sequences
#make model such that it is independent of batch_size

### Encoder

In [5]:
encoder_inputs = tf.placeholder(tf.int32, [None, None], name="encoder_inputs") # [ BATCHSIZE, MAX_SEQLEN ]
encoder_embeddings = tf.get_variable(name="encoder_embedding_matrix",
                                     shape=[ENCODER_VOCABSIZE, ENCODER_CELLSIZE],
                                     initializer=tf.contrib.layers.xavier_initializer(uniform=True,seed=None,dtype=tf.float32),
                                     dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(encoder_embeddings, encoder_inputs) 
encoder_input_length = tf.placeholder(tf.int32, [None], name="encoder_input_length") # [ BATCHSIZE ]

In [6]:
e_cells = [tf.contrib.rnn.GRUCell(ENCODER_CELLSIZE) for _ in range(NLAYERS)]
encoder_cell = tf.contrib.rnn.MultiRNNCell(e_cells)

In [7]:
(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(encoder_cell, 
                                encoder_cell, 
                                encoder_inputs_embedded, 
                                sequence_length=encoder_input_length,
                                dtype=tf.float32)

In [8]:
encoder_outputs = tf.concat(outputs, 2) # [ BATCHSIZE, MAX_SEQLEN, 2*ENCODER_CELLSIZE]
encoder_final_states = [tf.concat(x, 1) for x in zip(output_states[0],output_states[1])] # [ BATCHSIZE, 2*ENCODER_CELLSIZE]

## Decoder

In [9]:
decoder_input = tf.placeholder(tf.int32, [None, None], name='decoder_input') # [ BATCHSIZE, SEQLEN ]
decoder_input_length = tf.placeholder(shape=[None],dtype=tf.int32,name='decoder_input_length')

decoder_embeddings = tf.get_variable(name="decoder_embedding_matrix",
                                     shape=[DECODER_VOCABSIZE, DECODER_CELLSIZE],
                                     initializer=tf.contrib.layers.xavier_initializer(),
                                     dtype=tf.float32)

decoder_input_embed = tf.nn.embedding_lookup(decoder_embeddings, decoder_input) 


decoder_targets = tf.placeholder(shape=[None, None],dtype=tf.int32,name='decoder_targets')
decoder_targets_length = tf.placeholder(shape=[None],dtype=tf.int32,name='decoder_targets_length')

batch_size_t = tf.placeholder(tf.int32, [1], name="batch_size_t")

In [10]:
attention_mech = tf.contrib.seq2seq.BahdanauAttention(DECODER_CELLSIZE, encoder_outputs,memory_sequence_length=encoder_input_length)
#num_units: convert memory(hs) W * hs and query(ht) into W * ht num_units size first
#memory: The memory to query; usually the output of an RNN encoder.
#normalize
#probability_fn: Converts the score to probabilities.  The default is @{tf.nn.softmax}.
#score_mask_value: (optional): The mask value for score before passing into `probability_fn`. The default is -inf. Only used if`memory_sequence_length` is not None.

#def __call__(query, previous_alignments):
#score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query),[2]) #v * tanh(W * hs + W * ht)
#alignments = self._probability_fn(score, previous_alignments) #previous_alignments are ignored in BahdanauAttention

##Applying attention wrapper on top most cell
d_cells = [tf.contrib.rnn.GRUCell(DECODER_CELLSIZE) for _ in range(NLAYERS)]
top_d_cell = tf.contrib.seq2seq.AttentionWrapper(d_cells[-1], attention_mech,output_attention=False)#read AttentionWrapper once more
d_cells[-1] = top_d_cell

#Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. array_ops.concat([inputs, state.attention], -1)
#Step 2: Call the wrapped `cell` with this input and its previous state.
#Step 3: Score & alignment the cell's output with `attention_mechanism`. alignments(a(s) = self._attention_mechanism(query=cell_output, previous_alignments=state.alignments)
#Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). sigma(a(s)*hs)
#Step 6: attention_layer_size!=None, attention = DenseLayer(attention_layer_size)(array_ops.concat([cell_output, context], 1))
#        else:                       attention = context
#output_attention == true return attention, next_state or return output, next_state

##Updating state of top cell to be equivalent to attention wrapper cell
decoder_initial_states = encoder_final_states
top_d_state = top_d_cell.zero_state(batch_size_t, tf.float32)
top_d_state = top_d_state.clone(cell_state=decoder_initial_states[-1])
decoder_initial_states[-1] = top_d_state
decoder_initial_states = tuple(decoder_initial_states)

decoder_cell = tf.contrib.rnn.MultiRNNCell(d_cells)

In [11]:
decoder_helper = tf.contrib.seq2seq.TrainingHelper(decoder_input_embed, decoder_input_length)
#Training Helper
#sample(time, outputs) -> argmax(output, -1)
#next_inputs(time, outputs, state) -> (allFinished?, decoder_input_embed[time+1], state)

decoder_output_layer = tf.contrib.keras.layers.Dense(DECODER_VOCABSIZE)

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, 
                                          decoder_helper, 
                                          decoder_initial_states, 
                                          decoder_output_layer)
#step(time, inputs, state)
#Step1: cell_outputs, cell_state = self._cell(inputs, state)
#step2: if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs)
#step3: sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state) which is just argmax
#step4: (finished, next_inputs, next_state) = self._helper.next_inputs(time=time,outputs=cell_outputs,state=cell_state,sample_ids=sample_ids)

In [12]:
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder)

In [13]:
Ylogits = final_outputs.rnn_output

In [14]:
Y_pred = tf.argmax(tf.nn.softmax(Ylogits),2)

In [15]:
loss_weights = tf.ones([BATCHSIZE,tf.reduce_max(final_sequence_lengths)], dtype=tf.float32, name="loss_weights")

In [16]:
loss = tf.contrib.seq2seq.sequence_loss(Ylogits, decoder_targets, loss_weights)
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)

## Training

In [17]:
def batch(inputs):
    sequence_lengths = [len(seq) for seq in inputs]
    batch_size = len(inputs)
    
    max_sequence_length = max(sequence_lengths)
    type(max_sequence_length)
    
    inputs_batch_major = np.ones(shape=[batch_size, max_sequence_length], dtype=np.int32)*PAD
    for i, seq in enumerate(inputs):
        for j, element in enumerate(seq):
            inputs_batch_major[i, j] = element
            
    return inputs_batch_major

def rnn_minibatch_sequencer(X, Y, batch_size, epochs):
    inputs = len(X)
    for ep in range(epochs):
        for i in range(int(inputs/batch_size)):
            encoder_input = X[i*batch_size: (i+1)*batch_size]
            encoder_input_len = [len(seq) for seq in encoder_input]
            y = Y[i*batch_size: (i+1)*batch_size]
            decoder_input = [[GO] + seq for seq in y]
            decoder_input_len = [len(seq) for seq in decoder_input]
            decoder_target =[(seq + [EOS]) for seq in y]
            decoder_target_len = [len(seq) for seq in decoder_target]
            yield batch(encoder_input),np.array(encoder_input_len),batch(decoder_input),\
            np.array(decoder_input_len),batch(decoder_target),np.array(decoder_target_len),np.array([batch_size]),ep        
        

In [18]:
#a,b,c,d,e,f,g,h = rnn_minibatch_sequencer(source, target, BATCHSIZE,10).__next__()


In [20]:
inn = tf.global_variables_initializer()
sess = tf.InteractiveSession()

sess.run(inn)

In [21]:
for a,b,c,d,e,f,g,h in rnn_minibatch_sequencer(source, target, BATCHSIZE, 10):
    feed_dict = {encoder_inputs:a,
               encoder_input_length:b,
               decoder_input:c,
               decoder_input_length:d,
               decoder_targets:e,
               decoder_targets_length:f,
               batch_size_t:g}
    
    _, pred, c = sess.run([train_op, Y_pred, loss], feed_dict=feed_dict)
    
    print("epoch {} loss {}".format(h, c))
    print(pred[:5])
    
    #e_s = sess.run([encoder_final_states],feed_dict=feed_dict)
    #print(e_s)
    #logi = sess.run([Ylogits],feed_dict=feed_dict)
    #print(logi)
    

epoch 0 loss 3.1360700130462646
[[ 7  7  7  7  7  7  7 17 17 21 21 21 21 21 21 21  6  6  6  6  6]
 [ 5  5  5  5 13 13 13 13  5  5  5 14 14 14 14  2  2  2  0  0  0]
 [15 18 18 18 18 11 11 19 19 19 19 19 17 17 17 17 17 17 17 17 17]
 [19 13 13 13 13 13 20 20 20 20 17 13 13 13  0  0  0  0  0  0  0]
 [16 16 16 16 16 16 13 13 13 13 13  9  9  9  9  9 17 17 17 17 17]]
epoch 0 loss 3.0731582641601562
[[19  5  5  5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1 18 18 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  5  5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 3  7  7 16 21 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 16 16 16 16 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 0 loss 3.0175206661224365
[[17 16 16 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  2  2 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  2 11 22 22 22 22 22 22 22 22 22 22 22 

epoch 0 loss 2.148702383041382
[[18  5  5  5  5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5 11 11 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18  6  6 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5  5  5  5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 0 loss 2.223414659500122
[[ 5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5  3  6 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 14  4 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 0 loss 2.1411707401275635
[[18 18 18 14 14 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 16 16 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 11 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5  5 22 22 22 22 22 22 22 22 22 22 22 22

epoch 0 loss 1.8916898965835571
[[ 1  1 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 18  2  2  2  2  2  2  2  2  2 10 10 20 20 20 20]
 [18 18 14 14 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 18 18 16 17 17 17 17 17 17 17 17 10 20 20 20 20]
 [18 18 18 18 18 18 18 18  2  2  2  2  2  2  2 10 20 20 20 20 20]]
epoch 0 loss 1.97297203540802
[[18  2  2  2  2  2  2 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 18 18 16 16  2  2  2  2  2  2  2  2 20 20 20 22]
 [18 18 18 18 10 10 10 10 10 10 10 10 10 10 10 10 10 20 22 22 22]
 [18 18 18 18 18 18 10 10 10 10 17 17 17 17 10 10 10 10 10 20 20]
 [18 18 18 18 16 16 16 16 16 16 16 16 10 10 20 20 22 22 22 22 22]]
epoch 0 loss 1.9484732151031494
[[18 11 11 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 15 15 15 10 10 10 10 20 20 22 22 22 22 22 22 22 22 22]
 [ 5  2  7 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 16 16 16 16 16 10 10 10 10 10

epoch 0 loss 1.795719861984253
[[ 5 10 10 10  0  0  0  0  0  0  0  0 20 20 22 22 22 22 22 22 22]
 [ 5 10 10 10 10 10 10  0  0  0  0  0  0  0  0 20 20 22 22 22 22]
 [ 5 10 10 10 10 10 10 10 10 10 10 10  1  1 20 20 22 22 22 22 22]
 [11 11 11 11 20 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5  5  5 10 10  1  1  1  1  1 20 20 22 22 22 22 22 22 22 22]]
epoch 0 loss 1.7927340269088745
[[ 5 17 17 17 17 17 17 17 17  0  0  0  0  0  0  0  0  0  0  0 20]
 [ 5  2 10 10 10 10  2  2  2  2 10 10 10 10  0  0  0  0  1  1 20]
 [ 2 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 20 22 22 22 22]
 [ 5  2 12 12 12 12 12 10 10 17 17 17  1  1  1  1 20 22 22 22 22]
 [ 5  5 17 17 17 17 17 17 17 17 17  1  1  1  1 20 20 22 22 22 22]]
epoch 0 loss 1.7651150226593018
[[ 5  8  8  8  8 20 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 15 15 15 20 20 20 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  6  6  6  6  6  6  6  6  6  6  1  1  1 20 20 22 22 22]
 [ 5  2  2  2  2  2  2  2  2  2  2  2  2  7  

epoch 0 loss 1.666081190109253
[[ 1 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 20 22 22 22 22]
 [ 9  9  9  9  9  9  9  9  9 20 20 22 22 22 22 22 22 22 22 22 22]
 [ 8 13 13 13 13 13 13 13  2  2  2  2  2  2  2  2  2 20 20 22 22]
 [ 8  8  8  9  9  9  9  9  9  9  9  9  9  9  2  6  6 20 20 22 22]
 [ 9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9 20 20 22 22]]
epoch 0 loss 1.802170991897583
[[18 18 15 14 14 14 14 20 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 20 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 17 17 17 20 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5 13 13  2  2  2 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 20 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 0 loss 1.7357920408248901
[[ 5 10 19 20 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 18 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 10 10 10 10 10  8  8  8  8  8  8 10 20 22 22 22 22 22 22 22]
 [ 5  1  1 14 14 14 14 14 14 14 14 14 14 14 14

epoch 1 loss 1.6237484216690063
[[10 10 15 15 15 15 15 15 15 15 15 15 15 15 15 20 22 22 22 22 22]
 [ 7  7  7  7  7  7  7  7 20 20 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 20 22 22 22 22]
 [ 5  5  5  5  5 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11 11 11 11 11 11 11 11 11 20 22 22 22 22 22 22 22]]
epoch 1 loss 1.6826621294021606
[[18 18 18 18 18 18 18 18 18 18 18 18 18 20 22 22 22 22 22 22 22]
 [ 5  5 13 13 13 13 13 13 20 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5  5  5  5  5  5  5  5  5 20 22 22 22 22 22 22 22 22 22 22]
 [14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 20 22 22 22 22]
 [15 11  6  6  6  6  6 20 20 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 1 loss 1.7719080448150635
[[11 11 11 11 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 18 18 18 18 18 18 20 20 22 22 22 22 22 22 22 22]
 [ 0  0  0  0  0  0 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1 15 15 15 15 15 15 20 20 22 22 22 22 22 

epoch 1 loss 1.5170819759368896
[[15 15 15 15 15 15 15 15 20 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  7  7 18 18 18 18 18 16 20 20 22 22 22 22 22 22 22 22 22 22]
 [17 13 13 13 13 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11 11 11 11 11 11 11 11 11 13 20 22 22 22 22 22 22]
 [14 14 15 15 15 15 15 15 15 15 20 22 22 22 22 22 22 22 22 22 22]]
epoch 1 loss 1.611401081085205
[[19 19 19 19 19 19 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 20 22 22]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  3  3  3 20 22 22 22 22]
 [17 17 11 11 11  6  6  6  6 20 20 22 22 22 22 22 22 22 22 22 22]
 [ 5 15 15 15 15 15 10 10 10 10 10 10 10 10 10 10 10  5 20 22 22]]
epoch 1 loss 1.655617356300354
[[11 11 11 11 11 11 11 11 11 20 20 22 22 22 22 22 22 22 22 22 22]
 [ 6 11 11 11 11 11 11 17 17 17 17 17 17 17 17 17 20 22 22 22 22]
 [14 14  0  0  0  0  0  0  0  0  0  0  0  0  0  0 20 22 22 22 22]
 [ 8  8  8  8  8  8  8  8  8  8  8  8  8  8  8

epoch 1 loss 1.6499165296554565
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 20 22 22]
 [ 5  4  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9 20 22]
 [ 5  5  5  5  5  5  5  5  5  5  5  5  5  5  5 20 22 22 22 22 22]
 [ 3  9  9  9  9  9  9  9  9  9  6 20 22 22 22 22 22 22 22 22 22]
 [16 14 14 14 14 14 14 14 14 14 14 14 14 14 20 22 22 22 22 22 22]]
epoch 1 loss 1.5133072137832642
[[14 14 14 14 14 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 14 14 14 14 14 14 14 14 14 14 20 22 22 22 22 22 22 22 22 22]
 [ 2  2  2  2  2  2 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 20 22 22 22 22]
 [ 5  5  5  5  5  5  5  5  5  5  5  5  5 20 22 22 22 22 22 22 22]]
epoch 1 loss 1.5163527727127075
[[15 10 10 10 14 14 14 14 14 14 14 20 22 22 22 22 22 22 22 22 22]
 [10 19 19 19 19 19 19 19 19 19 19 19 19 19  2  2 20 22 22 22 22]
 [ 5 17 17 17 17 17 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11 20 20 22 22 22 22 22 22 22 

epoch 1 loss 1.456110954284668
[[ 5 11 11 11 10 10 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0  7  7  7  7  7 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 13 14 14 14 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [19 19 19 19 19 19 19  7  7  7  7  7  7  7  7  7  7  7  7 20 22]
 [10  5  5  5  5  5  5  5  5 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 1 loss 1.422655463218689
[[15  1  1  1  1  1  1  1  1  1  1  1  1 20 22 22 22 22 22 22 22]
 [12 15 15 15  3  3  3  3  3  3  3  3  3  3  3  3 20 22 22 22 22]
 [18  6  6  6  6  6  6  6  6  6  6  6  6  6 10 10 10 20 22 22 22]
 [ 9 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 20 22 22 22]
 [ 0  0  0  0 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 1 loss 1.378897786140442
[[18 11 11  2  2  2  2  2 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [19 19 19 19 19 19  6  6  6  6  6  6  6  6  6  6  6  6  6 20 22]
 [17 11 11 11 11 11 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 17 17 17 17 17 17 17 17 17 16 16 20 22 22 

epoch 2 loss 1.418966293334961
[[17 17 17 17 17 17 17 12 12 12 12 12 12 12 12 12 12 12 12 12 20]
 [ 5  5  5  5  5  5 10 10 10 10 10 10 10 20 22 22 22 22 22 22 22]
 [ 6 17 17 17 17 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 13 13 13 19  8  8  8  8 20 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11  5  5  5  5  7  7  7  7  7  7 20 22 22 22 22 22 22 22]]
epoch 2 loss 1.3832626342773438
[[ 9  9  9  9  9  9  9  9 17 17 17 17 17 17 17 17 17 20 22 22 22]
 [18  7  7  7  3  3  3  3  3  3  3  3  3  3 20 22 22 22 22 22 22]
 [ 7  7  7  7  7  7  7  7 15 15 15 15 15 15 15 15 19 20 22 22 22]
 [18  6  6  6  6  6  6  6  6  6 20 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 20 22 22 22 22]]
epoch 2 loss 1.4321863651275635
[[ 5  5  5  5  3  3  3  3  3 20 22 22 22 22 22 22 22 22 22 22 22]
 [17  7  7  7  7  7 17 17 17 17 17 17 17 17 17 17 17 17 17 20 22]
 [ 2  7  7  7  7  3  3  3  3  3  3  3  3  3  3  3  5  5  5  5 20]
 [ 7  7  7 18 18 18 18 18 18 18 18 18 18  6  

epoch 2 loss 1.3408616781234741
[[ 6  6  6  0  0  0  0  0  0  0  0  0  0  0  0 17 20 22 22 22 22]
 [19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 20 22]
 [ 2  2  2  2  2  2  2  2  2  7 18 18 18 18 18 18 20 22 22 22 22]
 [ 6  6  6  6 15 15  4 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0  7  7  7  7  7 18 18 18 18 18 18 18 18 18 18 18 18 20 22 22]]
epoch 2 loss 1.2758808135986328
[[15  3  3  3  3  3 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 16 16 16 16  5  5  5  5  5  5  5  5  5  5  5  5 20 22 22]
 [ 7  7  7  7  7  7  7 15 15 15 15 12 12 12 20 22 22 22 22 22 22]
 [15 15 15 15 15 15 15 15 15 15 15 15 15 15 15  3  3  3  5  5 20]
 [11  9  9  9  9  9 15 15 15 15  3  3  3  3  3  3  3 20 22 22 22]]
epoch 2 loss 1.2970103025436401
[[13 15 15 15 15 15 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17  8  8  8  8  8  8  8 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  1  1  1  1 11 11 16 16 16 16 16 16 16 20 22 22 22 22 22 22]
 [17 17 17 17 17 17 20 22 22 22 22 22 22 22 

epoch 2 loss 1.31052565574646
[[ 1 15 12 12 12 12  3  3  3  3 20 22 22 22 22 22 22 22 22 22 22]
 [ 7  7  7  7  7  7  7  7  7  7  7  7  7 17 17 17 17 17 20 22 22]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  8  8  8  8  8  8 20]
 [ 6  6  6 18 18 17 17 17 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 10 10 10 10 10 10 10 10 10 10 10 16 13 13 13  9  9  9 20 22]]
epoch 2 loss 1.363384485244751
[[13 13 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 18 18 20]
 [19 19 17 17 17 17 19 19 19 19 19  8  8  8  8  8  8  8 20 22 22]
 [ 8  8  8  8  6  6  6 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  1 12 12 12 12 12 12 12 12 12 12 12 12 12 12  2  2 20 22 22]
 [ 3  3  3  3  3  3  3  1  1  1  1  1  1  1  1  4 20 22 22 22 22]]
epoch 2 loss 1.3986459970474243
[[17  9  9  9  9  9  9  9  9 20 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  5  5  5  5 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  6  6  6  6 15 15 15 15 20 22 22 22 

epoch 2 loss 1.2283961772918701
[[14 14 14 14 14 14 14 13 13 13 19 19 19 19 19  1 20 22 22 22 22]
 [17 17 17 17 17 17  0  0  0 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  5  5  5  5  5  5  5  5  1  1 20 22 22 22 22 22 22 22]
 [ 4  4  4  4  4  4  7  7  7  7  7  7 20 22 22 22 22 22 22 22 22]
 [ 7  7  9  9  9  9  9  9 14 14 14 14 14 14 14 14 14 14 14 20 22]]
epoch 2 loss 1.5652823448181152
[[ 9  9  9  9  9  9  9  9  9  9  9  9 13 13 13 13 13 13 14 14 20]
 [16 16 16 16 16 16  1  1  1  5  5  5  5  5  5  5 22 22 22 22 22]
 [ 1  1  1  1  5  5  5  5  5  5 16 16 16 16  0  0  0 22 22 22 22]
 [ 5  5  5  5  5  5  5  2  2  2  2 20 22 22 22 22 22 22 22 22 22]
 [12 12 12 12 12 12 12 12 12 12 12 18 18 18 18 18 18 22 22 22 22]]
epoch 2 loss 2.151710271835327
[[ 8  8  8 16 16 16 16 16 16 16 20 22 22 22 22 22 22 22 22 22 22]
 [ 0  0  0  0  0 16 16 16 20 20 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 16 16 16 16 16 16 16 16 16 20 22 22 2

epoch 2 loss 1.2209622859954834
[[ 7  7  7  7  7  7  7  9  9  9 11 11 11 20 22 22 22 22 22 22 22]
 [ 5  5  5  5  5  5  5  5  5  5  5  5  4  4  0  0  0  0  0 20 22]
 [15 15 15  6  6  6 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 16 16 13 13  1 20 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 10 10 10 10 10 10 16 16 16 16 16  4 20 22 22 22 22 22 22 22]]
epoch 3 loss 1.3034942150115967
[[14 14 14 14 14  3  3  3  3  3 12 12 12 12 12 12 20 22 22 22 22]
 [14 14 14 14 14 14 14 14 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 8  8  8  8  8  8  8  8  8  3  6  6  6  6  6  6  6 20 22 22 22]
 [ 2  2 18 18 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 4  4  4  4  4  4  6  6  6  6  6 20 20 22 22 22 22 22 22 22 22]]
epoch 3 loss 1.2635360956192017
[[ 5  5  3 15 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 3  3  3  0  0  0  0  0  0  0 16 16 16 16 16 16 16 20 20 22 22]
 [15 15 15 15 15 15 15 19 19 19  9  9  9  9 20 20 22 22 22 22 22]
 [ 0  0  4  4  4  4  4  4  4  4  4 11 11 11 

epoch 3 loss 1.2115737199783325
[[19 19 19 19 19  1  1  1  1  1  0  0  0  0  0  0  0 20 22 22 22]
 [ 4  4  4  4  4  4  4  4  6  6  6  6  6  6  6  6  6  6  6  6 20]
 [12 12  6  6  6  6  6  6  6  6  6  7  7  7 17 17 20 22 22 22 22]
 [10 11  8  8 20 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11  0 17 17 17 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 3 loss 1.2129727602005005
[[11  6  6  6  6  6  6  6  6 11 11  4  4  4  4  4  4  4  4 20 22]
 [ 8  8  3  1  1  1  1 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11  7  7 14 14 14 14 14 14 14 20 22 22 22 22 22 22]
 [ 6  6  6  6  6  6  6  6  6  6  6  6  9  9  9  9 20 22 22 22 22]
 [ 2  2  2  2  2  2  2  2  2  2  2  2  2  2  4  4  4  4  4  4 20]]
epoch 3 loss 1.2615294456481934
[[ 9  9  9  9  9  9 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13  8  8  6  6  6  6  6  6  6 17 20 22 22 22 22 22 22 22 22]
 [17 17 15 15 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 3  3  3  3  3  3  3 20 22 22 22 22 22 22 

epoch 3 loss 1.2511475086212158
[[ 7  7  7  3  3  3  3 17 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0  0  0  0  0  9  9  9  4 20 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 15 15 15 15 15 15 15 15 15 17 17 17 17 17 17 20]
 [ 1 17 17 17 17 17 17  3  3  3  3  3  3  3  3  9  9 20 22 22 22]
 [13 13 13 13 13  8  8  8  8 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 3 loss 1.160396695137024
[[ 1  1  1  1  1 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  2  2  2  2  2  2  6  6  6  6  6  4  4  4  4  4  4  4 20 22]
 [14 14 14 14 14  1  1  1 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 3  2  2  2  3  3  3  3  3  3 17 17 17 17 17 17 17 17  6  6 20]
 [ 6  6  6  6  6  6  6  6  6  6  6  6 16 16  8  8  8  8 14 20 22]]
epoch 3 loss 1.2490178346633911
[[ 8 16 16 16  2  2  2  2  2  2 20 22 22 22 22 22 22 22 22 22 22]
 [ 3  3  3  3  3  3  3  3  3  3  3  3  3 16  2  2  2 18 20 22 22]
 [ 3  3  3  3  3 10 10 10 10 10 13 16 11 11 11 11 11 20 22 22 22]
 [17 17 17 17 17 17 17 17 17 17 17 17 17 17 1

epoch 3 loss 1.2660130262374878
[[ 7 14 14 14 14 14  6  6  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [12 12 12 12  2  2  2  2  1  1  1  1 20 22 22 22 22 22 22 22 22]
 [ 4  4  4  4  4  4  4  4  4  4 14  1  1  1  1 10 10  1 20 22 22]
 [14 14 14 14 14 14 14 14  2  2  2  2 15 15 15 15 15 15  1  1 20]
 [19 19 19 19 19  4  4  4 13 13 13 13 13  5 20 22 22 22 22 22 22]]
epoch 3 loss 1.1489827632904053
[[10 10 10 10 10 10  2  4  4  4  4  4  4 20 22 22 22 22 22 22 22]
 [13 13 13 13 13  0  0  0  0  0  0  0  0 16 11 11 20 22 22 22 22]
 [ 6  6  6  6 13 13 13 13 13 10 18 18 18 18 16 20 22 22 22 22 22]
 [11 11 11 11 15  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  1  1  1  1  1  1  1  1 10 10 10 20 22 22 22 22 22 22 22 22]]
epoch 3 loss 1.140614628791809
[[ 9  9  9  9  9  9  9  7  7  7  7  7  7  7  7  7  7 15 15 13 20]
 [ 5  5  5  5  5  5  5 13 13 13 18 18 18 18 18 18 18 18 18 18 20]
 [16 16 16 16 16 16 16 16 16 16 16 16 16 16  7  7 20 22 22 22 22]
 [ 9  9  9  9 13 13 13 13 13 13 15 15 15 15 1

epoch 3 loss 1.096660852432251
[[16 16 16 11 11 12 12 12 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 10 10 10 10 10 10 10 10 10 10 10 10  1  1  1  4 16 16 20 22]
 [ 0  0  0  0  0  0  0  0  5  5  5  5  5  5 19 19 20 22 22 22 22]
 [12 12 12 12 12 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 8  8  8  8  8  8  8  4  4  4  4 18 18 18 18 18 20 22 22 22 22]]
epoch 3 loss 1.0664541721343994
[[18 18 18 18 18 18 18 11 11 13 13 13 13 13 13  9 20 22 22 22 22]
 [ 1  1  1 10 10 10 10 10  9  9 20 22 22 22 22 22 22 22 22 22 22]
 [13 16 16 16 16 16 16 16 13 13  6  6  6  6  6 10 10 10 20 22 22]
 [13 13 13 13 13 13 13 13 13 13 13 13  6  6  1  1  1  1 20 22 22]
 [17  8  8  8  8  8  8  8  9  2  2  2  2  2  2  2  2  2 20 22 22]]
epoch 3 loss 1.1773574352264404
[[18 18 18 16 16 14 14 14 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15  1 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7 17 17 17 17 17 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13  2  2  2 20 22 22 22 22 22 22 2

epoch 4 loss 1.1037040948867798
[[ 8  8  8  8  8  8  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [12 12 19 19 19  8 11 11 11 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 8  8  8  8  8  8 10 13 13 13 13 13  2  2  5  1  7 20 22 22 22]
 [ 3  3  3  3  3  3 12  8  8  8  8 13 13 13 13 13  4 13 13 13 20]
 [19 18 18 18 12 12 12 12 12 12 17 20 22 22 22 22 22 22 22 22 22]]
epoch 4 loss 1.062999963760376
[[10 10 10 10 10 10 10 10 15 15 15 15 15 15 15 20 22 22 22 22 22]
 [ 8  8  8  7  7  4  4  7  4 20 22 22 22 22 22 22 22 22 22 22 22]
 [16 12 12 16 16 16 16 16 16  9  9  9  9  9  9  9 20 22 22 22 22]
 [ 4  5  5  5  5  5 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 15 15 11 11 11 11 11 11 20 22 22 22 22 22 22 22]]
epoch 4 loss 1.102520227432251
[[ 1  4  4 18 18 18 18 18 18  8  8 12 12 20 22 22 22 22 22 22 22]
 [12 19 19 19 19 13 13 13 13 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  6 12 13  5  5  5  5 20 22 22 22 22 22 22 22 22 22 22]
 [16 16 16 16 16 15 15 15 15 15 15 15  6  5  5

epoch 4 loss 1.0141079425811768
[[ 5 13 13 12 12 12 12 12  6  6 14 14 20 22 22 22 22 22 22 22 22]
 [15 15 17 17 17 10 10 10 10 10 10  1  6  6  6  6 20 22 22 22 22]
 [ 0  6  6  6  6  6  6 16 16 16 16 16 14 19 19 19 20 22 22 22 22]
 [ 8  8  8  2 13 13  7  7 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11 11 11 11 11 13 13 13 13 13 13 13 13 20 22 22 22]]
epoch 4 loss 0.9821217060089111
[[ 8 15 15 15 15 15 15  9  0 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  7  7  7  7  8 18  3  3 16 20 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11 11 11 11  8  8  8  0  0 13 20 22 22 22 22 22 22]
 [18 18 18 18 18 15 15 15 15 15 20 22 22 22 22 22 22 22 22 22 22]]
epoch 4 loss 1.0682165622711182
[[19 19 19 19 19 19 19 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13  7  7 20 22 22]
 [ 2  6  6  6  4  4  7  7  7  1  1  1  1  1  1  0 20 22 22 22 22]
 [17 17 17 17 17  4  4 18 18  6 20 22 22 22 

epoch 4 loss 1.0324058532714844
[[ 6  6  6  6  6 18 18 18  8  8  8 16 16 16 20 22 22 22 22 22 22]
 [17  1  1  9 18 18 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 15 17 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13 13 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15 15 15  5  5 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 4 loss 1.108207106590271
[[17 17 17 17 17 17 17 17 17 17 18 18 18 18 18 18 18 18 20 22 22]
 [ 4  4  4  4  4  4  4  4  9 16 16 16 16 16 16 16 16 16 16 20 22]
 [14 14 14 14 14  5 19 19 19 19 19 16 16 16  5 20 22 22 22 22 22]
 [ 7 19  1  1  1  1  9  9  9  9  9 20 22 22 22 22 22 22 22 22 22]
 [12 12 12 12 12 12 12 12 12 12 12 18 18 14 20 22 22 22 22 22 22]]
epoch 4 loss 0.9906834959983826
[[ 6 14 14 14 14 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13 14 14 14  3  3  3 20 22 22 22 22 22 22 22 22 22]
 [12  2  2  2  2  2  2 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6 11 11 11 11 11 11 11 14 14 14 14 14 1

epoch 4 loss 0.949661135673523
[[12 12 12 12  1 19 19 19  8  8  8 20 22 22 22 22 22 22 22 22 22]
 [13 13 13 13  9  9  4 10 10  7 20 22 22 22 22 22 22 22 22 22 22]
 [17 17 17 17 17  0  0 10 10 10 10 10  6 20 22 22 22 22 22 22 22]
 [ 9 11  5  2  8  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 10  9  9  9  9  9  9 19 14 20 22 22 22 22 22 22 22]]
epoch 4 loss 0.9310810565948486
[[ 5  7  4 13 19 19 19 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 4  7  7  7  7 14 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 11 11 13 13  6 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6 19 19 19 19 19 16  7  7  7  7  7  7  7  5 11 11 11  3 20 22]
 [ 7 19 19 19 19  5  5  5  5 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 4 loss 0.9148685336112976
[[15 15 15 15  1  1  1  1  1  1  1  8  8 20 22 22 22 22 22 22 22]
 [ 9  9  9  9  9  9  9 15  3 15 15  3  3 17 17 17 20 22 22 22 22]
 [ 6  6  6  6  6  6  6  6  6 14  4  4  4  4  4 10 10 20 22 22 22]
 [ 9  9  9  9  9  9  9  9  4  4  4 11 11 11 1

epoch 5 loss 0.884148120880127
[[ 2  2  7 12 12 12 12 12 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  6  6  3  3  3 18 18 18 18 18 20 22 22 22 22 22 22 22]
 [ 8  8  8  8  8  5  5  5  8  8 16 16 16 16 19  2  2  2  0  0 20]
 [ 3  3  3  3 19  4 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 14  6  6 13 13 13 11  5  5  5 17  0 11  9 20 22 22 22 22 22]]
epoch 5 loss 0.9040859937667847
[[12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 13 13 13 20]
 [ 5  5  5  5  5  5  5 14 14 10 10 14  3 20 22 22 22 22 22 22 22]
 [ 6  3 12 17 13 13 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17  5 18 18 13 13 19 12 12 20 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 11 11 11 15 15 17 17 18 18 18 20 22 22 22 22 22 22 22]]
epoch 5 loss 0.8792818784713745
[[ 9 18 18  9  9 12 12 12 12 12 12  5  1 17  1 16 17 20 22 22 22]
 [18 18  8  8  8  8  8  8  3  7  7  7  7  6 20 22 22 22 22 22 22]
 [ 7  7  7  7  7  7  7  7  8  8  8  8  8  6 19 19 19 20 22 22 22]
 [18 18 18 18  6  6  6  6  0  0 20 22 22 22 2

epoch 5 loss 0.9013702273368835
[[ 4  4  4  4  4  4 18  7 13  3 13 13 20 22 22 22 22 22 22 22 22]
 [ 1  3 13 13 13 13 10 14 10 20 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13 13 13  7 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 3  3  3  3  9  9  7  2  4  4  1  1 18 11 17 17 16 16 20 22 22]
 [18  6  6 17 17  0 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 5 loss 0.8307300806045532
[[ 6  6  6  8  8  8  9  9  9  9  9 15 15 17 17 17 20 22 22 22 22]
 [14 12 12 12 12 12 12 12 12 12 12 12  5  5 19 19 19 19 14 20 22]
 [ 9  9  9  9  7  7  7  7  7  7  7  7  7 14 14 14 20 22 22 22 22]
 [ 6  6  6  0 15  1 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0  0  0  0  0  0 18 18 18 18 18 14 14  7  4  4  4 15 20 22 22]]
epoch 5 loss 0.7733476161956787
[[15  4 12 12  3  3 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  2 19 19 19 17 17 17  2  8  2 15 15  5 16 16 16 16 20 22 22]
 [ 7  7  7  7  7  3  3  3  9 12 12 12 12 11 20 22 22 22 22 22 22]
 [15 15 15 15 15 15 18 11 11 18 11  9  9  9 

epoch 5 loss 0.8752070665359497
[[19 19 19 19 19 19 14 14 15  0  0  0 20 22 22 22 22 22 22 22 22]
 [ 6 12 12 12 12 12 12 12 12  0  0  0  0 20 22 22 22 22 22 22 22]
 [12 12 12 12 12 12 19 19 19 19 19 19 19  7  7  7  8  0  0 20 22]
 [ 6 12 12 12 12 12 12  6  6  7  6 19 19 19 19 20 22 22 22 22 22]
 [17 17 19 19 14 14  8  6  5 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 5 loss 0.8341452479362488
[[ 7  7  7 16  5 16 10  9  3  3 20 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  6  6  6  6  6  6 12  9 12 17  7 17 17 13 13 20 22 22]
 [ 0  0  0  0  0  0  0  0  0  0  0 14  5  5  8 10  8  8  8  8  8]
 [ 6  6  6  6 18 10 17 17 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6  6 10 10 10 10 10 10 10 16 16 15 11 11 11 11  9 20 22]]
epoch 5 loss 0.8851592540740967
[[13 13 13 13 13  5  5 12 12 12 12 12 12 12 12 12 12 18 19 18 20]
 [17 17 17 17 17  7  7  7 18  8  6  8  8  8 12 12  8  6 20 22 22]
 [ 8  8  8  8 10 11  6 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  1  1  1  1  1  1  5  4  4  4  2  4 12 

epoch 5 loss 0.7353855967521667
[[16 14  5 18  0  0  0 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [14 14 14 14 14 14  2 17  2  2  2  1  2  3  3 11  3 20 22 22 22]
 [ 5  5  5  5  5 19 19  6 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 13  3  3 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0  3  3  1 10 10 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 5 loss 0.7260699272155762
[[14 14 14 14 14 14  2  7  7  8 19 19 19  1  1  1 20 22 22 22 22]
 [11  4 17 17 17 17  7  6  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [15 15 15 15 15  6  6  6  5  5  5  3  1 20 22 22 22 22 22 22 22]
 [ 4  4  4  4  4  4  0 11  7  7 12  7 20 22 22 22 22 22 22 22 22]
 [ 9  9  9  9  9  9  9  9  7  7 12 12  4  4 12 13 13 13 14 20 22]]
epoch 5 loss 0.7952955961227417
[[ 9  9  9  9  9  9  9  9  1  1 13  1  1  1 19  2  2 14 14 20 22]
 [ 2  2  2 12  7  1  1  1  1  1  5  5  5  8 17 20 22 22 22 22 22]
 [ 5 14  5  5  5  5  5  5 16 16 16 15  8  0 17 17 20 22 22 22 22]
 [ 5  5 15  1  1 17 17  2  2  2  2 20 22 22 

epoch 5 loss 0.6472902894020081
[[18 14 14  4  8  2  2  2 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 16 17  7  7  4  3 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 18 15  3  5 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [12 12  6 18  8  8  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 17  0 17  5 14 11  3 10 12 17 20 22 22 22 22 22 22 22 22 22]]
epoch 5 loss 0.6454424858093262
[[ 7  7  0  0  7  7  9  8  9  9  9 10 11 20 22 22 22 22 22 22 22]
 [ 8  8  8  8  8  5  5  5  4  4 14 14 14  0  0  0 15 16 16 20 22]
 [15 15 16  9  4  8  6 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 10 11  9 13  1  6 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7 10 10 14 14  7 10 10 16 16 16 19 19 20 22 22 22 22 22 22 22]]
epoch 6 loss 0.7186011672019958
[[ 3  3 18 18 18 18  6  6  6  6  3  3  8 12  9 12 20 22 22 22 22]
 [ 1 10 14 14 14 14  7  5 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  2  2  2  2  2  2 12 12 12 12 18  0  6  6  6  6 20 22 22 22]
 [12  2  3 18 13 20 22 22 22 22 22 22 22 22 

epoch 6 loss 0.5839869379997253
[[15 15  7 15 19 19 19  9  2  4  4 17 20 22 22 22 22 22 22 22 22]
 [14 14 14 14 14 14 14 14 14  0  2  9  2  0 20 22 22 22 22 22 22]
 [10 10 12 12 12 12  2  2  7 12 12 19 19 14 20 22 22 22 22 22 22]
 [18  6  6  6  6 11  9  9 17 17  3  0 17 17  8  8 20 22 22 22 22]
 [18  7  9  9  3  3  9  9 17  9  9  9 14 14 14 10 20 22 22 22 22]]
epoch 6 loss 0.607039749622345
[[19 19 19 19 19 19 14  8  3  2  2  2  2  0 13  0 20 20 22 22 22]
 [12 12 12 12 19 19 19  8  8  6  6  6  6  6  6  4 11 11  5 20 20]
 [12  6 12 14  5  5  5 15  4  0  0  0  1 18  7 17 20 22 22 22 22]
 [10 12  2 11  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11  2  0  6 17 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 6 loss 0.6100989580154419
[[10  6  6  6  6  6  6  6  6 15 15 15 15 13 14 13 13 19 19 20 22]
 [ 8  8 19  2  3  1  1 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11  7  7  7  7 12 19 19  5 14  0 13 14 20 22 22 22 22 22 22]
 [ 6  6  6  6  6  6  6  4 16 16  6 15 15  3 1

epoch 6 loss 0.5803384184837341
[[15  4 11  1  7  8  8  3 16 17 16 20 22 22 22 22 22 22 22 22 22]
 [12  1  4  4  6  8  4  8  1  5 20 22 22 22 22 22 22 22 22 22 22]
 [ 0  0  0  0 13 18 18 18  7 18 12  2  2  2  4  9 16 20 20 22 22]
 [17 17  9  9  9  9 19  2  4  4  4  4 16 15 20 22 22 22 22 22 22]
 [ 8  8  8 12 12  7 11  4  2 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 6 loss 0.6409836411476135
[[ 7  7  7  3  4  3 17 13 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  6 13 13 16  9  9  4 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  1  4  1 15  1 15 15 15 15 15 15 15 13 16 16 17 17  6 17 20]
 [10  2  2  1  1  1 17 17  4  3  3  3  5  5 12  9  9 20 22 22 22]
 [13 13 13 13  3  6 16 16  8 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 6 loss 0.5761889815330505
[[14  1  1  1  1 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 13 13  2  2  2  4  2  6  6  1 10  5  1  4 11 11  4  4 20 22]
 [14 14 14 10  8 16  7  1 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  2  2  0  0  0  3  0  0  3 17  2 17 17 

epoch 6 loss 0.49740514159202576
[[ 6  6  6  2  2  2  2  0  8  8  8  8  8  9 16 16  5  2 10 20 22]
 [17 17 17 17 17  1 12 15 11 11  0 18 20 22 22 22 22 22 22 22 22]
 [ 3 15  7  7  7 14  6  6  8  5  6 11  3  3  1 18 18 20 22 22 22]
 [ 0  0 13  2 12  7 11 16 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 18 17 17 13 14 19 14  5 17  1  0  0 10  5  4 20 22 22 22 22]]
epoch 6 loss 0.5759918689727783
[[14 14  5 16 10 14  3 12  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [12  9  9 12 12  2 13  6  2  3  4  1 20 22 22 22 22 22 22 22 22]
 [ 4  0  0  0  4  4 17 14 14  1  1  1 15 15  1 10  2  2 20 22 22]
 [14 14 14 14 14 14  2  7  2  2  2 15 15  3 16 16 15  9  1 13 20]
 [19 19 19  4 14 14  7 12  3  3  4 13  5 10 20 22 22 22 22 22 22]]
epoch 6 loss 0.5130829215049744
[[15  5 10 10 19  2  9  6  4  0  3  4 16 20 22 22 22 22 22 22 22]
 [13 14 14 14 14  0  1 13  8 17  0  0 10 19 16 11 20 22 22 22 22]
 [ 6  6 15 15 15 15 13  1 10 10 12 12 17 18 16 20 22 22 22 22 22]
 [11 11 11 15 12  8 20 22 22 22 22 22 22 22

epoch 6 loss 0.4457206130027771
[[11  3  5 12 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  8  6  6  6 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  5  5 11 15  6 15 15 15  2  2  2  4  4  4  3  9 14 14  1 20]
 [ 2  6  6  6  6 13  5  5  5  2  2  2 17 17 10  8  8 10 10 20 22]
 [ 1  6  6  5 16 16 16 19 19  1  7  5  5  1  7  7 20 22 22 22 22]]
epoch 6 loss 0.4181487262248993
[[16 16 11 11 11  5 12 12 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 10 12  9  9 10 13  9  9  1 17 18 18 13 10  9 16 16  4 20 22]
 [ 5  2  5  3  3  7 13 14  0  0  5  4 15  5  7 19 20 22 22 22 22]
 [16 12 12 12  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  8  8  8  4  9 15 15 18 16  8  4 11  7  9 18 20 22 22 22 22]]
epoch 6 loss 0.40777552127838135
[[18 18 18 18 11 19 19  1 11 11 11 13 16 16 13  9 20 22 22 22 22]
 [ 1  1 16 14 10 10 10 17 17  9 20 22 22 22 22 22 22 22 22 22 22]
 [13 16 16 16 16  5 17 17 15  6 13 12 11 11  2  6 10  9 20 22 22]
 [ 9 17  9  9 13 13 13 15 15  2  2  2  9 19

epoch 7 loss 0.3754153847694397
[[ 8 15 15  3  8 13  4 13 10  8  8  8  3  6  6  2  0 20 22 22 22]
 [11 11 11  4 19 17 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 3  3  6 11 10 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 13 14  6 14 14 14 14  1 13 13 10  9 20 22 22 22 22 22 22 22]
 [ 1  1  5  5 13 18 17 13 15 12 12  0 20 22 22 22 22 22 22 22 22]]
epoch 7 loss 0.42136985063552856
[[ 0  0 14 16  5  8 11 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [12  4  4  3 19  8  7 11 11 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 1 10  8  8 19 13  8 16 19 13 10 17 17  2  5  1  7 20 22 22 22]
 [ 3  3  3  3 12 12 12 14  8 11 11 18  0  4  4  4  4  7 13 13 20]
 [19 13  6 18  9 12 11 11 12 12 17 20 22 22 22 22 22 22 22 22 22]]
epoch 7 loss 0.3562184274196625
[[ 0 12 12 10 10 10 10  2  2  0 15 15 15 16 16 20 22 22 22 22 22]
 [ 8  8  4  4  3 17 12  7  4 20 22 22 22 22 22 22 22 22 22 22 22]
 [16 12 12  5 12  8 19  2 16  9 16 13  5  9  9 18 20 22 22 22 22]
 [ 4  2  5  5  5 15 20 22 22 22 22 22 22 22

epoch 7 loss 0.41474029421806335
[[12 19 18  1 15 16 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  7  5  5 17 13 19 19  6  8  6  7 20 22 22 22 22 22 22 22 22]
 [ 8 11 11 11  9 17 17 17 10  3 19 19 18 18 15 15  5  5  5  5 20]
 [19  0  8 18  5  4 10  6 19 19 19 19 19 17 17  5 14 18 20 22 22]
 [ 1  1 14  0 10 13  0  8 16  4  5  9 13 13 20 22 22 22 22 22 22]]
epoch 7 loss 0.43833425641059875
[[ 5  0 11 12 13  8 18 18 18  6  7 14 20 22 22 22 22 22 22 22 22]
 [15 15 15 17 17 17 10 10 10 10  1  2  1  6  6  6 20 22 22 22 22]
 [ 8  6  6 13  6  0 16 16 16 16 11 14 14 18 18 20 20 22 22 22 22]
 [18 18  5  2 11  0 13  7 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [13 11 11 11 11  3  7 19  3  5 13 13 13 13 10 19 19 20 22 22 22]]
epoch 7 loss 0.38699260354042053
[[ 8 17 11 15  1 15  9  0  0 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  7 19 19  8 14 18 13  3 16 20 22 22 22 22 22 22 22 22 22 22]
 [17 17 19  1 13 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 11 11 10 11 11 19 10 17 11  8 14  0 

epoch 7 loss 0.27323150634765625
[[18  8  8  8  8  1 17  4 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [19 10 19  4  2 12  6  2 12  7 20 22 22 22 22 22 22 22 22 22 22]
 [13 14  1  8 11  0 14 15 18 16  8 12 13 19 18  6 19  5  1 10 20]
 [ 7  7 12 15  1  3  1  4 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [19 19  7  7  3  9  2 16 14  6 20 22 22 22 22 22 22 22 22 22 22]]
epoch 7 loss 0.301116943359375
[[ 6  2 12 13 17 18  0  6 12 10  8  2 15 16 20 22 22 22 22 22 22]
 [17  1 14  3  9 18 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15  3 15 14  9 15 17 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 8  8 17  5  9 13 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 15 18 18  3  1 17 15  5 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 7 loss 0.32157912850379944
[[ 6  8  0 17 14 16  5 15 17  4  0  9  2 18 18  0  6 11 20 22 22]
 [ 4  5  4  4 10  9  5  0  6  1  9  9 11 11 10 14  9  9 16 20 22]
 [ 6 18  4 14  5  9  9 19  8 19  6 15 16  7  5 20 22 22 22 22 22]
 [19 19  4  1  3 18 10  9  9  9  0 20 22 22

epoch 7 loss 0.1896847039461136
[[ 4  1 18  1 10 10  4  4  1 13  6 19 20 22 22 22 22 22 22 22 22]
 [13 14 10 15 19  2  3 10  6  4  9 10 18 13  0 12  6 17 20 22 22]
 [ 6  2 15 17  6  8 18 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18  7 18 17 13 10  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0 10 13 14 12 11  7 20 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 7 loss 0.21096819639205933
[[18 12 12 12  1 19  1 19 10 11  8 20 22 22 22 22 22 22 22 22 22]
 [13 19 19 13  6  9  9  4 10  7 20 22 22 22 22 22 22 22 22 22 22]
 [ 7 17  0 11 15 16 16  0 12 10 10 10  6 20 22 22 22 22 22 22 22]
 [ 9 11  5 10  2  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  6  6 10 11  9  7 10  9  8  0 19 14 20 22 22 22 22 22 22 22]]
epoch 7 loss 0.19733484089374542
[[ 5  7  4 13 10 11 19 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 4  7  6  0  7 18 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [18 11 15  0 13  6 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  5 19 16 19 19  7  7  2  4 15  7  

epoch 8 loss 0.1642276495695114
[[18 19 18  1 19 19  1  9  5 15 10 17  7 17 11 16 16 16 20 22 22]
 [ 3 13  3 11  1 16 18 14  2 15 13  7  7 20 22 22 22 22 22 22 22]
 [ 9 10 19  2  7 14 13 16 17 10 11 17 20 22 22 22 22 22 22 22 22]
 [ 3  3 18 10 18  9 16  3  3 12 10 13 18 20 22 22 22 22 22 22 22]
 [17 15 12  0 15  5 10 16 17  1  1  1  1  5  3 18 13  9 17 15 20]]
epoch 8 loss 0.16084344685077667
[[ 2 13  7  8  3 12 12  7 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  6  2  6 15 12  7  2  3 18 16 18 18 20 22 22 22 22 22 22 22]
 [ 3  0  7  0 10  8 13  5  9  8  9 16 18 16 19 15 13  2 19  0 20]
 [ 3  3  3 11 19  4 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 14  6 18  9  2 13 13 11  4  5 17  0 11  9 20 22 22 22 22 22]]
epoch 8 loss 0.15739168226718903
[[17  4  8 11 16 12 11 17  9 12 14  8 17 15 12 12 12  1  0 13 20]
 [ 9  3 12  7  5  6 15  5 14 10 10 14  3 20 22 22 22 22 22 22 22]
 [ 6  3 12  8 17 13 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17  5 18  4 15 13 19  8 12 20 22 22 22 2

epoch 8 loss 0.13721740245819092
[[ 1 10 11  5  9 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7 15  8 14  8 14  5  2  0 12 13  6 18 12 20 22 22 22 22 22 22]
 [ 0 10  7 13  2 13 11 15  6 12 15 13 20 22 22 22 22 22 22 22 22]
 [ 1 11  9 18 10 18 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 6  1  8 14  8  0  2 15  0  8  6  6  1 20 22 22 22 22 22 22 22]]
epoch 8 loss 0.1488630771636963
[[ 3  4 18 19  4  6 18  7 13  3  6 13 20 22 22 22 22 22 22 22 22]
 [ 1  3 13  8 10 19 13 14 10 20 22 22 22 22 22 22 22 22 22 22 22]
 [13  1 11 12 13  7 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [13  3  1 18  9 12  7  2  9  4  3  1 18 11 18 17 18 20 20 22 22]
 [18  6  5 19 17  0 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 8 loss 0.14894506335258484
[[ 6  6  6  8  0 13  9 10  8 12  9  6 15  0 18 17 20 22 22 22 22]
 [19  9 14 12  0  4 18 19  4 12 12  6  1  5 19  6 19  6 19 20 22]
 [ 5 16  9  9  7  2 18  8 19  7  7  6  2 18 15 14 20 22 22 22 22]
 [ 6  6  9  0  4  1 15 20 22 22 22 22 22 2

epoch 8 loss 0.2533426880836487
[[ 6 13 14 18 16  1 18 18  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 14 11 17  1 15  2 15 11  2  7  7 11 17 12  8  8 20 22 22]
 [ 9 11  0  8  6  3  1  4 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [12  0  0 12  0  5 11 11 14  6  9 20 22 22 22 22 22 22 22 22 22]
 [16 16 16 13 13 13 13 14  5  9  9 15 20 22 22 22 22 22 22 22 22]]
epoch 8 loss 0.22208982706069946
[[ 3 19  7  7 19 19  4 14 15  2  0  0 20 22 22 22 22 22 22 22 22]
 [ 6  6  6 12 14  9 12 12 12 10  0  0  0 20 22 22 22 22 22 22 22]
 [ 0  0 12 10 12  7  7 19 18  4  9  5  4 12  7  8  8  0  0 20 22]
 [ 6  6 15 13 12 18  5  3  4  7  6  2  0 19  8 20 22 22 22 22 22]
 [13 13 16 19  5 14  8  6  5 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 8 loss 0.24779029190540314
[[14  7 15  1 16 16 10  9 12  3 20 22 22 22 22 22 22 22 22 22 22]
 [ 7  7  3  2 19 18  7  7 16 16 13  9 12  7 17 17 13 13 20 22 22]
 [ 3  0 19 16 13 12  0  0  0  0 14 14  8  8 10 10  4  8  0 20 20]
 [ 6  6  1  4 18 10 13 17 20 22 22 22 22 2

epoch 8 loss 0.14588816463947296
[[ 3  8 17  4  6  5  1 11 18 15 15 15  4  1 15 15  7  6 20 22 22]
 [15 12 11 16  9 11  0  2  2  6 15  6  3 20 22 22 22 22 22 22 22]
 [ 2  1  4 17 11  0 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 9 10  9 14 14  9  1 10 19  7  9 20 22 22 22 22 22 22 22 22 22]
 [16 11  1  7  4 12 13 16  5 16 19 18 12 10 10 18 20 22 22 22 22]]
epoch 8 loss 0.12989574670791626
[[16 14  5 18 12  0  0 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  7 14 14 11 11 15 17 10  7  2  1  2  6  3 11  3 20 22 22 22]
 [ 0 14  5  5  4 15 19  6 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [13  3 13 10  3 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0  5  3  1  9 10 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 8 loss 0.12422528862953186
[[ 5 12 13 14 14 12  2 11  7  8  0 13 19  1  1 20 20 22 22 22 22]
 [ 4  4 16 17  9 17  7  6  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  3 13 15 16  6 15  6 19  5  5  3  1 20 22 22 22 22 22 22 22]
 [ 2 17  4 18 12  4  0 11  7 16 12  7 20 

epoch 8 loss 0.09395720064640045
[[16  4  7  6  6 17  9  6  9 11  0  3  0 12 20 22 22 22 22 22 22]
 [ 8  8  9 17 16  9  4  9 19 11 13 17  5  9 20 22 22 22 22 22 22]
 [10  4  9  2 14 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [15 15  1 19  4 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  0 16 14  6  5 12  7  8  1  7 11 18 16  2 20 22 22 22 22 22]]
epoch 8 loss 0.09303603321313858
[[18  2 14  4  8 18 11  2 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16  1 17 12  7  4  3 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 9  4 18 15  3  5 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7 12  6 18  8 11  8 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 17  0 17  5 14 11  3 10 12 17 20 22 22 22 22 22 22 22 22 22]]
epoch 8 loss 0.08678176254034042
[[12 12  6  0 11  7  9  8  4 18  9 10 11 20 22 22 22 22 22 22 22]
 [17 17 17  8  8 15  5  5  4  4  5 18 14  6 12  0 15  0 16 20 22]
 [10 15 16  9  4  8  6 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [16 16 10 11  9 13  1  6 20 22 22 22 22 

epoch 9 loss 0.0731380507349968
[[14  7 16  0  8 11 12 17  5 20 22 22 22 22 22 22 22 22 22 22 22]
 [13 15  2 12 14 14 12 18  2 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 2  8  3  4  6 18  6 16  9 11 17  7 10 12 20 22 22 22 22 22 22]
 [19 17 14  8  1  1 14  8 15 20 22 22 22 22 22 22 22 22 22 22 22]
 [14  8  1 11 12 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 9 loss 0.07507845014333725
[[15  3 17 15  7 19 19  9  2  0  4 17 20 22 22 22 22 22 22 22 22]
 [ 5 19 14 17 14 18  2 10 14  0 18  9  2  0 20 22 22 22 22 22 22]
 [16 10 19  4 12 12  1  2  7  4 12  0 19 14 20 22 22 22 22 22 22]
 [ 7  7 17  8  6 13 11  9 15 17  3  0  2 17  7  8 20 22 22 22 22]
 [18  7  9 13 16  3 19 15 17  9  8  9 15 18 14 10 20 22 22 22 22]]
epoch 9 loss 0.07252027839422226
[[ 1 19  8 19 19 19  1 14  3  3  8  9  2  0 13  1  0 20 22 22 22]
 [13  4 12 12  0 19 19  4  8 13  3  6 15  6  6  4 11 11 10 10 20]
 [14  6  7 12 19  1  5 15  4 14  0  6  1 18  7 17 20 22 22 22 22]
 [10 12  2 11  8 20 22 22 22 22 22 22 22 2

epoch 9 loss 0.06384650617837906
[[ 1 11 17 11 11  0 12 18 10  0 15 12  3  8  8 20 22 22 22 22 22]
 [ 0  0  8 13 19  7  1 20 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [11  2 13  7 19  6  2  1 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 0 16  4  1 15  7  1  5  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [15  0 10 17 11 17 16 11 14 10 19  5 19 13 20 22 22 22 22 22 22]]
epoch 9 loss 0.0632747933268547
[[15  4 11  1  7  6  8  3 18 17 16 20 22 22 22 22 22 22 22 22 22]
 [12  1  8  4  6  9  4  8  1  5 20 22 22 22 22 22 22 22 22 22 22]
 [10 14  7  0 13  6 18 13  7 18 12 15  2  2  4  9 12 16 20 22 22]
 [14 17 17  9 13  9 19  2 15  6  4  4 16 15 20 22 22 22 22 22 22]
 [ 7 11  8  5 12  7 11  4  2 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 9 loss 0.06723449379205704
[[ 7  7  7  3  4  3 17 13 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [11  0  6  3 13 16  8  9  4 20 22 22 22 22 22 22 22 22 22 22 22]
 [ 1 13 10  4 15 12  9  1  7 17 17 15 15 13 15 16  5 17  6 17 20]
 [10 10  2  1  0 15  1 17  4 17  3  3  3  

epoch 9 loss 0.05561337247490883
[[15  3  5  3 17 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 9  9 10  3 13 11 13  0  2 10  5  1  4  1  5 15 10 20 22 22 22]
 [ 2 10 13 18  5  7  4 12 19  6  8  8 20 22 22 22 22 22 22 22 22]
 [10  2 12  5  1 19  2 14 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 10  1  1 11 16 10 16  0 20 22 22 22 22 22 22 22 22 22 22 22]]
epoch 9 loss 0.04581916704773903
[[12  9  6  5  6  2  2  0 10  6  8  0  8  9 19 16  5  2 10 20 22]
 [17  1 17  3 17  1 12 15 11 11  0 18 20 22 22 22 22 22 22 22 22]
 [ 0 15  3  1  7 14 12  6  8  5  6 11 18  3  1  9 18 20 22 22 22]
 [ 7  0 13  2 12  7 11 16 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [10 18 14 17 13 12 19 14  5 17  1  0  0 10  5  4 20 22 22 22 22]]
epoch 9 loss 0.053632091730833054
[[ 7 14  5 16 10 14  3 12  6 20 22 22 22 22 22 22 22 22 22 22 22]
 [12 10  9  1 12  2 13  6  2  3  4  1 20 22 22 22 22 22 22 22 22]
 [12  4  9  0 11  4  6  7 14 17 14  1  4 15  1 10 14  2 20 22 22]
 [14 14 16  4 14  1  8  7 17  2  2  0 15

epoch 9 loss 0.04527127370238304
[[11 18 12 19 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 15  8  6 16  5 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 2 13  1 16 15  4 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 5  6  9 19  5  9 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [17 10 16  8  6 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22]]
epoch 9 loss 0.044146955013275146
[[11  3  5 12 15 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 7  8  6  1  6 20 22 22 22 22 22 22 22 22 22 22 22 22 22 22 22]
 [ 1  8  5 11 15  6 10 16 15 18  4  2  1  7  4  3  9 11 14  1 20]
 [ 2 19  2  6  6 13  5  6  5  7 11  2 17 17 10 17  8 13 10 20 22]
 [ 1 18  6  5 17  9 16 12 19  1  7 17  5  1 12  7 20 22 22 22 22]]
epoch 9 loss 0.04470420628786087
[[16 16 12 11 11  5  8 12 20 22 22 22 22 22 22 22 22 22 22 22 22]
 [11 10 12 16  9 10 13 12  9  1 17 19 18 13 10  1  9 16  4 20 22]
 [ 6  2  5  0  3  7 13 14  8  0  5  4 15  5  7 19 20 22 22 22 22]
 [16 13 12 12  8 20 22 22 22 22 22 22 22

In [None]:
encoder_final_states

## Inference

In [22]:
decoder_start_token = tf.placeholder(tf.int32, [None], name="decoder_start_token") # [ BATCHSIZE ]
decoder_end_token = EOS
decoder_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,decoder_start_token,decoder_end_token)
#initialize = finished, first_inputs = embedding_lookup(start_token)
#next_inputs (time, outputs, state, sample_ids) -> (finished=sample_id==end_token, next_input=embedding_lookup(sample_ids), state)
#sample_id is argmax of the output, so decoder should give output which are after softmax

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, decoder_helper, decoder_initial_states,decoder_output_layer)
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder)

In [45]:
#YlogitsTest = final_outputs.rnn_output
#Y_pred_test = tf.argmax(tf.nn.softmax(YlogitsTest),2)

In [89]:
input_seq = np.array([int(x) for x in '1 2 3 4 5 6'.split()], dtype=np.int32)
input_seq=input_seq.reshape((1,len(input_seq)))

In [90]:
input_seq_len = np.array(input_seq.shape[1]).reshape((1))

In [173]:
predicted_seq, predicted_seq_len = sess.run([final_outputs.sample_id, final_sequence_lengths],\
         feed_dict={decoder_start_token:np.array([GO]) , 
                    encoder_inputs:input_seq, 
                    encoder_input_length:input_seq_len,
                    batch_size_t:np.array([1])})

print("{}, {}".format(predicted_seq[0], predicted_seq_len[0]))#first prediction for batch

[ 6  5  4  3  2  1 20], 7


In [174]:
beam_width = 5
beamsearch_decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell,
                                                          decoder_embeddings,
                                                          decoder_start_token,
                                                          decoder_end_token,
                                                          decoder_initial_states,
                                                          beam_width,
                                                          decoder_output_layer)

#np.tile(decoder_start_token, beam_width) is applied 
#therefore, when decoder_start_token = [batch_size]
#then decoder_initial_states = ([batch_size*beam_width, cell_size])
#therefore, encoder_inputs = ([batch_size*beam_width, seq_len])
final_outputs_bs, final_state_bs, final_sequence_lengths_bs = tf.contrib.seq2seq.dynamic_decode(beamsearch_decoder)

predicted_seq_bs, predicted_seq_len_bs = sess.run([final_outputs_bs, final_sequence_lengths_bs],\
         feed_dict={decoder_start_token:np.array([GO]) , 
                    batch_size_t:np.array([beam_width]),
                  encoder_inputs: np.tile(input_seq,(beam_width,1)),
                  encoder_input_length:np.tile(input_seq_len,beam_width)})

print("{}, {}".format(predicted_seq_bs.predicted_ids[0].T, predicted_seq_len_bs[0]))#first prediction for batch

[[ 6  5  4  3  2  1 20]
 [ 6  4  5  3  2  1 20]
 [ 5  6  4  3  2  1 20]
 [ 6  5  4  2  3  1 20]
 [ 5  6  4  2  3  1 20]], [7 7 7 7 7]


In [175]:
saver = tf.train.Saver()
save_path = saver.save(sess, "seq2seq_reverse.ckpt")
print("Model saved in file: %s" % save_path)

Model saved in file: seq2seq_reverse.ckpt
