From 0e11b10b0a482f0699f1507c8cbfc53315a97f26 Mon Sep 17 00:00:00 2001 From: Quentin Soubeyran Date: Tue, 17 Oct 2023 14:51:46 +0200 Subject: [PATCH] fix #7444 with implementation benchmark --- pydantic/main.py | 75 ++- tests/benchmarks/basemodel_eq_performance.py | 614 +++++++++++++++++++ 2 files changed, 685 insertions(+), 4 deletions(-) create mode 100644 tests/benchmarks/basemodel_eq_performance.py diff --git a/pydantic/main.py b/pydantic/main.py index de9bb24bcd7..55b7c1f246f 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1,11 +1,13 @@ """Logic for creating models.""" from __future__ import annotations as _annotations +import enum +import operator import types import typing import warnings from copy import copy, deepcopy -from typing import Any, ClassVar +from typing import Any, ClassVar, Generic, Mapping, TypeVar import pydantic_core import typing_extensions @@ -56,6 +58,42 @@ _object_setattr = _model_construction.object_setattr +_K = TypeVar('_K') +_V = TypeVar('_V') + + +# We need a sentinel value for missing fields when comparing models +# Models are equals if-and-only-if they miss the same fields, and since None is a legitimate value +# we can't default to None +# We use the single-value enum trick to allow correct typing when using a sentinel +class _SentinelType(enum.Enum): + SENTINEL = enum.auto() + + +_SENTINEL = _SentinelType.SENTINEL + + +class _SafeGetItemProxy(Generic[_K, _V]): + """Wrapper redirecting `__getitem__` to `get` with a sentinel value as default + + This makes is safe to use in `operator.itemgetter` when some keys may be missing + """ + + wrapped: Mapping[_K, _V] + + def __init__(self, __mapping: Mapping[_K, _V]) -> None: + self.wrapped = __mapping + + def __getitem__(self, __key: _K) -> _SentinelType | _V: + return self.wrapped.get(__key, _SENTINEL) + + # required to pass the proxy to operator.itemgetter instances + def __contains__(self, __key: _K) -> bool: + return self.wrapped.__contains__(__key) + + def __repr__(self) -> str: + return f'{type(self).__name__}({self.wrapped!r})' + class BaseModel(metaclass=_model_construction.ModelMetaclass): """Usage docs: https://docs.pydantic.dev/2.4/concepts/models/ @@ -866,12 +904,41 @@ def __eq__(self, other: Any) -> bool: self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ - return ( + # Perform common checks first + if not ( self_type == other_type - and self.__dict__ == other.__dict__ and self.__pydantic_private__ == other.__pydantic_private__ and self.__pydantic_extra__ == other.__pydantic_extra__ - ) + ): + return False + + # Fix GH-7444 by comparing only pydantic fields + # We provide a fast-path for performance: __dict__ comparison is *much* faster + # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks + if self.__dict__ == other.__dict__: + # If the check above passes, then pydantic fields are equal, we can return early + return True + else: + # Else, we need to perform a more detailed, costlier comparison + # We use operator.itemgetter because it is much faster than dict comprehensions + # NOTE: Contratry to standard python class and instances, when the Model class has + # attribute default values and the model instance doesn't has a corresponding + # attribute, accessing the missing attribute raises an error in + # __getattr__ instance of returning the class attribute + # Thus, using operator.itemgetter() instead of operator.attrgetter() is valid + model_fields = type(self).model_fields.keys() + getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _SENTINEL + try: + return getter(self.__dict__) == getter(other.__dict__) + except KeyError: + # In rare cases (such as when using the deprecated BaseModel.copy() method), + # the __dict__ may not contain all model fields, which is how we can get here. + # getter(self.__dict__) is much faster than any 'safe' method that accounts + # for missing keys, and wrapping it in a `try` doesn't slow things down much + # in the common case. + self_fields_proxy = _SafeGetItemProxy(self.__dict__) + other_fields_proxy = _SafeGetItemProxy(other.__dict__) + return getter(self_fields_proxy) == getter(other_fields_proxy) else: return NotImplemented # delegate to the other item in the comparison diff --git a/tests/benchmarks/basemodel_eq_performance.py b/tests/benchmarks/basemodel_eq_performance.py new file mode 100644 index 00000000000..090822202f6 --- /dev/null +++ b/tests/benchmarks/basemodel_eq_performance.py @@ -0,0 +1,614 @@ +from __future__ import annotations + +import dataclasses +import enum +import gc +import itertools +import operator +import sys +import textwrap +import timeit +from importlib import metadata +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Sized, TypeVar + +# Do not import additional dependencies at top-level +if TYPE_CHECKING: + import matplotlib.pyplot as plt + import numpy as np + from matplotlib import axes, figure + +import pydantic + +PYTHON_VERSION = '.'.join(map(str, sys.version_info)) +PYDANTIC_VERSION = metadata.version('pydantic') + + +# New implementation of pydantic.BaseModel.__eq__ to test + + +class OldImplementationModel(pydantic.BaseModel, frozen=True): + def __eq__(self, other: Any) -> bool: + if isinstance(other, pydantic.BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + return ( + self_type == other_type + and self.__dict__ == other.__dict__ + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ) + else: + return NotImplemented # delegate to the other item in the comparison + + +class DictComprehensionEqModel(pydantic.BaseModel, frozen=True): + def __eq__(self, other: Any) -> bool: + if isinstance(other, pydantic.BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + field_names = type(self).model_fields.keys() + + return ( + self_type == other_type + and ({k: self.__dict__[k] for k in field_names} == {k: other.__dict__[k] for k in field_names}) + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ) + else: + return NotImplemented # delegate to the other item in the comparison + + +class ItemGetterEqModel(pydantic.BaseModel, frozen=True): + def __eq__(self, other: Any) -> bool: + if isinstance(other, pydantic.BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + model_fields = type(self).model_fields.keys() + getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None + + return ( + self_type == other_type + and getter(self.__dict__) == getter(other.__dict__) + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ) + else: + return NotImplemented # delegate to the other item in the comparison + + +class ItemGetterEqModelFastPath(pydantic.BaseModel, frozen=True): + def __eq__(self, other: Any) -> bool: + if isinstance(other, pydantic.BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + # Perform common checks first + if not ( + self_type == other_type + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ): + return False + + # Fix GH-7444 by comparing only pydantic fields + # We provide a fast-path for performance: __dict__ comparison is *much* faster + # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks + if self.__dict__ == other.__dict__: + # If the check above passes, then pydantic fields are equal, we can return early + return True + else: + # Else, we need to perform a more detailed, costlier comparison + model_fields = type(self).model_fields.keys() + getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None + return getter(self.__dict__) == getter(other.__dict__) + else: + return NotImplemented # delegate to the other item in the comparison + + +K = TypeVar('K') +V = TypeVar('V') + + +# We need a sentinel value for missing fields when comparing models +# Models are equals if-and-only-if they miss the same fields, and since None is a legitimate value +# we can't default to None +# We use the single-value enum trick to allow correct typing when using a sentinel +class _SentinelType(enum.Enum): + SENTINEL = enum.auto() + + +_SENTINEL = _SentinelType.SENTINEL + + +@dataclasses.dataclass +class _SafeGetItemProxy(Generic[K, V]): + """Wrapper redirecting `__getitem__` to `get` and a sentinel value + + This makes is safe to use in `operator.itemgetter` when some keys may be missing + """ + + wrapped: dict[K, V] + + def __getitem__(self, __key: K) -> V | _SentinelType: + return self.wrapped.get(__key, _SENTINEL) + + def __contains__(self, __key: K) -> bool: + return self.wrapped.__contains__(__key) + + +class SafeItemGetterEqModelFastPath(pydantic.BaseModel, frozen=True): + def __eq__(self, other: Any) -> bool: + if isinstance(other, pydantic.BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + # Perform common checks first + if not ( + self_type == other_type + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ): + return False + + # Fix GH-7444 by comparing only pydantic fields + # We provide a fast-path for performance: __dict__ comparison is *much* faster + # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks + if self.__dict__ == other.__dict__: + # If the check above passes, then pydantic fields are equal, we can return early + return True + else: + # Else, we need to perform a more detailed, costlier comparison + model_fields = type(self).model_fields.keys() + getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None + return getter(_SafeGetItemProxy(self.__dict__)) == getter(_SafeGetItemProxy(other.__dict__)) + else: + return NotImplemented # delegate to the other item in the comparison + + +class ItemGetterEqModelFastPathFallback(pydantic.BaseModel, frozen=True): + def __eq__(self, other: Any) -> bool: + if isinstance(other, pydantic.BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + # Perform common checks first + if not ( + self_type == other_type + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ): + return False + + # Fix GH-7444 by comparing only pydantic fields + # We provide a fast-path for performance: __dict__ comparison is *much* faster + # See tests/benchmarks/test_basemodel_eq_performances.py and GH-7825 for benchmarks + if self.__dict__ == other.__dict__: + # If the check above passes, then pydantic fields are equal, we can return early + return True + else: + # Else, we need to perform a more detailed, costlier comparison + model_fields = type(self).model_fields.keys() + getter = operator.itemgetter(*model_fields) if model_fields else lambda _: None + try: + return getter(self.__dict__) == getter(other.__dict__) + except KeyError: + return getter(_SafeGetItemProxy(self.__dict__)) == getter(_SafeGetItemProxy(other.__dict__)) + else: + return NotImplemented # delegate to the other item in the comparison + + +IMPLEMENTATIONS = { + # Commented out because it is too slow for benchmark to complete in reasonable time + # "dict comprehension": DictComprehensionEqModel, + 'itemgetter': ItemGetterEqModel, + 'itemgetter+fastpath': ItemGetterEqModelFastPath, + # Commented-out because it is too slow to run with run_benchmark_random_unequal + #'itemgetter+safety+fastpath': SafeItemGetterEqModelFastPath, + 'itemgetter+fastpath+safe-fallback': ItemGetterEqModelFastPathFallback, +} + +# Benchmark running & plotting code + + +def plot_all_benchmark( + bases: dict[str, type[pydantic.BaseModel]], + sizes: list[int], +) -> figure.Figure: + import matplotlib.pyplot as plt + + n_rows, n_cols = len(BENCHMARKS), 2 + fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 6, n_rows * 4)) + + for row, (name, benchmark) in enumerate(BENCHMARKS.items()): + for col, mimic_cached_property in enumerate([False, True]): + plot_benchmark( + f'{name}, {mimic_cached_property=}', + benchmark, + bases=bases, + sizes=sizes, + mimic_cached_property=mimic_cached_property, + ax=axes[row, col], + ) + for ax in axes.ravel(): + ax.legend() + fig.suptitle(f'python {PYTHON_VERSION}, pydantic {PYDANTIC_VERSION}') + return fig + + +def plot_benchmark( + title: str, + benchmark: Callable, + bases: dict[str, type[pydantic.BaseModel]], + sizes: list[int], + mimic_cached_property: bool, + ax: axes.Axes | None = None, +): + import matplotlib.pyplot as plt + import numpy as np + + ax = ax or plt.gca() + arr_sizes = np.asarray(sizes) + + baseline = benchmark( + title=f'{title}, baseline', + base=OldImplementationModel, + sizes=sizes, + mimic_cached_property=mimic_cached_property, + ) + ax.plot(sizes, baseline / baseline, label='baseline') + for name, base in bases.items(): + times = benchmark( + title=f'{title}, {name}', + base=base, + sizes=sizes, + mimic_cached_property=mimic_cached_property, + ) + mask_valid = ~np.isnan(times) + ax.plot(arr_sizes[mask_valid], times[mask_valid] / baseline[mask_valid], label=name) + + ax.set_title(title) + ax.set_xlabel('Number of pydantic fields') + ax.set_ylabel('Average time relative to baseline') + return ax + + +class SizedIterable(Sized, Iterable): + pass + + +def run_benchmark_nodiff( + title: str, + base: type[pydantic.BaseModel], + sizes: SizedIterable, + mimic_cached_property: bool, + n_execution: int = 10_000, + n_repeat: int = 5, +) -> np.ndarray: + setup = textwrap.dedent( + """ + import pydantic + + Model = pydantic.create_model( + "Model", + __base__=Base, + **{f'x{i}': (int, i) for i in range(%(size)d)} + ) + left = Model() + right = Model() + """ + ) + if mimic_cached_property: + # Mimic functools.cached_property editing __dict__ + # NOTE: we must edit both objects, otherwise the dict don't have the same size and + # dict.__eq__ has a very fast path. This makes our timing comparison incorrect + # However, the value must be different, otherwise *our* __dict__ == right.__dict__ + # fast-path prevents our correct code from running + setup += textwrap.dedent( + """ + object.__setattr__(left, 'cache', None) + object.__setattr__(right, 'cache', -1) + """ + ) + statement = 'left == right' + namespace = {'Base': base} + return run_benchmark( + title, + setup=setup, + statement=statement, + n_execution=n_execution, + n_repeat=n_repeat, + globals=namespace, + params={'size': sizes}, + ) + + +def run_benchmark_first_diff( + title: str, + base: type[pydantic.BaseModel], + sizes: SizedIterable, + mimic_cached_property: bool, + n_execution: int = 10_000, + n_repeat: int = 5, +) -> np.ndarray: + setup = textwrap.dedent( + """ + import pydantic + + Model = pydantic.create_model( + "Model", + __base__=Base, + **{f'x{i}': (int, i) for i in range(%(size)d)} + ) + left = Model() + right = Model(f0=-1) if %(size)d > 0 else Model() + """ + ) + if mimic_cached_property: + # Mimic functools.cached_property editing __dict__ + # NOTE: we must edit both objects, otherwise the dict don't have the same size and + # dict.__eq__ has a very fast path. This makes our timing comparison incorrect + # However, the value must be different, otherwise *our* __dict__ == right.__dict__ + # fast-path prevents our correct code from running + setup += textwrap.dedent( + """ + object.__setattr__(left, 'cache', None) + object.__setattr__(right, 'cache', -1) + """ + ) + statement = 'left == right' + namespace = {'Base': base} + return run_benchmark( + title, + setup=setup, + statement=statement, + n_execution=n_execution, + n_repeat=n_repeat, + globals=namespace, + params={'size': sizes}, + ) + + +def run_benchmark_last_diff( + title: str, + base: type[pydantic.BaseModel], + sizes: SizedIterable, + mimic_cached_property: bool, + n_execution: int = 10_000, + n_repeat: int = 5, +) -> np.ndarray: + setup = textwrap.dedent( + """ + import pydantic + + Model = pydantic.create_model( + "Model", + __base__=Base, + # shift the range() so that there is a field named size + **{f'x{i}': (int, i) for i in range(1, %(size)d + 1)} + ) + left = Model() + right = Model(f%(size)d=-1) if %(size)d > 0 else Model() + """ + ) + if mimic_cached_property: + # Mimic functools.cached_property editing __dict__ + # NOTE: we must edit both objects, otherwise the dict don't have the same size and + # dict.__eq__ has a very fast path. This makes our timing comparison incorrect + # However, the value must be different, otherwise *our* __dict__ == right.__dict__ + # fast-path prevents our correct code from running + setup += textwrap.dedent( + """ + object.__setattr__(left, 'cache', None) + object.__setattr__(right, 'cache', -1) + """ + ) + statement = 'left == right' + namespace = {'Base': base} + return run_benchmark( + title, + setup=setup, + statement=statement, + n_execution=n_execution, + n_repeat=n_repeat, + globals=namespace, + params={'size': sizes}, + ) + + +def run_benchmark_random_unequal( + title: str, + base: type[pydantic.BaseModel], + sizes: SizedIterable, + mimic_cached_property: bool, + n_samples: int = 100, + n_execution: int = 1_000, + n_repeat: int = 5, +) -> np.ndarray: + import numpy as np + + setup = textwrap.dedent( + """ + import pydantic + + Model = pydantic.create_model( + "Model", + __base__=Base, + **{f'x{i}': (int, i) for i in range(%(size)d)} + ) + left = Model() + right = Model(f%(field)d=-1) + """ + ) + if mimic_cached_property: + # Mimic functools.cached_property editing __dict__ + # NOTE: we must edit both objects, otherwise the dict don't have the same size and + # dict.__eq__ has a very fast path. This makes our timing comparison incorrect + # However, the value must be different, otherwise *our* __dict__ == right.__dict__ + # fast-path prevents our correct code from running + setup += textwrap.dedent( + """ + object.__setattr__(left, 'cache', None) + object.__setattr__(right, 'cache', -1) + """ + ) + statement = 'left == right' + namespace = {'Base': base} + arr_sizes = np.fromiter(sizes, dtype=int) + mask_valid_sizes = arr_sizes > 0 + arr_valid_sizes = arr_sizes[mask_valid_sizes] # we can't support 0 when sampling the field + rng = np.random.default_rng() + arr_fields = rng.integers(arr_valid_sizes, size=(n_samples, *arr_valid_sizes.shape)) + # broadcast the sizes against their sample, so we can iterate on (size, field) tuple + # as parameters of the timing test + arr_size_broadcast, _ = np.meshgrid(arr_valid_sizes, arr_fields[:, 0]) + + results = run_benchmark( + title, + setup=setup, + statement=statement, + n_execution=n_execution, + n_repeat=n_repeat, + globals=namespace, + params={'size': arr_size_broadcast.ravel(), 'field': arr_fields.ravel()}, + ) + times = np.empty(arr_sizes.shape, dtype=float) + times[~mask_valid_sizes] = np.nan + times[mask_valid_sizes] = results.reshape((n_samples, *arr_valid_sizes.shape)).mean(axis=0) + return times + + +BENCHMARKS = { + 'All field equals': run_benchmark_nodiff, + 'First field unequal': run_benchmark_first_diff, + 'Last field unequal': run_benchmark_last_diff, + 'Random unequal field': run_benchmark_random_unequal, +} + + +def run_benchmark( + title: str, + setup: str, + statement: str, + n_execution: int = 10_000, + n_repeat: int = 5, + globals: dict[str, Any] | None = None, + progress_bar: bool = True, + params: dict[str, SizedIterable] | None = None, +) -> np.ndarray: + import numpy as np + import tqdm.auto as tqdm + + namespace = globals or {} + # fast-path + if not params: + length = 1 + packed_params = [()] + else: + length = len(next(iter(params.values()))) + # This iterator yields a tuple of (key, value) pairs + # First, make a list of N iterator over (key, value), where the provided values are iterated + param_pairs = [zip(itertools.repeat(name), value) for name, value in params.items()] + # Then pack our individual parameter iterator into one + packed_params = zip(*param_pairs) + + times = [ + # Take the min of the repeats as recommended by timeit doc + min( + timeit.Timer( + setup=setup % dict(local_params), + stmt=statement, + globals=namespace, + ).repeat(repeat=n_repeat, number=n_execution) + ) + / n_execution + for local_params in tqdm.tqdm(packed_params, desc=title, total=length, disable=not progress_bar) + ] + gc.collect() + + return np.asarray(times, dtype=float) + + +if __name__ == '__main__': + # run with `pdm run tests/benchmarks/test_basemodel_eq_performance.py` + import argparse + import pathlib + + try: + import matplotlib # noqa: F401 + import numpy # noqa: F401 + import tqdm # noqa: F401 + except ImportError as err: + raise ImportError( + 'This benchmark additionally depends on numpy, matplotlib and tqdm. ' + 'Install those in your environment to run the benchmark.' + ) from err + + parser = argparse.ArgumentParser( + description='Test the performance of various BaseModel.__eq__ implementations fixing GH-7444.' + ) + parser.add_argument( + '-o', + '--output-path', + default=None, + type=pathlib.Path, + help=( + 'Output directory or file in which to save the benchmark results. ' + 'If a directory is passed, a default filename is used.' + ), + ) + parser.add_argument( + '--min-n-fields', + type=int, + default=0, + help=('Test the performance of BaseModel.__eq__ on models with at least this number of fields. Defaults to 0.'), + ) + parser.add_argument( + '--max-n-fields', + type=int, + default=100, + help=('Test the performance of BaseModel.__eq__ on models with up to this number of fields. Defaults to 100.'), + ) + + args = parser.parse_args() + + import matplotlib.pyplot as plt + + sizes = list(range(args.min_n_fields, args.max_n_fields)) + fig = plot_all_benchmark(IMPLEMENTATIONS, sizes=sizes) + plt.tight_layout() + + if args.output_path is None: + plt.show() + else: + if args.output_path.suffix: + filepath = args.output_path + else: + filepath = args.output_path / f"eq-benchmark_python-{PYTHON_VERSION.replace('.', '-')}.png" + fig.savefig( + filepath, + dpi=200, + facecolor='white', + transparent=False, + ) + print(f'wrote {filepath!s}', file=sys.stderr)