from collections import namedtuple import shutil import os import numpy as np import pandas as pd from sklearn.model_selection import train_test_split import tensorflow as tf import tensorflow_addons as tfa NUMBER_OBSERVATIONS = 500 MAX_VOCAB_TOKEN = 100 MAX_SEQUENCE_LENGTH = 10 GO_SYMBOL = MAX_VOCAB_TOKEN + 1 END_SYMBOL = 0 EMBED_SIZE = 10 RNN_UNITS = 256 def sequence_length_3D(sequence): used = tf.sign(tf.reduce_max(tf.abs(sequence), 2)) length = tf.reduce_sum(used, 1) length = tf.cast(length, tf.int32) return length # # Custom Layer and Model # class CustomEncoder(tf.keras.layers.Layer): def __init__(self, vocab_size, embed_size, rnn_units): super(CustomEncoder, self).__init__() self.embeddings_enc = tf.keras.layers.Embedding(vocab_size, embed_size) self.encoder = tf.keras.layers.LSTM(rnn_units, return_sequences=True, return_state=True) def call(self, inputs, training=True, mask=None): encoder_embeddings = self.embeddings_enc(inputs) encoder_outputs = self.encoder(encoder_embeddings) encoder_output = encoder_outputs[0] encoder_end_state = encoder_outputs[1:] return encoder_output, encoder_end_state class CustomDecoder(tf.keras.layers.Layer): def __init__(self, vocab_size, embed_size, rnn_units): super(CustomDecoder, self).__init__() self.decoder_embeddings = tf.keras.layers.Embedding(vocab_size, embed_size) self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units) self.output_layer = tf.keras.layers.Dense(vocab_size) self.sampler = tfa.seq2seq.sampler.TrainingSampler() self.attention_mechanism = tfa.seq2seq.LuongAttention(units=rnn_units) self.decoder_rnncell = tfa.seq2seq.AttentionWrapper( self.decoder_rnncell, self.attention_mechanism, attention_layer_size=rnn_units ) self.decoder = \ tfa.seq2seq.basic_decoder.BasicDecoder( self.decoder_rnncell, self.sampler, output_layer=self.output_layer ) def build_decoder_initial_state(self, batch_size, encoder_state, dtype): decoder_initial_state = self.decoder_rnncell.get_initial_state( batch_size=batch_size, dtype=dtype ) decoder_initial_state = decoder_initial_state.clone( cell_state=encoder_state ) return decoder_initial_state def decoder_greedy( self, encoder_output, encoder_end_state=None, training=None ): # ================ Setup ================ batch_size = encoder_output.shape[0] encoder_sequence_length = sequence_length_3D(encoder_output) # ================ predictions ================= greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler() decoder_input = tf.expand_dims([GO_SYMBOL] * batch_size, 1) start_tokens = tf.fill([batch_size], GO_SYMBOL) end_token = END_SYMBOL decoder_inp_emb = self.decoder_embeddings(decoder_input) if self.attention_mechanism is not None: self.attention_mechanism.setup_memory( encoder_output, memory_sequence_length=encoder_sequence_length ) decoder_initial_state = self.build_decoder_initial_state( batch_size, encoder_state=encoder_end_state, dtype=tf.float32 ) decoder = tfa.seq2seq.BasicDecoder( cell=self.decoder_rnncell, sampler=greedy_sampler, output_layer=self.output_layer ) # ================generate logits ================== maximum_iterations = MAX_SEQUENCE_LENGTH # initialize inference decoder decoder_embedding_matrix = self.decoder_embeddings.variables[0] ( first_finished, first_inputs, first_state ) = decoder.initialize( decoder_embedding_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state ) inputs = first_inputs state = first_state # create empty predictions tensor predictions = tf.convert_to_tensor( np.array([]).reshape([batch_size, 0]), dtype=tf.int32 ) # create lengths tensor lengths = tf.zeros([batch_size], dtype=tf.int32) already_finished = tf.cast(tf.zeros([batch_size], dtype=tf.int8), dtype=tf.bool) # build up logits for j in range(maximum_iterations): outputs, next_state, next_inputs, finished = decoder.step( j, inputs, state, training=training) inputs = next_inputs state = next_state one_prediction = tf.expand_dims(outputs.sample_id, axis=1) predictions = tf.concat([predictions, one_prediction], axis=1) already_finished = tf.logical_or(already_finished, finished) lengths += tf.cast(tf.logical_not(already_finished), dtype=tf.int32) predictions = tf.cast( predictions, tf.int64, name='predictions_{}'.format(self.name) ) return predictions def call(self, inputs_targets_encoder_outputs,training=None, mask=None): inputs = inputs_targets_encoder_outputs[0] targets = inputs_targets_encoder_outputs[1] encoder_outputs = inputs_targets_encoder_outputs[2] encoder_state = inputs_targets_encoder_outputs[3] batch_size = inputs.shape[0] # create embeddings for target decoder_emb_inp = self.decoder_embeddings(targets) # setup for Attention self.attention_mechanism.setup_memory( encoder_outputs, memory_sequence_length= \ tf.ones([batch_size], dtype=tf.int32) * MAX_SEQUENCE_LENGTH ) decoder_initial_state = self.build_decoder_initial_state( batch_size, encoder_state, tf.float32 ) # generate logits for embedded targets ( final_outputs, final_state, final_sequence_lengths ) = self.decoder( decoder_emb_inp, # output from the encoder initial_state=decoder_initial_state, sequence_length=tf.ones([batch_size], dtype=tf.int32) * MAX_SEQUENCE_LENGTH ) return final_outputs.rnn_output class CustomModel(tf.keras.models.Model): def __init__(self): super(CustomModel, self).__init__() self.encoder_layer = CustomEncoder( GO_SYMBOL + 1, EMBED_SIZE, RNN_UNITS ) self.decoder_layer = CustomDecoder( GO_SYMBOL + 1, EMBED_SIZE, RNN_UNITS ) self.metric_function = tf.keras.metrics.Mean() def loss_function(self, y, y_pred): #shape of y [batch_size, output_sequence] #shape of y_pred [batch_size, output_sequence, output_vocab_size] sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none') loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred) mask = tf.logical_not(tf.math.equal(y,0)) #output 0 for y=0 else output 1 mask = tf.cast(mask, dtype=loss.dtype) loss = mask* loss loss = tf.reduce_mean(loss) return loss def call(self, inputs_targets, training=None, mask=None): if training: inputs = inputs_targets[0] targets = inputs_targets[1] encoder_outputs, encoder_state = self.encoder_layer(inputs) logits = self.decoder_layer( (inputs, targets, encoder_outputs, encoder_state) ) return logits else: inputs = inputs_targets encoder_output, encoder_state = self.encoder_layer(inputs) response = self.decoder_layer.decoder_greedy( encoder_output, encoder_state, training=False ) return response # # Helper functions # GeneratedData = namedtuple('GeneratedData', 'train test') def generate_data(): # function generates train and test data sets NUMBER_OBSERVATIONS = 500 MAX_VOCAB_TOKEN = 100 MAX_SEQUENCE_LENGTH = 10 # generate lengths for each sequence inp_sequence_len = np.random.randint(3, MAX_SEQUENCE_LENGTH + 1, size=NUMBER_OBSERVATIONS) out_sequence_len = np.random.randint(3, MAX_SEQUENCE_LENGTH + 1, size=NUMBER_OBSERVATIONS) # generate sequence of specified length inp_sequences = [np.random.randint(1, MAX_VOCAB_TOKEN + 1, size=l, dtype=np.int32) for l in inp_sequence_len] inp_sequences = np.array([np.pad(a, (0, MAX_SEQUENCE_LENGTH - a.shape[0]), constant_values=0) for a in inp_sequences]) out_sequences = [np.random.randint(1, MAX_VOCAB_TOKEN + 1, size=l, dtype=np.int32) for l in out_sequence_len] out_sequences = np.array([np.pad(a, (0, MAX_SEQUENCE_LENGTH - a.shape[0]), constant_values=0) for a in out_sequences]) # create train/test data sets train_X, test_X, train_y, test_y = train_test_split( inp_sequences, out_sequences, train_size=0.7 ) return GeneratedData( {'X': train_X, 'y': train_y}, {'X': test_X, 'y':test_y} ) def batcher(X, y, batch_size): num_batches = np.ceil(X.shape[0] / batch_size) idxs = np.arange(X.shape[0]) np.random.shuffle(idxs) idxs = np.array_split(idxs, num_batches) for idx in idxs: yield X[idx], y[idx] def train_step(inputs_targets, model, optimizer): with tf.GradientTape() as tape: logits = model(inputs_targets, training=True) # print("in training", y.shape, y_hat.shape) loss = model.loss_function(inputs_targets[1], logits) # print("\tbatch training loss:", loss.numpy()) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) model.metric_function.update_state(loss) def predict(inputs, model): ans = model(inputs, training=False) return ans # # Make repeatable # np.random.seed(13) tf.random.set_seed(13) # # generate training and test data # generated_data = generate_data() # # clean out old results # shutil.rmtree('./results', ignore_errors=True) os.mkdir('./results') # # Perform initial model training # print("initial model training") model1 = CustomModel() optimizer = tf.keras.optimizers.Adam() for epoch in range(5): model1.metric_function.reset_states() batches = batcher(generated_data.train['X'], generated_data.train['y'], 16) for X, y in batches: train_step((X, y), model1, optimizer) print("training loss metric, epoch", (epoch+1), model1.metric_function.result().numpy()) # # Save predictions from test data set # preds1 = predict(generated_data.test['X'], model1) # # save weights from trained model # weights1 = model1.get_weights() print("saving initial trained model weights") model1.save_weights('./results/model_weights') # # restore model weights # print("\nrestore saved model weights") model2 = CustomModel() model2.load_weights('./results/model_weights') weights2_before_predictions = model2.get_weights() # # generate predictions from test data set using restored model # preds2 = predict(generated_data.test['X'], model2) weights2_after_predictions = model2.get_weights() # # Compare predictions and weights between initial trained model and restored model # print("\ncomparing predictions and weights") try: assert np.all(np.isclose(preds1, preds2)) print("OK predictions matched") except AssertionError: print("FAILED prediction comparision") try: assert np.all([np.all(np.isclose(w1, w2)) for w1, w2 in zip(weights1, weights2_after_predictions)]) print('OK model weights matched after predictions matched') except AssertionError: print("FAILED weight comparision AFTER predictions") try: assert np.all([np.all(np.isclose(w1, w2)) for w1, w2 in zip(weights1, weights2_before_predictions)]) print('OK model weights before predictions matched') except AssertionError: print("FAILED weight comparision BEFORE predictions") print("all done")