Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport symlog norm patches to yt-4.0.x #3651

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions yt/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import matplotlib
import numpy as np
from more_itertools import always_iterable, collapse, first
from packaging.version import parse as parse_version
from packaging.version import Version
from tqdm import tqdm

from yt.units import YTArray, YTQuantity
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def matplotlib_style_context(style_name=None, after_reset=False):
import matplotlib

style_name = {"mathtext.fontset": "cm"}
if parse_version(matplotlib.__version__) >= parse_version("3.3.0"):
if Version(matplotlib.__version__) >= Version("3.3.0"):
style_name["mathtext.fallback"] = "cm"
else:
style_name["mathtext.fallback_to_cm"] = True
Expand Down
4 changes: 2 additions & 2 deletions yt/utilities/on_demand_imports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from packaging.version import parse as parse_version
from packaging.version import Version


class NotAModule:
Expand Down Expand Up @@ -361,7 +361,7 @@ def __init__(self):
try:
import h5py

if parse_version(h5py.__version__) < parse_version("2.4.0"):
if Version(h5py.__version__) < Version("2.4.0"):
self._err = RuntimeError(
"yt requires h5py version 2.4.0 or newer, "
"please update h5py with e.g. `python -m pip install -U h5py` "
Expand Down
62 changes: 39 additions & 23 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from io import BytesIO

import matplotlib
import numpy as np
from packaging.version import parse as parse_version
from packaging.version import Version

from yt.funcs import (
get_brewer_cmap,
Expand All @@ -12,7 +13,7 @@
mylog,
)

from ._commons import get_canvas, validate_image_name
from ._commons import MPL_VERSION, get_canvas, validate_image_name

BACKEND_SPECS = {
"GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"],
Expand Down Expand Up @@ -132,9 +133,9 @@ def save(self, name, mpl_kwargs=None, canvas=None):

if mpl_kwargs is None:
mpl_kwargs = {}
if "papertype" not in mpl_kwargs and parse_version(
matplotlib.__version__
) < parse_version("3.3.0"):
if "papertype" not in mpl_kwargs and Version(matplotlib.__version__) < Version(
"3.3.0"
):
mpl_kwargs["papertype"] = "auto"

name = validate_image_name(name)
Expand Down Expand Up @@ -206,20 +207,29 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
vmin=float(self.zmin) if self.zmin is not None else None,
vmax=float(self.zmax) if self.zmax is not None else None,
)
zmin = float(self.zmin) if self.zmin is not None else np.nanmin(data)
zmax = float(self.zmax) if self.zmax is not None else np.nanmax(data)

if cbnorm == "symlog":
# if cblinthresh is not specified, try to come up with a reasonable default
min_abs_val = np.min(np.abs((zmin, zmax)))
if cblinthresh is None:
cblinthresh = np.nanmin(np.absolute(data)[data != 0])
elif zmin * zmax > 0 and cblinthresh < min_abs_val:
warnings.warn(
f"Cannot set a symlog norm with linear threshold {cblinthresh} "
f"lower than the minimal absolute data value {min_abs_val} . "
"Switching to log norm."
)
cbnorm = "log10"

if cbnorm == "log10":
cbnorm_cls = matplotlib.colors.LogNorm
elif cbnorm == "linear":
cbnorm_cls = matplotlib.colors.Normalize
elif cbnorm == "symlog":
# if cblinthresh is not specified, try to come up with a reasonable default
vmin = float(np.nanmin(data))
vmax = float(np.nanmax(data))
if cblinthresh is None:
cblinthresh = np.nanmin(np.absolute(data)[data != 0])

