Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ jobs:

- name: Check for errors and warnings
run: |
if [[ $(grep ERROR doc.txt | grep -v 'Unknown target name: "l_shape"' | grep -v 'Unknown target name: "l_x"') ]]; then
if [[ $(grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export') ]]; then
echo "Documentation produces errors."
grep ERROR doc.txt
grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export'
exit 1
fi
if [[ $(grep WARNING doc.txt) ]]; then
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export') ]]; then
echo "Documentation produces warnings."
grep WARNING doc.txt
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export'
exit 1
fi

Expand Down
2 changes: 2 additions & 0 deletions _doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ API of onnx_diagnostic
:maxdepth: 1
:caption: submodules

torch_export_patches/index
torch_models/index

.. toctree::
:maxdepth: 1
Expand Down
6 changes: 6 additions & 0 deletions _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
onnx_diagnostic.torch_export_patches
====================================

.. automodule:: onnx_diagnostic.torch_export_patches
:members:
:no-undoc-members:
12 changes: 12 additions & 0 deletions _doc/api/torch_models/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
onnx_diagnostic.torch_models
============================

.. toctree::
:maxdepth: 1
:caption: submodules

llms

.. automodule:: onnx_diagnostic.torch_models
:members:
:no-undoc-members:
7 changes: 7 additions & 0 deletions _doc/api/torch_models/llms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_models.llms
=================================

.. automodule:: onnx_diagnostic.torch_models.llms
:members:
:no-undoc-members:
3 changes: 2 additions & 1 deletion _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
("py:class", "torch.utils._pytree.Context"),
("py:class", "torch.utils._pytree.KeyEntry"),
("py:class", "torch.utils._pytree.TreeSpec"),
("py:class", "transformers.LlamaConfig"),
("py:class", "transformers.cache_utils.Cache"),
("py:class", "transformers.cache_utils.DynamicCache"),
("py:class", "transformers.cache_utils.MambaCache"),
Expand Down Expand Up @@ -154,7 +155,7 @@
}

if int(os.environ.get("UNITTEST_GOING", "0")):
sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*"
sphinx_gallery_conf["ignore_pattern"] = ".*((tiny_llm)|(dort)|(draft_mode)).*"
elif pv.Version(torch.__version__) < pv.Version("2.8"):
sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*"

Expand Down
139 changes: 24 additions & 115 deletions _doc/examples/plot_export_tiny_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""
.. _l-plot-tiny-llm-export:

Export LLM with dynamic shapes
==============================

Expand All @@ -15,11 +17,11 @@
We use the dummy example from the model page.
"""

from typing import Any, Dict
import copy
import torch
import transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.cache_helpers import make_dynamic_cache
from onnx_diagnostic.torch_models.llms import get_tiny_llm


MODEL_NAME = "arnir0/Tiny-LLM"
Expand All @@ -30,21 +32,6 @@
# We rewrite the forward method to print the cache dimension.


def string_inputs(args, kwargs):
def _cache(a):
if len(a.key_cache):
return f"n_caches={len(a.key_cache)}, shape={a.key_cache[0].shape}"
return f"n_caches={len(a.key_cache)}"

for a in args:
if isinstance(a, transformers.cache_utils.DynamicCache):
return _cache(a)
for k, a in kwargs.items():
if isinstance(a, transformers.cache_utils.DynamicCache):
return f"{k}={_cache(a)}"
return "no_cache"


def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not torch.compiler.is_exporting():
Expand Down Expand Up @@ -83,100 +70,6 @@ def _forward_(*args, _f=None, **kwargs):
# Let's create an untrained model.


def get_tiny_llm(
batch_size: int = 2,
input_cache: bool = True,
common_dynamic_shapes: bool = True,
dynamic_rope: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""
Gets a non initialized model.

:param batch_size: batch size
:param input_cache: generate data for this iteration with or without cache
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:param common_dynamic_shapes: if True returns dynamic shapes as well
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
:return: dictionary
"""
import transformers

config = {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 192,
"initializer_range": 0.02,
"intermediate_size": 1024,
"max_position_embeddings": 1024,
"model_type": "llama",
"num_attention_heads": 2,
"num_hidden_layers": 1,
"num_key_value_heads": 1,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
"tie_word_embeddings": False,
"torch_dtype": "float32",
"transformers_version": "4.31.0.dev0",
"use_cache": True,
"vocab_size": 32000,
}

config.update(**kwargs)
conf = transformers.LlamaConfig(**config)
model = transformers.LlamaForCausalLM(conf)
model.eval()

# now the inputs
cache_last_dim = 96
sequence_length = 30
sequence_length2 = 3
num_key_value_heads = 1
max_token_id = config["vocab_size"] - 1
n_layers = config["num_hidden_layers"]

batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
cache_length = torch.export.Dim("cache_length", min=1, max=4096)

shapes = {
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {
0: batch,
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
},
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(n_layers)],
[{0: batch, 2: cache_length} for _ in range(n_layers)],
],
}
inputs = dict(
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
torch.int64
),
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
torch.int64
),
past_key_values=make_dynamic_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length, cache_last_dim
),
torch.randn(
batch_size, num_key_value_heads, sequence_length, cache_last_dim
),
)
for i in range(n_layers)
]
),
)
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)


# %%
# Let's get the model, inputs and dynamic shapes.

Expand All @@ -187,9 +80,25 @@ def get_tiny_llm(
experiment["dynamic_shapes"],
)

# %%
# Before we run it, we make a copy of the inputs as the cache
# get modified by the execution. Then it is no longer valid
# associated with the previous input_ids and mask.
cloned_inputs = copy.deepcopy(inputs)


# %% Let's run it.
expected_output = model(**inputs)
print("result type", type(expected_output))
print("input type", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)


print("input after the execution", string_type(inputs, with_shape=True))
print("result type", string_type(expected_output, with_shape=True))

ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
)

# %%
# It works.
Expand All @@ -199,7 +108,7 @@ def get_tiny_llm(

try:
ep = torch.export.export(
untrained_model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
)
print("It worked:")
print(ep)
Expand All @@ -217,7 +126,7 @@ def get_tiny_llm(
# Let's use the same dummy inputs but we use the downloaded model.

try:
ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False)
ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
print("It worked:")
print(ep)
except Exception as e:
Expand Down
7 changes: 0 additions & 7 deletions _doc/galleries.rst

This file was deleted.

2 changes: 1 addition & 1 deletion _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Source are `sdpython/onnx-diagnostic
:caption: Contents

api/index
galleries
auto_examples/index

.. toctree::
:maxdepth: 1
Expand Down
Loading
Loading