In [1]:

# environment setup
import os
# one API particularly on Intel architectures Deep Neural Network Library (oneDNN)
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# NVIDIA's GPU-accelerated library for Deep Neural Network Library (cuDNN)
os.environ['TF_ENABLE_CUDNN_OPTS'] = '1'

# The output messages generated during the execution of DeepReg functions.(0..4)
os.environ['DEEPREG_LOG_LEVEL'] = '2' # WARNING: Indicates potential issues that are not necessarily errors.

# TensorFlow to control the logging level (0..4)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # WARNING logs and above are shown (removes INFO and DEBUG logs).



import tensorflow as tf
from tensorflow.keras.layers import GRU, LSTM, SimpleRNN



In [2]:
# Define sample input shape (e.g., 16 sequences, 5 time steps, 3 features)
batch_size = 16
time_steps = 5
features = 3
input_data = tf.random.normal((batch_size, time_steps, features), dtype = tf.float32)
print(input_data.shape)

# number of units per layer
units_size_per_layer  = 4

(16, 5, 3)


In [3]:
# input / ouput, SimpleRNN
SimpleRNN_whole_sequence_output, SimpleRNN_final_memory_state = SimpleRNN(units_size_per_layer, return_sequences=True, return_state=True)(input_data)

print(SimpleRNN_whole_sequence_output.shape)
print(SimpleRNN_final_memory_state.shape)

(16, 5, 4)
(16, 4)


In [4]:
# input / ouput, LSTM
LSTM_whole_sequence_output, LSTM_final_memory_state, LSTM_final_carry_state = LSTM(units_size_per_layer, return_sequences=True, return_state=True)(input_data)

print(LSTM_whole_sequence_output.shape)
print(LSTM_final_memory_state.shape)
print(LSTM_final_carry_state.shape)

(16, 5, 4)
(16, 4)
(16, 4)


In [5]:
# input / ouput, GRU 
GRU_whole_sequence_output, GRU_final_memory_state = GRU(units_size_per_layer, return_sequences=True, return_state=True, unroll=True)(input_data)

print(GRU_whole_sequence_output.shape)
print(GRU_final_memory_state.shape)

(16, 5, 4)
(16, 4)


In [6]:
# another approach to visualizing results
out = LSTM(units_size_per_layer, return_sequences=True, return_state=True)(input_data)

for i in out:
    print(i.shape)

# LSTM_whole_sequence_output
# LSTM_final_memory_state
# LSTM_final_carry_state

(16, 5, 4)
(16, 4)
(16, 4)
