diff --git a/yt/funcs.py b/yt/funcs.py index 97a3c0f6445..34352e1e5d1 100644 --- a/yt/funcs.py +++ b/yt/funcs.py @@ -25,7 +25,7 @@ import matplotlib import numpy as np from more_itertools import always_iterable, collapse, first -from packaging.version import parse as parse_version +from packaging.version import Version from tqdm import tqdm from yt.units import YTArray, YTQuantity @@ -1039,7 +1039,7 @@ def matplotlib_style_context(style_name=None, after_reset=False): import matplotlib style_name = {"mathtext.fontset": "cm"} - if parse_version(matplotlib.__version__) >= parse_version("3.3.0"): + if Version(matplotlib.__version__) >= Version("3.3.0"): style_name["mathtext.fallback"] = "cm" else: style_name["mathtext.fallback_to_cm"] = True diff --git a/yt/utilities/on_demand_imports.py b/yt/utilities/on_demand_imports.py index ba519191836..ba112a7a3ed 100644 --- a/yt/utilities/on_demand_imports.py +++ b/yt/utilities/on_demand_imports.py @@ -1,6 +1,6 @@ import sys -from packaging.version import parse as parse_version +from packaging.version import Version class NotAModule: @@ -361,7 +361,7 @@ def __init__(self): try: import h5py - if parse_version(h5py.__version__) < parse_version("2.4.0"): + if Version(h5py.__version__) < Version("2.4.0"): self._err = RuntimeError( "yt requires h5py version 2.4.0 or newer, " "please update h5py with e.g. `python -m pip install -U h5py` " diff --git a/yt/visualization/base_plot_types.py b/yt/visualization/base_plot_types.py index 90096e586e4..d4b06bfb227 100644 --- a/yt/visualization/base_plot_types.py +++ b/yt/visualization/base_plot_types.py @@ -1,8 +1,9 @@ +import warnings from io import BytesIO import matplotlib import numpy as np -from packaging.version import parse as parse_version +from packaging.version import Version from yt.funcs import ( get_brewer_cmap, @@ -12,7 +13,7 @@ mylog, ) -from ._commons import get_canvas, validate_image_name +from ._commons import MPL_VERSION, get_canvas, validate_image_name BACKEND_SPECS = { "GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"], @@ -132,9 +133,9 @@ def save(self, name, mpl_kwargs=None, canvas=None): if mpl_kwargs is None: mpl_kwargs = {} - if "papertype" not in mpl_kwargs and parse_version( - matplotlib.__version__ - ) < parse_version("3.3.0"): + if "papertype" not in mpl_kwargs and Version(matplotlib.__version__) < Version( + "3.3.0" + ): mpl_kwargs["papertype"] = "auto" name = validate_image_name(name) @@ -206,20 +207,29 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): vmin=float(self.zmin) if self.zmin is not None else None, vmax=float(self.zmax) if self.zmax is not None else None, ) + zmin = float(self.zmin) if self.zmin is not None else np.nanmin(data) + zmax = float(self.zmax) if self.zmax is not None else np.nanmax(data) + + if cbnorm == "symlog": + # if cblinthresh is not specified, try to come up with a reasonable default + min_abs_val = np.min(np.abs((zmin, zmax))) + if cblinthresh is None: + cblinthresh = np.nanmin(np.absolute(data)[data != 0]) + elif zmin * zmax > 0 and cblinthresh < min_abs_val: + warnings.warn( + f"Cannot set a symlog norm with linear threshold {cblinthresh} " + f"lower than the minimal absolute data value {min_abs_val} . " + "Switching to log norm." + ) + cbnorm = "log10" + if cbnorm == "log10": cbnorm_cls = matplotlib.colors.LogNorm elif cbnorm == "linear": cbnorm_cls = matplotlib.colors.Normalize elif cbnorm == "symlog": - # if cblinthresh is not specified, try to come up with a reasonable default - vmin = float(np.nanmin(data)) - vmax = float(np.nanmax(data)) - if cblinthresh is None: - cblinthresh = np.nanmin(np.absolute(data)[data != 0]) - - cbnorm_kwargs.update(dict(linthresh=cblinthresh, vmin=vmin, vmax=vmax)) - MPL_VERSION = parse_version(matplotlib.__version__) - if MPL_VERSION >= parse_version("3.2.0"): + cbnorm_kwargs.update(dict(linthresh=cblinthresh)) + if MPL_VERSION >= Version("3.2.0"): # note that this creates an inconsistency between mpl versions # since the default value previous to mpl 3.4.0 is np.e # but it is only exposed since 3.2.0 @@ -275,27 +285,33 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): if cbnorm == "symlog": formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh) self.cb = self.figure.colorbar(self.image, self.cax, format=formatter) - if np.nanmin(data) >= 0.0: - yticks = [np.nanmin(data).v] + list( + + if zmin >= 0.0: + yticks = [zmin] + list( 10 ** np.arange( np.rint(np.log10(cblinthresh)), - np.ceil(np.log10(np.nanmax(data))), + np.ceil(np.log10(zmax)), ) ) - elif np.nanmax(data) <= 0.0: + elif zmax <= 0.0: + if MPL_VERSION >= Version("3.5.0b"): + offset = 0 + else: + offset = 1 + yticks = ( list( -( 10 ** np.arange( - np.floor(np.log10(-np.nanmin(data))), - np.rint(np.log10(cblinthresh)) - 1, + np.floor(np.log10(-zmin)), + np.rint(np.log10(cblinthresh)) - offset, -1, ) ) ) - + [np.nanmax(data).v] + + [zmax] ) else: yticks = ( @@ -303,7 +319,7 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): -( 10 ** np.arange( - np.floor(np.log10(-np.nanmin(data))), + np.floor(np.log10(-zmin)), np.rint(np.log10(cblinthresh)) - 1, -1, ) @@ -314,7 +330,7 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect): 10 ** np.arange( np.rint(np.log10(cblinthresh)), - np.ceil(np.log10(np.nanmax(data))), + np.ceil(np.log10(zmax)), ) ) ) diff --git a/yt/visualization/color_maps.py b/yt/visualization/color_maps.py index 1203e413867..0c6b4837c6a 100644 --- a/yt/visualization/color_maps.py +++ b/yt/visualization/color_maps.py @@ -1,10 +1,10 @@ import numpy as np from matplotlib import __version__ as mpl_ver, cm as mcm, colors as cc -from packaging.version import parse as parse_version +from packaging.version import Version from . import _colormap_data as _cm -MPL_VERSION = parse_version(mpl_ver) +MPL_VERSION = Version(mpl_ver) del mpl_ver @@ -260,7 +260,7 @@ def show_colormaps(subset="all", filename=None): "to be 'all', 'yt_native', or a list of " "valid colormap names." ) from e - if parse_version("2.0.0") <= MPL_VERSION < parse_version("2.2.0"): + if Version("2.0.0") <= MPL_VERSION < Version("2.2.0"): # the reason we do this filtering is to avoid spurious warnings in CI when # testing against old versions of matplotlib (currently not older than 2.0.x) # and we can't easily filter warnings at the level of the relevant test itself diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 6b8ff098a62..416a03c115c 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -184,8 +184,10 @@ def get_symlog_minorticks(linthresh, vmin, vmax): the maximum value in the colorbar """ - if vmin > 0 or vmax < 0: + if vmin > 0: return get_log_minorticks(vmin, vmax) + elif vmax < 0 and vmin < 0: + return -get_log_minorticks(-vmax, -vmin) elif vmin == 0: return np.hstack((0, get_log_minorticks(linthresh, vmax))) elif vmax == 0: diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 50c953e8056..17348e39365 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -7,7 +7,7 @@ import numpy as np from more_itertools import always_iterable from mpl_toolkits.axes_grid1 import ImageGrid -from packaging.version import Version, parse as parse_version +from packaging.version import Version from unyt.exceptions import UnitConversionError from yt._maintenance.deprecation import issue_deprecation_warning @@ -64,7 +64,7 @@ def zip_equal(*args): return zip(*args, strict=True) -MPL_VERSION = parse_version(matplotlib.__version__) +MPL_VERSION = Version(matplotlib.__version__) # Some magic for dealing with pyparsing being included or not # included in matplotlib (not in gentoo, yes in everything else) @@ -1202,7 +1202,7 @@ def _setup_plots(self): self.plots[f].cax.minorticks_on() elif self._field_transform[f] == symlog_transform: - if Version("3.2.0") <= MPL_VERSION < Version("3.5.0"): + if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"): # no known working method to draw symlog minor ticks # see https://github.com/yt-project/yt/issues/3535 pass @@ -1211,13 +1211,13 @@ def _setup_plots(self): np.log10(self.plots[f].cb.norm.linthresh) ) mticks = get_symlog_minorticks(flinthresh, vmin, vmax) - if MPL_VERSION < Version("3.5.0"): + if MPL_VERSION < Version("3.5.0b"): # https://github.com/matplotlib/matplotlib/issues/21258 mticks = self.plots[f].image.norm(mticks) self.plots[f].cax.yaxis.set_ticks(mticks, minor=True) elif self._field_transform[f] == log_transform: - if MPL_VERSION >= parse_version("3.0.0"): + if MPL_VERSION >= Version("3.0.0"): self.plots[f].cax.minorticks_on() self.plots[f].cax.xaxis.set_visible(False) else: diff --git a/yt/visualization/profile_plotter.py b/yt/visualization/profile_plotter.py index ba45a06a368..b54cc05c6aa 100644 --- a/yt/visualization/profile_plotter.py +++ b/yt/visualization/profile_plotter.py @@ -8,7 +8,7 @@ import numpy as np from matplotlib.font_manager import FontProperties from more_itertools.more import always_iterable, unzip -from packaging.version import parse as parse_version +from packaging.version import Version from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys from yt.data_objects.static_output import Dataset @@ -30,7 +30,7 @@ validate_plot, ) -MPL_VERSION = parse_version(matplotlib.__version__) +MPL_VERSION = Version(matplotlib.__version__) def invalidate_profile(f): @@ -1175,7 +1175,7 @@ def _setup_plots(self): if self._cbar_minorticks[f]: if self._field_transform[f] == linear_transform: self.plots[f].cax.minorticks_on() - elif MPL_VERSION < parse_version("3.0.0"): + elif MPL_VERSION < Version("3.0.0"): # before matplotlib 3 log-scaled colorbars internally used # a linear scale going from zero to one and did not draw # minor ticks. Since we want minor ticks, calculate