Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions _doc/technical/plot_layer_norm_discrepancies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@
:ref:`l-plot-parallelized-reduction`, reduction operations
are sensitive to parallelization.

We consider a small model including a layer normalization
followed by a matrix multiplication and we show that replacing
a kernel by another one may significantly impact the output.
Methodology
+++++++++++

We consider a simple model with a LayerNormalization followed by a MatMul.
Each operator can be run with :epkg:`onnxruntime` or :epkg:`pytorch`.
We compare the four combinations.

The model
+++++++++
"""

import itertools
import numpy as np
import pandas
import onnx
import onnx.helper as oh
import onnxruntime
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic.doc import rotate_align, save_fig, plot_histogram, title
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name, onnx_dtype_to_np_dtype
Expand Down Expand Up @@ -79,6 +84,8 @@ def make_feeds(last_dim: int):


def cast_feeds(itype, provider, feeds):
ttype = onnx_dtype_to_torch_dtype(itype)
np_dtype = onnx_dtype_to_np_dtype(itype)
np_feeds = {k: v.detach().numpy() for k, v in feeds.items()}
if provider == "CUDA":
if not torch.cuda.is_available():
Expand All @@ -101,8 +108,6 @@ def cast_feeds(itype, provider, feeds):
baseline = {}

for provider, itype in itertools.product(["CPU", "CUDA"], [TFLOAT, TFLOAT16]):
ttype = onnx_dtype_to_torch_dtype(itype)
np_dtype = onnx_dtype_to_np_dtype(itype)
tch_feeds, ort_feeds = cast_feeds(itype, provider, feeds)
if tch_feeds is None:
continue
Expand Down Expand Up @@ -143,13 +148,34 @@ def cast_feeds(itype, provider, feeds):
# %%
# Visually.

df["abs"].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
save_fig(
rotate_align(
df[["abs"]].plot.bar(title="Discrepancies ORT / torch for LayerNorm(X) @ W + B")
),
"plot_layer_norm_discrepancies_1.png",
)

# %%
# The discrepancies are significant on CUDA, higher for float16.
# Let's see which operator is responsible for them,
# *LayerNormalization* or *MatMul*.

# %%
# Distribution of the results
# +++++++++++++++++++++++++++

tensor = baseline[TFLOAT16, "CPU", "ort"][0].ravel().astype(np.float32)
print(pandas.DataFrame({"expected": tensor}).describe())

# %%
# Histogram.

save_fig(
title(plot_histogram(tensor), "Distribution of the computed results"),
"plot_layer_norm_discrepancies_hist.png",
)


# %%
# The discrepancies come from?
# ++++++++++++++++++++++++++++
Expand All @@ -159,19 +185,18 @@ def cast_feeds(itype, provider, feeds):
data = []

for mod, provider, itype in itertools.product(
["ORT-TORCH", "TORCH-ORT"], ["CPU", "CUDA"], [TFLOAT, TFLOAT16]
["ORT-ORT", "ORT-TORCH", "TORCH-ORT", "TORCH-TORCH"], ["CPU", "CUDA"], [TFLOAT, TFLOAT16]
):
ttype = onnx_dtype_to_torch_dtype(itype)
np_dtype = onnx_dtype_to_np_dtype(itype)
tch_feeds, _ = cast_feeds(itype, provider, feeds)
if tch_feeds is None:
continue

ker1, ker2 = mod.split("-")
custom_kernels = (
{("", "LayerNormalization"): LayerNormalizationOrt}
if mod == "ORT-TORCH"
else {("", "MatMul"): MatMulOrt}
)
{("", "LayerNormalization"): LayerNormalizationOrt} if ker1 == "ORT" else {}
) | ({("", "MatMul"): MatMulOrt} if ker2 == "ORT" else {})

model = get_model(itype)
print()
Expand Down Expand Up @@ -200,13 +225,27 @@ def cast_feeds(itype, provider, feeds):
)

# %%
df = pandas.DataFrame(data).set_index(["model", "provider", "dtype"])
df = pandas.DataFrame(data).set_index(["dtype", "provider", "model"])
df = df.sort_index()
print(df)

# %%
# Visually.

df[["diff_ort", "diff_torch"]].plot.bar(
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
save_fig(
rotate_align(
df[["diff_ort", "diff_torch"]].plot.bar(
title="ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B",
figsize=(10, 4),
)
),
"plot_layer_norm_discrepancies_2.png",
)

# %%
# Conclusion
# ++++++++++
#
# :epkg:`torch` seems able to replicate the same results if the same computation
# is run multiple times. :epkg:`onnxruntime` is only able to do that on CUDA.
# With float16 and CUDA, LayerNormalization seems to introduce some discrepancies.
7 changes: 7 additions & 0 deletions _doc/technical/plot_parallelized_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@

With :math:`\\mathbb{E}X = mean(X)`,
:math:`\\mathbb{V}X = mean\\left(\\left(X - mean(X)\\right)^2\\right)`.

Methodology
+++++++++++

**Permutation should not change the average.**

We draw 128 random permutations of X. The average or mean should not change.
And the normalized vector should have the same values. In the first case, we compute
the difference between the highest and the lowest values obtained for the average.
Expand Down Expand Up @@ -188,6 +194,7 @@ def make_value(base, value):
# Visually.

ax = df.plot.bar(logy=True)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
fig = ax.get_figure()
fig.savefig("plot_parallelized_reduction.png")

Expand Down
13 changes: 13 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import unittest
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting


class TestPatchRewrite(ExtTestCase):
def test_code_needing_rewriting(self):
res = code_needing_rewriting("BartModel")
self.assertEqual(len(res), 2)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_validate_model_custom_torch(self):
mid,
do_run=True,
verbose=10,
exporter="custom-inline",
exporter="custom-noinline",
dump_folder="dump_test_validate_model_custom_torch",
patch=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
Expand Down
30 changes: 0 additions & 30 deletions k.py

This file was deleted.

46 changes: 46 additions & 0 deletions onnx_diagnostic/doc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Optional
import numpy as np


def reset_torch_transformers(gallery_conf, fname):
"Resets torch dynamo for :epkg:`sphinx-gallery`."
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -30,3 +34,45 @@ def plot_legend(
ax.grid(False)
ax.set_axis_off()
return ax


def rotate_align(ax, angle=15, align="right"):
"""Rotates x-label and align them to thr right. Returns ax."""
for label in ax.get_xticklabels():
label.set_rotation(angle)
label.set_horizontalalignment(align)
return ax


def save_fig(ax, name: str):
"""Applies ``tight_layout`` and saves the figures. Returns ax."""
import matplotlib.pyplot as plt

plt.tight_layout()
fig = ax.get_figure()
fig.savefig(name)
return ax


def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
"Adds a title to axes and returns them."
ax.set_title(title)
return ax


def plot_histogram(
tensor: np.ndarray,
ax: Optional["plt.axes"] = None, # noqa: F821
bins: int = 30,
color: str = "orange",
alpha: float = 0.7,
) -> "plt.axes": # noqa: F821
"Computes the distribution for a tensor."
if ax is None:
import matplotlib.pyplot as plt

ax = plt.gca()
ax.cla()
ax.hist(tensor, bins=30, color="orange", alpha=0.7)
ax.set_yscale("log")
return ax
30 changes: 25 additions & 5 deletions onnx_diagnostic/helpers/doc_helper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
from typing import Dict, Optional, Tuple
import os
from typing import Dict, List, Optional, Tuple
import onnx
import onnx.helper as oh
import torch
from ..reference.torch_ops import OpRunKernel, OpRunTensor
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
from .ort_session import InferenceSessionForTorch

_SAVED: List[str] = []
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))


def _get_model_name(op_name: str, provider: str) -> Optional[str]:
if _SAVE_OPTIMIZED_MODEL_:
name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
_SAVED.append(name)
return name
return None


class LayerNormalizationOrt(OpRunKernel):
"LayerNormalization with onnxruntime"

@classmethod
def device_dependent(cls) -> bool:
"Needs device."
return False
return True

def __init__(
self,
Expand Down Expand Up @@ -70,7 +82,11 @@ def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
)
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
self._provider = provider
return InferenceSessionForTorch(layer_model, providers=[provider])
return InferenceSessionForTorch(
layer_model,
optimized_model_filepath=_get_model_name("layer_norm", provider),
providers=[provider],
)

def run(self, x, scale, bias=None):
itype = torch_dtype_to_onnx_dtype(x.dtype)
Expand All @@ -94,7 +110,7 @@ class MatMulOrt(OpRunKernel):
@classmethod
def device_dependent(cls) -> bool:
"Needs device."
return False
return True

def __init__(
self,
Expand Down Expand Up @@ -127,7 +143,11 @@ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
)
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
self._provider = provider
return InferenceSessionForTorch(model, providers=[provider])
return InferenceSessionForTorch(
model,
optimized_model_filepath=_get_model_name("matmul", provider),
providers=[provider],
)

def run(self, a, b):
itype = torch_dtype_to_onnx_dtype(a.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
"AutoformerModel": "AutoformerEncoderLayer",
"BartEncoderLayer": "BartEncoderLayer",
"BartForConditionalGeneration": "BartEncoderLayer",
"BartModel": "BartEncoderLayer",
"BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
"BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
"BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
Expand Down
27 changes: 16 additions & 11 deletions onnx_diagnostic/torch_models/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ def validate_model(
if model_options:
print(f"[validate_model] model_options={model_options!r}")
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
print(
f"[validate_model] rewrite={rewrite}, patch={patch}, "
f"stop_if_static={stop_if_static}"
)
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
print(f"[validate_model] dump_folder={dump_folder!r}")
summary["model_id"] = model_id
summary["model_subfolder"] = subfolder or ""

Expand Down Expand Up @@ -446,6 +452,8 @@ def validate_model(
print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
else:
del data["rewrite"]
if verbose:
print("[validate_model] no rewrite")
if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
print("[validate_model] -- PRINT CONFIG")
print("-- type(config)", type(data["configuration"]))
Expand Down Expand Up @@ -1334,13 +1342,13 @@ def call_torch_export_custom(
"custom-nostrict",
"custom-nostrict-default",
"custom-nostrict-all",
"custom-inline",
"custom-strict-inline",
"custom-strict-default-inline",
"custom-strict-all-inline",
"custom-nostrict-inline",
"custom-nostrict-default-inline",
"custom-nostrict-all-inline",
"custom-noinline",
"custom-strict-noinline",
"custom-strict-default-noinline",
"custom-strict-all-noinline",
"custom-nostrict-noinline",
"custom-nostrict-default-noinline",
"custom-nostrict-all-noinline",
}
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
assert "model" in data, f"model is missing from data: {sorted(data)}"
Expand Down Expand Up @@ -1381,10 +1389,7 @@ def call_torch_export_custom(
),
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
)
inline = "-inline" in exporter
if inline:
export_options.aten_as_function = set()

inline = "-noinline" not in exporter
options = OptimizationOptions(patterns=optimization) if optimization else None
model = data["model"]
kws = dict(
Expand Down
Loading