Skip to content

Commit

Permalink
Merge pull request #3651 from neutrinoceros/backport_symlog_norm_patches
Browse files Browse the repository at this point in the history
Backport symlog norm patches to yt-4.0.x
  • Loading branch information
matthewturk committed Nov 17, 2021
2 parents 1baa461 + 6c1bfa1 commit 1f1af71
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 39 deletions.
4 changes: 2 additions & 2 deletions yt/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import matplotlib
import numpy as np
from more_itertools import always_iterable, collapse, first
from packaging.version import parse as parse_version
from packaging.version import Version
from tqdm import tqdm

from yt.units import YTArray, YTQuantity
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def matplotlib_style_context(style_name=None, after_reset=False):
import matplotlib

style_name = {"mathtext.fontset": "cm"}
if parse_version(matplotlib.__version__) >= parse_version("3.3.0"):
if Version(matplotlib.__version__) >= Version("3.3.0"):
style_name["mathtext.fallback"] = "cm"
else:
style_name["mathtext.fallback_to_cm"] = True
Expand Down
4 changes: 2 additions & 2 deletions yt/utilities/on_demand_imports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from packaging.version import parse as parse_version
from packaging.version import Version


class NotAModule:
Expand Down Expand Up @@ -361,7 +361,7 @@ def __init__(self):
try:
import h5py

