From 87484abb05282501d0fb7add1bc669fb7bc5ecf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 20 Oct 2025 17:53:44 +0200 Subject: [PATCH] add one more model to test --- .../ut_tasks/test_tasks_text_generation.py | 33 ++++++++++++++++ .../hghub/hub_data_cached_configs.py | 38 +++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 _unittests/ut_tasks/test_tasks_text_generation.py diff --git a/_unittests/ut_tasks/test_tasks_text_generation.py b/_unittests/ut_tasks/test_tasks_text_generation.py new file mode 100644 index 00000000..8303b788 --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_text_generation.py @@ -0,0 +1,33 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + +class TestTasksTextGeneration(ExtTestCase): + @hide_stdout() + @requires_transformers("4.53") + @requires_torch("2.7.99") + def test_image_text_to_text_gemma3_for_causallm(self): + mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "text-generation") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**torch_deepcopy(inputs)) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 2bee6b27..f56317b9 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4865,3 +4865,41 @@ def _ccached_google_gemma_3_4b_it_like(): }, } ) + + +def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm(): + "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + return transformers.Gemma3TextConfig( + **{ + "architectures": ["Gemma3ForCausalLM"], + "attention_bias": false, + "attention_dropout": 0.0, + "attn_logit_softcapping": null, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": [1, 106], + "final_logit_softcapping": null, + "head_dim": 8, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 32, + "max_position_embeddings": 32768, + "model_type": "gemma3_text", + "num_attention_heads": 2, + "num_hidden_layers": 2, + "num_key_value_heads": 1, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "sliding_window_pattern": 6, + "torch_dtype": "float32", + "transformers_version": "4.52.0.dev0", + "use_cache": true, + "vocab_size": 262144, + } + )