Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions onnx2kerastl/ltsm_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def convert_lstm(node, params, layers, lambda_func, node_name, keras_name):
tf_name=f"{params['cleaned_name']}_lstm_cast_h")
initial_c_state = tf_cast(
tf_squeeze(
ensure_tf_type(layers[node.input[6]]),
axis=0,
tf_name=f"{params['cleaned_name']}_lstm_squeeze_c"), input_tensor.dtype,
ensure_tf_type(layers[node.input[6]]),
axis=0,
tf_name=f"{params['cleaned_name']}_lstm_squeeze_c"), input_tensor.dtype,
tf_name=f"{params['cleaned_name']}_lstm_cast_c")

tf.keras.backend.set_image_data_format("channels_last")
Expand Down Expand Up @@ -85,24 +85,47 @@ def convert_lstm(node, params, layers, lambda_func, node_name, keras_name):
h_out = tf.expand_dims(h_out, 0)

lstm_tensor = res[:, 1:-1, :]
layers[node.output[1]] = h_out
layers[node.output[2]] = c_out
else:
lstm_tensor = res
lstm_tensor_in_onnx_order = tf_transpose(lstm_tensor, perm=[1, 0, 2], tf_name=f"{params['cleaned_name']}_lstm_transpose")

# Add identical dense contains the lstm tensor for easy fetch of latent space
input_dim = int(lstm_tensor.shape[2])
Copy link
Collaborator

@tomkoren21 tomkoren21 Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to raise an Exception here if shape is dynamic? This line would fail if lstm_tensor shape is None
Or else replace with tf.shape(lstm_tensor)[2]

dense = tf.keras.layers.Dense(
units=input_dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.Identity()
)

lstm_tensor_dense = dense(lstm_tensor)

if should_return_state:
mul_o = lstm_tensor_dense[0, 0, 0] * 0
c_out = tf.add(c_out, mul_o)
h_out = tf.add(h_out, mul_o)

layers[node.output[1]] = h_out
layers[node.output[2]] = c_out

lstm_tensor = lstm_tensor_dense

lstm_tensor_in_onnx_order = tf_transpose(lstm_tensor, perm=[1, 0, 2],
tf_name=f"{params['cleaned_name']}_lstm_transpose")
lstm_tensor_in_onnx_order = tf_expand_dims(lstm_tensor_in_onnx_order, axis=1,
tf_name=f"{params['cleaned_name']}_lstm_expand_dims")
layers[node_name] = lstm_tensor_in_onnx_order


def convert_gru(node, params, layers, lambda_func, node_name, keras_name):
logger = logging.getLogger('onnx2keras.convert_gru')
if len(params["_outputs"]) > 1:
logger.warning("The GRU return hidden state is currently not supported. Accessing in deeper layers will raise Exception")
logger.warning(
"The GRU return hidden state is currently not supported. Accessing in deeper layers will raise Exception")
if params.get('activation_alpha') or params.get('activation_beta') or params.get('activations'):
raise NotImplementedError('Custom Activations in GRU not implemented')
if params.get('clip'):
raise NotImplementedError('Clip in GRU not implemented')
if params.get('direction'): #After implementation - verify weights reshaping, and h default_size for all directions
if params.get(
'direction'): # After implementation - verify weights reshaping, and h default_size for all directions
raise NotImplementedError('direction in GRU not implemented')
else:
num_directions = 1
Expand All @@ -114,25 +137,25 @@ def convert_gru(node, params, layers, lambda_func, node_name, keras_name):
raise NotImplementedError('GRU sequence_lens is not yet implemented')
hidden_size = params.get('hidden_size')
linear_before_reset = bool(params.get('linear_before_reset', 0))
x = layers[node.input[0]] # [seq_length, batch_size, input_size] iff layout = 0
x = layers[node.input[0]] # [seq_length, batch_size, input_size] iff layout = 0
w = layers[node.input[1]]
r = layers[node.input[2]]
b = layers.get(node.input[3], np.zeros((num_directions, 6*hidden_size), dtype=np.float32))
b = layers.get(node.input[3], np.zeros((num_directions, 6 * hidden_size), dtype=np.float32))
h = layers.get(node.input[5], np.zeros((1, x.shape[1], hidden_size), dtype=np.float32))
if isinstance(h, np.ndarray):
tensor_h = tf.convert_to_tensor(h)
else:
tensor_h = h
tf.keras.backend.set_image_data_format("channels_last")
gru_layer = tf.keras.layers.GRU(units=hidden_size,
reset_after=linear_before_reset,
return_sequences=True,
name=f"{params['cleaned_name']}_gru")
reset_after=linear_before_reset,
return_sequences=True,
name=f"{params['cleaned_name']}_gru")
if layout == 0:
batch_first_x = tf_transpose(x, [1, 0, 2], tf_name=f"{params['cleaned_name']}_gru_transpose")
res = gru_layer(batch_first_x, initial_state=tf.convert_to_tensor(tensor_h[0]))
# gru_layer.build(tf.shape(batch_first_x))
gru_layer.set_weights([w[0].swapaxes(0, 1), r[0].swapaxes(0, 1), b[0].reshape(-1, 3*hidden_size)])
gru_layer.set_weights([w[0].swapaxes(0, 1), r[0].swapaxes(0, 1), b[0].reshape(-1, 3 * hidden_size)])
# res = gru_layer(batch_first_x, initial_state=tf.convert_to_tensor(tensor_h[0]))
if num_directions == 1:
reshaped_res = tf_expand_dims(tf_transpose(res,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "onnx2kerastl"
version = "0.0.170"
version = "0.0.174"
description = ""
authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
license = "MIT"
Expand Down