cbnorm_kwargs.update(dict(linthresh=cblinthresh, vmin=vmin, vmax=vmax))
MPL_VERSION = parse_version(matplotlib.__version__)
if MPL_VERSION >= parse_version("3.2.0"):
cbnorm_kwargs.update(dict(linthresh=cblinthresh))
if MPL_VERSION >= Version("3.2.0"):
# note that this creates an inconsistency between mpl versions
# since the default value previous to mpl 3.4.0 is np.e
# but it is only exposed since 3.2.0
Expand Down Expand Up @@ -275,35 +285,41 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
if cbnorm == "symlog":
formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh)
self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)
if np.nanmin(data) >= 0.0:
yticks = [np.nanmin(data).v] + list(

if zmin >= 0.0:
yticks = [zmin] + list(
10
** np.arange(
np.rint(np.log10(cblinthresh)),
np.ceil(np.log10(np.nanmax(data))),
np.ceil(np.log10(zmax)),
)
)
elif np.nanmax(data) <= 0.0:
elif zmax <= 0.0:
if MPL_VERSION >= Version("3.5.0b"):
offset = 0
else:
offset = 1

yticks = (
list(
-(
10
** np.arange(
np.floor(np.log10(-np.nanmin(data))),
np.rint(np.log10(cblinthresh)) - 1,
np.floor(np.log10(-zmin)),
np.rint(np.log10(cblinthresh)) - offset,
-1,
)
)
)
+ [np.nanmax(data).v]
+ [zmax]
)
else:
yticks = (
list(
-(
10
** np.arange(
np.floor(np.log10(-np.nanmin(data))),
np.floor(np.log10(-zmin)),
np.rint(np.log10(cblinthresh)) - 1,
-1,
)
Expand All @@ -314,7 +330,7 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
10
** np.arange(
np.rint(np.log10(cblinthresh)),
np.ceil(np.log10(np.nanmax(data))),
np.ceil(np.log10(zmax)),
)
)
)
Expand Down
6 changes: 3 additions & 3 deletions yt/visualization/color_maps.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
from matplotlib import __version__ as mpl_ver, cm as mcm, colors as cc
from packaging.version import parse as parse_version
from packaging.version import Version

from . import _colormap_data as _cm

MPL_VERSION = parse_version(mpl_ver)
MPL_VERSION = Version(mpl_ver)
del mpl_ver


Expand Down Expand Up @@ -260,7 +260,7 @@ def show_colormaps(subset="all", filename=None):
"to be 'all', 'yt_native', or a list of "
"valid colormap names."
) from e
if parse_version("2.0.0") <= MPL_VERSION < parse_version("2.2.0"):
if Version("2.0.0") <= MPL_VERSION < Version("2.2.0"):
# the reason we do this filtering is to avoid spurious warnings in CI when
# testing against old versions of matplotlib (currently not older than 2.0.x)
# and we can't easily filter warnings at the level of the relevant test itself
Expand Down
4 changes: 3 additions & 1 deletion yt/visualization/plot_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,10 @@ def get_symlog_minorticks(linthresh, vmin, vmax):
the maximum value in the colorbar

"""
if vmin > 0 or vmax < 0:
if vmin > 0:
return get_log_minorticks(vmin, vmax)
elif vmax < 0 and vmin < 0:
return -get_log_minorticks(-vmax, -vmin)
elif vmin == 0:
return np.hstack((0, get_log_minorticks(linthresh, vmax)))
elif vmax == 0:
Expand Down
10 changes: 5 additions & 5 deletions yt/visualization/plot_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from more_itertools import always_iterable
from mpl_toolkits.axes_grid1 import ImageGrid
from packaging.version import Version, parse as parse_version
from packaging.version import Version
from unyt.exceptions import UnitConversionError

from yt._maintenance.deprecation import issue_deprecation_warning
Expand Down Expand Up @@ -64,7 +64,7 @@ def zip_equal(*args):
return zip(*args, strict=True)


MPL_VERSION = parse_version(matplotlib.__version__)
MPL_VERSION = Version(matplotlib.__version__)

# Some magic for dealing with pyparsing being included or not
# included in matplotlib (not in gentoo, yes in everything else)
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def _setup_plots(self):
self.plots[f].cax.minorticks_on()

elif self._field_transform[f] == symlog_transform:
if Version("3.2.0") <= MPL_VERSION < Version("3.5.0"):
if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"):
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
pass
Expand All @@ -1203,13 +1203,13 @@ def _setup_plots(self):
np.log10(self.plots[f].cb.norm.linthresh)
)
mticks = get_symlog_minorticks(flinthresh, vmin, vmax)
if MPL_VERSION < Version("3.5.0"):
if MPL_VERSION < Version("3.5.0b"):
# https://github.com/matplotlib/matplotlib/issues/21258
mticks = self.plots[f].image.norm(mticks)
self.plots[f].cax.yaxis.set_ticks(mticks, minor=True)

elif self._field_transform[f] == log_transform:
if MPL_VERSION >= parse_version("3.0.0"):
if MPL_VERSION >= Version("3.0.0"):
self.plots[f].cax.minorticks_on()
self.plots[f].cax.xaxis.set_visible(False)
else:
Expand Down
6 changes: 3 additions & 3 deletions yt/visualization/profile_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from matplotlib.font_manager import FontProperties
from more_itertools.more import always_iterable, unzip
from packaging.version import parse as parse_version
from packaging.version import Version

from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys
from yt.data_objects.static_output import Dataset
Expand All @@ -30,7 +30,7 @@
validate_plot,
)

MPL_VERSION = parse_version(matplotlib.__version__)
MPL_VERSION = Version(matplotlib.__version__)


def invalidate_profile(f):
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def _setup_plots(self):
if self._cbar_minorticks[f]:
if self._field_transform[f] == linear_transform:
self.plots[f].cax.minorticks_on()
elif MPL_VERSION < parse_version("3.0.0"):
elif MPL_VERSION < Version("3.0.0"):
# before matplotlib 3 log-scaled colorbars internally used
# a linear scale going from zero to one and did not draw
# minor ticks. Since we want minor ticks, calculate
Expand Down