Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions _unittests/ut_tasks/test_tasks_text_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasksTextGeneration(ExtTestCase):
@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
def test_image_text_to_text_gemma3_for_causallm(self):
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-generation")
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**torch_deepcopy(inputs))
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)


if __name__ == "__main__":
unittest.main(verbosity=2)
38 changes: 38 additions & 0 deletions onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4865,3 +4865,41 @@ def _ccached_google_gemma_3_4b_it_like():
},
}
)


def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
"hf-internal-testing/tiny-random-Gemma3ForCausalLM"
return transformers.Gemma3TextConfig(
**{
"architectures": ["Gemma3ForCausalLM"],
"attention_bias": false,
"attention_dropout": 0.0,
"attn_logit_softcapping": null,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": [1, 106],
"final_logit_softcapping": null,
"head_dim": 8,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 16,
"initializer_range": 0.02,
"intermediate_size": 32,
"max_position_embeddings": 32768,
"model_type": "gemma3_text",
"num_attention_heads": 2,
"num_hidden_layers": 2,
"num_key_value_heads": 1,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": 512,
"sliding_window_pattern": 6,
"torch_dtype": "float32",
"transformers_version": "4.52.0.dev0",
"use_cache": true,
"vocab_size": 262144,
}
)
Loading