if parse_version(h5py.__version__) < parse_version("2.4.0"):
if Version(h5py.__version__) < Version("2.4.0"):
self._err = RuntimeError(
"yt requires h5py version 2.4.0 or newer, "
"please update h5py with e.g. `python -m pip install -U h5py` "
Expand Down
62 changes: 39 additions & 23 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from io import BytesIO

import matplotlib
import numpy as np
from packaging.version import parse as parse_version
from packaging.version import Version

from yt.funcs import (
get_brewer_cmap,
Expand All @@ -12,7 +13,7 @@
mylog,
)

from ._commons import get_canvas, validate_image_name
from ._commons import MPL_VERSION, get_canvas, validate_image_name

BACKEND_SPECS = {
"GTK": ["backend_gtk", "FigureCanvasGTK", "FigureManagerGTK"],
Expand Down Expand Up @@ -132,9 +133,9 @@ def save(self, name, mpl_kwargs=None, canvas=None):

if mpl_kwargs is None:
mpl_kwargs = {}
if "papertype" not in mpl_kwargs and parse_version(
matplotlib.__version__
) < parse_version("3.3.0"):
if "papertype" not in mpl_kwargs and Version(matplotlib.__version__) < Version(
"3.3.0"
):
mpl_kwargs["papertype"] = "auto"

name = validate_image_name(name)
Expand Down Expand Up @@ -206,20 +207,29 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
vmin=float(self.zmin) if self.zmin is not None else None,
vmax=float(self.zmax) if self.zmax is not None else None,
)
zmin = float(self.zmin) if self.zmin is not None else np.nanmin(data)
zmax = float(self.zmax) if self.zmax is not None else np.nanmax(data)

if cbnorm == "symlog":
# if cblinthresh is not specified, try to come up with a reasonable default
min_abs_val = np.min(np.abs((zmin, zmax)))
if cblinthresh is None:
cblinthresh = np.nanmin(np.absolute(data)[data != 0])
elif zmin * zmax > 0 and cblinthresh < min_abs_val:
warnings.warn(
f"Cannot set a symlog norm with linear threshold {cblinthresh} "
f"lower than the minimal absolute data value {min_abs_val} . "
"Switching to log norm."
)
cbnorm = "log10"

if cbnorm == "log10":
cbnorm_cls = matplotlib.colors.LogNorm
elif cbnorm == "linear":
cbnorm_cls = matplotlib.colors.Normalize
elif cbnorm == "symlog":
# if cblinthresh is not specified, try to come up with a reasonable default
vmin = float(np.nanmin(data))
vmax = float(np.nanmax(data))
if cblinthresh is None:
cblinthresh = np.nanmin(np.absolute(data)[data != 0])

cbnorm_kwargs.update(dict(linthresh=cblinthresh, vmin=vmin, vmax=vmax))
MPL_VERSION = parse_version(matplotlib.__version__)
if MPL_VERSION >= parse_version("3.2.0"):
cbnorm_kwargs.update(dict(linthresh=cblinthresh))
if MPL_VERSION >= Version("3.2.0"):
# note that this creates an inconsistency between mpl versions
# since the default value previous to mpl 3.4.0 is np.e
# but it is only exposed since 3.2.0
Expand Down Expand Up @@ -275,35 +285,41 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
if cbnorm == "symlog":
formatter = matplotlib.ticker.LogFormatterMathtext(linthresh=cblinthresh)
self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)
if np.nanmin(data) >= 0.0:
yticks = [np.nanmin(data).v] + list(

if zmin >= 0.0:
yticks = [zmin] + list(
10
** np.arange(
np.rint(np.log10(cblinthresh)),
np.ceil(np.log10(np.nanmax(data))),
np.ceil(np.log10(zmax)),
)
)
elif np.nanmax(data) <= 0.0:
elif zmax <= 0.0:
if MPL_VERSION >= Version("3.5.0b"):
offset = 0
else:
offset = 1

yticks = (
list(
-(
10
** np.arange(
np.floor(np.log10(-np.nanmin(data))),
np.rint(np.log10(cblinthresh)) - 1,
np.floor(np.log10(-zmin)),
np.rint(np.log10(cblinthresh)) - offset,
-1,
)
)
)
+ [np.nanmax(data).v]
+ [zmax]
)
else:
yticks = (
list(
-(
10
** np.arange(
np.floor(np.log10(-np.nanmin(data))),
np.floor(np.log10(-zmin)),
np.rint(np.log10(cblinthresh)) - 1,
-1,
)
Expand All @@ -314,7 +330,7 @@ def _init_image(self, data, cbnorm, cblinthresh, cmap, extent, aspect):
10
** np.arange(
np.rint(np.log10(cblinthresh)),
np.ceil(np.log10(np.nanmax(data))),
np.ceil(np.log10(zmax)),
)
)
)
Expand Down
6 changes: 3 additions & 3 deletions yt/visualization/color_maps.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
from matplotlib import __version__ as mpl_ver, cm as mcm, colors as cc
from packaging.version import parse as parse_version
from packaging.version import Version

from . import _colormap_data as _cm

MPL_VERSION = parse_version(mpl_ver)
MPL_VERSION = Version(mpl_ver)
del mpl_ver


Expand Down Expand Up @@ -260,7 +260,7 @@ def show_colormaps(subset="all", filename=None):
"to be 'all', 'yt_native', or a list of "
"valid colormap names."
) from e
if parse_version("2.0.0") <= MPL_VERSION < parse_version("2.2.0"):
if Version("2.0.0") <= MPL_VERSION < Version("2.2.0"):
# the reason we do this filtering is to avoid spurious warnings in CI when
# testing against old versions of matplotlib (currently not older than 2.0.x)
# and we can't easily filter warnings at the level of the relevant test itself
Expand Down
4 changes: 3 additions & 1 deletion yt/visualization/plot_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ def get_symlog_minorticks(linthresh, vmin, vmax):
the maximum value in the colorbar
"""
if vmin > 0 or vmax < 0:
if vmin > 0:
return get_log_minorticks(vmin, vmax)
elif vmax < 0 and vmin < 0:
return -get_log_minorticks(-vmax, -vmin)
elif vmin == 0:
return np.hstack((0, get_log_minorticks(linthresh, vmax)))
elif vmax == 0:
Expand Down
10 changes: 5 additions & 5 deletions yt/visualization/plot_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from more_itertools import always_iterable
from mpl_toolkits.axes_grid1 import ImageGrid
from packaging.version import Version, parse as parse_version
from packaging.version import Version
from unyt.exceptions import UnitConversionError

from yt._maintenance.deprecation import issue_deprecation_warning
Expand Down Expand Up @@ -64,7 +64,7 @@ def zip_equal(*args):
return zip(*args, strict=True)


MPL_VERSION = parse_version(matplotlib.__version__)
MPL_VERSION = Version(matplotlib.__version__)

# Some magic for dealing with pyparsing being included or not
# included in matplotlib (not in gentoo, yes in everything else)
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def _setup_plots(self):
self.plots[f].cax.minorticks_on()

elif self._field_transform[f] == symlog_transform:
if Version("3.2.0") <= MPL_VERSION < Version("3.5.0"):
if Version("3.2.0") <= MPL_VERSION < Version("3.5.0b"):
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
pass
Expand All @@ -1211,13 +1211,13 @@ def _setup_plots(self):
np.log10(self.plots[f].cb.norm.linthresh)
)
mticks = get_symlog_minorticks(flinthresh, vmin, vmax)
if MPL_VERSION < Version("3.5.0"):
if MPL_VERSION < Version("3.5.0b"):
# https://github.com/matplotlib/matplotlib/issues/21258
mticks = self.plots[f].image.norm(mticks)
self.plots[f].cax.yaxis.set_ticks(mticks, minor=True)

elif self._field_transform[f] == log_transform:
if MPL_VERSION >= parse_version("3.0.0"):
if MPL_VERSION >= Version("3.0.0"):
self.plots[f].cax.minorticks_on()
self.plots[f].cax.xaxis.set_visible(False)
else:
Expand Down
6 changes: 3 additions & 3 deletions yt/visualization/profile_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from matplotlib.font_manager import FontProperties
from more_itertools.more import always_iterable, unzip
from packaging.version import parse as parse_version
from packaging.version import Version

from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys
from yt.data_objects.static_output import Dataset
Expand All @@ -30,7 +30,7 @@
validate_plot,
)

MPL_VERSION = parse_version(matplotlib.__version__)
MPL_VERSION = Version(matplotlib.__version__)


def invalidate_profile(f):
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def _setup_plots(self):
if self._cbar_minorticks[f]:
if self._field_transform[f] == linear_transform:
self.plots[f].cax.minorticks_on()
elif MPL_VERSION < parse_version("3.0.0"):
elif MPL_VERSION < Version("3.0.0"):
# before matplotlib 3 log-scaled colorbars internally used
# a linear scale going from zero to one and did not draw
# minor ticks. Since we want minor ticks, calculate
Expand Down

0 comments on commit 1f1af71

Please sign in to comment.