Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions onnx2kerastl/operation_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,16 +687,31 @@ def convert_or(node, params, layers, lambda_func, node_name, keras_name):


def convert_trilu(node, params, layers, lambda_func, node_name, keras_name):
input = layers[node.input[0]]
x = layers[node.input[0]]
k = 0
if len(node.input) > 1:
k = layers[node.input[1]]

if "upper" in params and not params["upper"]:
result = tf.experimental.numpy.tril(input, k)
k_tensor = layers[node.input[1]]
try:
k = int(tf.keras.backend.get_value(k_tensor))
except:
k = 0 # fallback if symbolic
upper = params.get("upper", 1)

# cannot use tf.experimental.numpy.tril/triu because this is not an eager tensor and we dont know the shape
def trilu_fn(tensor):
shape = tf.shape(tensor)
m, n = shape[-2], shape[-1]
row_idx = tf.range(m)[:, None]
col_idx = tf.range(n)[None, :]
if upper:
mask = row_idx <= (col_idx - k)
else:
mask = row_idx >= (col_idx - k)
mask = tf.cast(mask, tensor.dtype)
mask = tf.broadcast_to(mask, shape)
return tensor * mask

else:
result = tf.experimental.numpy.triu(input, k)
result = tf.keras.layers.Lambda(trilu_fn, name=keras_name)(x)
layers[node_name] = result


Expand Down
6 changes: 6 additions & 0 deletions onnx2kerastl/reshape_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def convert_slice(node, params, layers, lambda_func, node_name, keras_name):
:return: None
"""
logger = logging.getLogger('onnx2keras.slice')
max_ends_val = np.iinfo(np.int32).max

if params['change_ordering']:
raise NotImplementedError("change_ordering for Slice is not implemented")
Expand All @@ -448,6 +449,11 @@ def convert_slice(node, params, layers, lambda_func, node_name, keras_name):
steps = list(layers[node.input[4]])
except IndexError:
steps = list(params.get("steps", [None] * len(axes)))

# when the 'ends' value is the int64 maximum, probably happen because [idx:] sets large end num in conversion
if ends[0].dtype == np.int64 and not isinstance(ends[0], KerasTensor):
if ends[0] > max_ends_val:
ends = [np.int32(max_ends_val)]
try:
max_len = len(layers[node.input[0]].shape)
axes_positives = [axis if axis >= 0 else max_len + axis for axis in axes]
Expand Down
267 changes: 239 additions & 28 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "onnx2kerastl"
version = "0.0.164"
version = "0.0.165"
description = ""
authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
license = "MIT"
Expand All @@ -17,6 +17,7 @@ fvcore = "^0.1.5.post20221221"
boto3 = "^1.24.22"
tensorflow-io-gcs-filesystem = "0.34.0"
keras-data-format-converter = "0.1.22"
optimum = "1.23.3"

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand Down
57 changes: 57 additions & 0 deletions test/models/test_llama_sentiment_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os.path

import onnx
import pytest
import tensorflow as tf
from transformers import AutoTokenizer
import numpy as np
from onnx2kerastl import onnx_to_keras
from keras_data_format_converter import convert_channels_first_to_last
from onnx2kerastl.customonnxlayer import onnx_custom_objects_map
from test.utils import export_torch_to_onnx_optimum


@pytest.mark.skip(reason="Fails on CI but works locally (might be too big?)")
def test_llama_32_1b_inst():
onnx_model_folder = 'onnx_model'
onnx_path = os.path.join(onnx_model_folder, 'model.onnx')
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# --------------------------------- Export to ONNX -------------------------------------
export_torch_to_onnx_optimum(model_name, model_output_path=onnx_model_folder)
# ----------------------------------------- Input Preparation --------------------------
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
text = "i love this movie!"
prompt = tokenizer.apply_chat_template(
[{"role": "user",
"content": f"What is the sentiment of this sentence: \"{text}\"? Respond with 'positive' or 'negative' only."}],
add_generation_prompt=True,
return_tensors="np"
)
input_ids = prompt
attention_mask = (input_ids != tokenizer.pad_token_id).astype(np.int64)
position_ids = np.arange(input_ids.shape[1])[None, :]
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids
}
keras_inputs = {k: tf.convert_to_tensor(v) for k, v in model_inputs.items()}
# --------------------------------- Export to Keras -------------------------------------
onnx_model = onnx.load(onnx_path) # TODO: add to requirements, updated onnx==1.17.0 ()
keras_model = onnx_to_keras(onnx_model, ['input_ids', 'attention_mask', 'position_ids'],
allow_partial_compilation=False)
keras_model = keras_model.converted_model
flipped_model = convert_channels_first_to_last(keras_model, [])
flipped_model.save('temp.h5')
model = tf.keras.models.load_model('temp.h5', custom_objects=onnx_custom_objects_map)
# --------------------------------- Evaluating Inference -------------------------------------
outputs = model(keras_inputs)
last_token_logits = outputs[0, -1]
pred_token_id = np.argmax(last_token_logits)
pred_token = tokenizer.decode([pred_token_id]).strip().lower()

assert pred_token=='positive'

if __name__ == "__main__":
test_llama_32_1b_inst()
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=1.1.0,<=1.5.0
torchvision>=0.3.0,<=0.6.0
optimum==1.23.3
pytest
pytest-repeat
31 changes: 31 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from onnx2kerastl import onnx_to_keras
from onnx2kerastl.utils import check_torch_keras_error
from optimum.exporters.onnx import main_export

NP_SEED = 42

Expand Down Expand Up @@ -55,3 +56,33 @@ def test_conversion(onnx_model, k_model, input_variable, change_ordering=False,

def is_lambda_layers_exist(model: Model):
return any(isinstance(layer, Lambda) for layer in model.layers)


def export_torch_to_onnx_optimum(model_name: str, model_output_path: str, task="causal-lm"):
"""
this function get a model as an input (Hugginface or local path), creates a folder and save the onnx model as output.
it uses the optimum library.
NOTE: For llama model the maximum absolute difference of the logits larget than 1e-5, it shouldnt be that important!
Args:
model_name: model path (local or HF name)
model_output_name: output folder path
task: model task

Returns:
creates the onnx model in the output folder path
"""
main_export(
model_name_or_path=model_name,
task=task,
output=model_output_path,
opset=None,
device="cpu",
dtype=None,
pad_token_id=None,
trust_remote_code=False,
do_validation=True,
framework=None,
no_post_process=False,
model_kwargs=None,
atol = 1e-5
)