From d6e01f117233d4a7d9ddb9024064efc2fa726727 Mon Sep 17 00:00:00 2001 From: Nils Vu Date: Thu, 29 Dec 2022 20:13:16 -0800 Subject: [PATCH] Refactor Render1D - Plot Lagrange interpolation in each DG element - Optionally plot collocation points, element boundaries, basis polynomials - Plot multiple variables (e.g. field and TCI status) - Set plot style with a stylesheet if needed. --- src/Visualization/Python/Render1D.py | 548 ++++++++++++------ .../Visualization/Python/Test_Render1D.py | 92 ++- 2 files changed, 429 insertions(+), 211 deletions(-) diff --git a/src/Visualization/Python/Render1D.py b/src/Visualization/Python/Render1D.py index fbada1383a763..0f9bf88509274 100755 --- a/src/Visualization/Python/Render1D.py +++ b/src/Visualization/Python/Render1D.py @@ -3,154 +3,169 @@ # Distributed under the MIT License. # See LICENSE.txt for details. -import glob -import h5py -import sys -import os -import numpy as np import logging -import matplotlib as mpl +import os +from itertools import cycle +from typing import Dict, Optional + import click -import rich -from spectre.Visualization.ReadH5 import available_subfiles - - -def find_extrema_over_data_set(arr): - ''' - Find max and min over a range of number arrays - :param arr: the array over which to find the max and min - ''' - - return (np.nanmin(arr), np.nanmax(arr)) - - -def get_data(h5files, subfile_name, var_name): - ''' - Get the data to be plotted - :param files: list of h5 filenames or common file prefix - :param subfile_name: name of .vol subfile in h5 file(s) - :param var_name: name of variable to render - :return: the list of time, coords and data - ''' - - time = [] - coords = [] - data = [] - volfiles = [h5file[subfile_name] for h5file in h5files] - # Get a list of times from the first vol file - ids_times = [(obs_id, volfiles[0][obs_id].attrs['observation_value']) - for obs_id in volfiles[0].keys()] - ids_times.sort(key=lambda pair: pair[1]) - for obs_id, local_time in ids_times: - local_coords = [] - local_data = [] - for volfile in volfiles: - try: - local_data = (local_data + list(volfile[obs_id][var_name])) - except KeyError: - print("The variable name {} is not a valid variable. " - "Use '--list-vars' to see the list of variable names in " - "the file(s) \n{}".format(var_name, - list(map(str, h5files)))) - sys.exit(1) - local_coords = (local_coords + - list(volfile[obs_id]['InertialCoordinates_x'])) - ordering = np.argsort(local_coords) - coords.append(np.array(local_coords)[ordering]) - data.append(np.array(local_data)[ordering]) - time.append(local_time) - - for h5file in h5files: - h5file.close() - return time, coords, data - - -def render_single_time(var_name, time_slice, output_prefix, time, coords, - data): - ''' - Renders image at a single time step - :param var_name: name of variable to render - :param time_slice: the integer observation step to render - :param output_prefix: name of output file - :param time: list of time steps - :param coords: list of coordinates to plot - :param data: list of variable data to plot - :return: None - ''' - import matplotlib.pyplot as plt - - plt.xlabel("x") - plt.ylabel(var_name) - try: - plt.title("t = {:.5f}".format(time[time_slice])) - plt.plot(coords[time_slice], data[time_slice], 'o') - except IndexError: - sys.exit("The integer time step provided {} is outside the range of " - "allowed time steps. The range of allowed integer time steps " - "in the files provided is 0-{}".format( - time_slice, - len(time) - 1)) - if output_prefix: - logging.info("Writing still to file {}.pdf".format(output_prefix)) - plt.savefig(output_prefix + ".pdf", format='pdf') - else: - plt.show() +import matplotlib.pyplot as plt +import numpy as np +import spectre.IO.H5 as spectre_h5 +from scipy.interpolate import lagrange +from spectre.DataStructures import DataVector +from spectre.DataStructures.Tensor import Frame, tnsr +from spectre.Domain import (FunctionOfTime, deserialize_domain, + deserialize_functions_of_time) +from spectre.IO.H5.IterElements import Element, iter_elements +from spectre.Spectral import Basis +from spectre.Visualization.PlotDatFile import parse_functions +logger = logging.getLogger(__name__) -def render_animation(var_name, output_prefix, interval, time, coords, data): - ''' - Render an animation of the data - :param var_name: name of variable to render - :param output_prefix: name of output file - :param interval: delay between frames for animation - :param time: list of time steps - :param coords: list of coordinates to plot - :param data: list of variable data to plot - :return: None - ''' - import matplotlib.pyplot as plt - import matplotlib.animation as animation - fig = plt.figure() - ax = plt.axes(xlim=(find_extrema_over_data_set( - np.concatenate(np.asarray(coords)).ravel())), - ylim=(find_extrema_over_data_set( - np.concatenate(np.asarray(data)).ravel()))) - line, = ax.plot([], [], 'o', lw=2) - ax.set_xlabel('x') - ax.set_ylabel(var_name) - title = ax.set_title("") - - def init(): - ''' - Initialize the animation canvas - :return: empty line and title info - ''' - line.set_data([], []) - title.set_text("") - return line, title - - def animate(iteration): - ''' - Update the animation canvas - :return: line and title info for particular time step - ''' - title.set_text("t = {:.5f}".format(time[iteration])) - line.set_data(coords[iteration], data[iteration]) - return line, title - - anim = animation.FuncAnimation(fig, - animate, - init_func=init, - frames=len(time), - interval=interval) - if output_prefix: - fps = 1000.0 / interval - logging.info( - "Writing animation to file {}.mp4 at {} frames per second".format( - output_prefix, fps)) - anim.save(output_prefix + ".mp4", writer='ffmpeg') - else: - plt.show() + +def get_bounds(volfiles, obs_ids, vars): + """Get the bounds in both x and y over all observations, vars, and files""" + x_bounds = [np.inf, -np.inf] + y_bounds = [np.inf, -np.inf] + for obs_id in obs_ids: + for element, vars_data in iter_elements(volfiles, obs_id, vars): + x_bounds[0] = min(x_bounds[0], + np.min(element.inertial_coordinates)) + x_bounds[1] = max(x_bounds[1], + np.max(element.inertial_coordinates)) + y_bounds[0] = min(y_bounds[0], np.nanmin(vars_data)) + y_bounds[1] = max(y_bounds[1], np.nanmax(vars_data)) + return x_bounds, y_bounds + + +def plot_element(element: Element, + vars_data: np.ndarray, + vars: Dict[str, str], + var_props: dict, + show_collocation_points: bool, + show_element_boundaries: bool, + show_basis_polynomials: bool, + time: Optional[float] = None, + functions_of_time: Optional[Dict[str, FunctionOfTime]] = None, + logical_space: np.ndarray = np.linspace(-1, 1, 50), + handles: Optional[dict] = None): + """Plot a 1D element, or update the plot in an animation""" + + # We store plots in these dicts so we can update them later in an animation + # by calling `plot_element` again + if handles is not None: + element_handles = handles.setdefault(element.id, dict()) + point_handles = element_handles.setdefault("point", dict()) + boundary_handles = element_handles.setdefault("boundary", dict()) + var_handles = element_handles.setdefault("var", dict()) + interpolation_handles = element_handles.setdefault("interp", dict()) + basis_handles = element_handles.setdefault("basis", dict()) + + # We collect legend items and return them from this function + legend_items = dict() + + inertial_coords = np.array(element.inertial_coordinates)[0] + + # Plot collocation points + if show_collocation_points: + for i, coord in enumerate(inertial_coords): + if handles and i in point_handles: + point_handles[i].set_xdata(coord) + else: + point_handles[i] = plt.axvline(coord, + color="black", + ls="dotted", + alpha=0.2) + legend_items["Collocation points"] = point_handles[0] + # Clean up leftover handles + if handles: + all_point_handles = sorted(list(point_handles.keys())) + for i in all_point_handles[all_point_handles.index(i) + 1:]: + point_handles.pop(i).remove() + + # Plot element boundaries + if show_element_boundaries: + logical_boundaries = tnsr.I[DataVector, 1, + Frame.ElementLogical](np.array([[-1., + 1.]])) + element_boundaries = np.asarray( + element.map(logical_boundaries, + time=time, + functions_of_time=functions_of_time))[0] + for i, coord in enumerate(element_boundaries): + if handles and i in boundary_handles: + boundary_handles[i].set_xdata(coord) + else: + boundary_handles[i] = plt.axvline(coord, color="black") + legend_items["Element boundaries"] = boundary_handles[0] + + # Prepare Lagrange interpolation + # Only show Lagrange interpolation for spectral elements. Finite-difference + # elements just show the data points. + show_lagrange_interpolation = all([ + basis == Basis.Legendre or basis == Basis.Chebyshev + for basis in element.mesh.basis() + ]) + if show_lagrange_interpolation: + # These are the control points for the Lagrange interpolation + logical_coords = np.array(element.logical_coordinates)[0] + # These are the points where we plot the interpolation + logical_space_tensor = tnsr.I[DataVector, 1, Frame.ElementLogical]( + np.expand_dims(logical_space, axis=0)) + inertial_space = np.asarray( + element.map(logical_space_tensor, + time=time, + functions_of_time=functions_of_time))[0] + + # Plot selected variables + for (var, label), var_data in zip(vars.items(), vars_data): + # Plot data points + if handles and var in var_handles: + var_handles[var].set_data(inertial_coords, var_data) + else: + var_handles[var] = plt.plot(inertial_coords, + var_data, + marker=".", + ls="none", + **var_props[var])[0] + legend_items[label] = var_handles[var] + # Plot Lagrange interpolation + if show_lagrange_interpolation: + interpolant = lagrange(logical_coords, var_data) + if handles and var in interpolation_handles: + interpolation_handles[var].set_data(inertial_space, + interpolant(logical_space)) + else: + interpolation_handles[var] = plt.plot( + inertial_space, interpolant(logical_space), + **var_props[var])[0] + # Plot polynomial basis + if show_basis_polynomials: + unit_weights = np.eye(len(logical_coords)) + for xi_i in range(len(logical_coords)): + basis_id = (var, xi_i) + basis_polynomial = lagrange( + logical_coords, var_data[xi_i] * unit_weights[xi_i]) + if handles and basis_id in basis_handles: + basis_handles[basis_id].set_data( + inertial_space, basis_polynomial(logical_space)) + else: + basis_handles[basis_id] = plt.plot( + inertial_space, + basis_polynomial(logical_space), + color="black", + alpha=0.2)[0] + elif handles: + # Clean up leftover handles + if var in interpolation_handles: + interpolation_handles.pop(var).remove() + for basis_id in list(basis_handles.keys()): + basis_handles.pop(basis_id).remove() + if show_lagrange_interpolation and show_basis_polynomials: + legend_items["Lagrange basis"] = basis_handles[basis_id] + + return legend_items @click.command() @@ -164,11 +179,19 @@ def animate(iteration): "-d", help=("Name of subfile within h5 file containing " "1D volume data to be rendered.")) -@click.option("--var", - "-y", - help=("Name of variable to render. E.g. 'Psi' " - "or 'Error(Psi)'. Can be specified multiple times. " - "If unspecified, print available variables and exit.")) +@click.option( + "--var", + "-y", + "vars", + multiple=True, + callback=parse_functions, + help=("Name of variable to plot, e.g. 'Psi' or 'Error(Psi)'. " + "Can be specified multiple times. " + "If unspecified, plot all available variables. " + "Labels for variables can be specified as key-value pairs such as " + "'Error(Psi)=$L_2(\\psi)$'. Remember to wrap the key-value pair in " + "quotes on the command line to avoid issues with special characters " + "or spaces.")) @click.option("--list-vars", "-l", is_flag=True, @@ -181,66 +204,215 @@ def animate(iteration): "--output", help=("Set the name of the output file you want " "written. For animations this saves an mp4 file and " - "for stills a pdf. Name of file should not include " - "file extension.")) -@click.option('--fps', - type=float, - default=5, - help=("Set the number of frames per second when writing " - "an animation to disk.")) + "for stills a pdf.")) @click.option('--interval', + default=100, type=float, help="Delay between frames in milliseconds") -def render_1d_command(h5_files, subfile_name, list_vars, **args): +# Plotting options +@click.option('--x-label', + help="The label on the x-axis.", + show_default="name of the x-axis column") +@click.option('--y-label', + required=False, + help="The label on the y-axis.", + show_default="no label") +@click.option('--x-logscale', + is_flag=True, + help="Set the x-axis to log scale.") +@click.option('--y-logscale', + is_flag=True, + help="Set the y-axis to log scale.") +@click.option('--x-bounds', + type=float, + nargs=2, + help="The lower and upper bounds of the x-axis.") +@click.option('--y-bounds', + type=float, + nargs=2, + help="The lower and upper bounds of the y-axis.") +@click.option('--title', + '-t', + help="Title of the graph.", + show_default="subfile name") +@click.option( + '--stylesheet', + '-s', + type=click.Path(exists=True, file_okay=True, dir_okay=False, + readable=True), + envvar="SPECTRE_MPL_STYLESHEET", + help=("Select a matplotlib stylesheet for customization of the plot, such " + "as linestyle cycles, linewidth, fontsize, legend, etc. " + "The stylesheet can also be set with the 'SPECTRE_MPL_STYLESHEET' " + "environment variable.")) +@click.option("--show-collocation-points", is_flag=True) +@click.option("--show-element-boundaries", is_flag=True) +@click.option("--show-basis-polynomials", is_flag=True) +def render_1d_command(h5_files, subfile_name, list_vars, vars, output, x_label, + y_label, x_logscale, y_logscale, x_bounds, y_bounds, + title, stylesheet, step, interval, + **plot_element_kwargs): """Render 1D data""" # Script should be a noop if input files are empty if not h5_files: return - open_h5_files = [h5py.File(filename, "r") for filename in h5_files] + open_h5_files = [spectre_h5.H5File(filename, "r") for filename in h5_files] # Print available subfile names and exit if not subfile_name: import rich.columns - rich.print( - rich.columns.Columns( - available_subfiles(open_h5_files[0], extension=".vol"))) + rich.print(rich.columns.Columns(open_h5_files[0].all_vol_files())) return - if not subfile_name.endswith(".vol"): - subfile_name += ".vol" + if subfile_name.endswith(".vol"): + subfile_name = subfile_name[:-4] + if not subfile_name.startswith("/"): + subfile_name = "/" + subfile_name - # Print available variables and exit - if list_vars or not args['var']: - volfile = open_h5_files[0][subfile_name] - obs_id_0 = next(iter(volfile)) - variables = list(volfile[obs_id_0].keys()) - variables.remove("connectivity") - variables.remove("InertialCoordinates_x") - variables.remove("total_extents") - variables.remove("grid_names") - variables.remove("bases") - variables.remove("quadratures") - if "domain" in variables: - variables.remove("domain") - if "functions_of_time" in variables: - variables.remove("functions_of_time") + volfiles = [h5file.get_vol(subfile_name) for h5file in open_h5_files] + dim = volfiles[0].get_dimension() + assert dim == 1, ( + f"The selected subfile contains {dim}D volume data, not 1D.") + obs_ids = volfiles[0].list_observation_ids() # Already sorted by obs value + all_vars = volfiles[0].list_tensor_components(obs_ids[0]) + if "InertialCoordinates_x" in all_vars: + all_vars.remove("InertialCoordinates_x") + # Print available variables and exit + if list_vars: import rich.columns - rich.print(rich.columns.Columns(variables)) + rich.print(rich.columns.Columns(all_vars)) return + for var in vars: + if var not in all_vars: + raise click.UsageError(f"Unknown variable '{var}'. " + f"Available variables are: {all_vars}") + if not vars: + vars = {var: var for var in all_vars} + plot_element_kwargs["vars"] = vars + + # Apply stylesheet + if stylesheet is not None: + plt.style.use(stylesheet) + + # Evaluate property cycles for each variable (by default this is just + # 'color'). We do multiple plotting commands per variable (at least one per + # element), so we don't want matplotlib to cycle through the properties at + # every plotting command. + prop_cycle = { + key: cycle(values) + for key, values in plt.rcParams['axes.prop_cycle'].by_key().items() + } + var_props = { + var: {key: next(values) + for key, values in prop_cycle.items()} + for var in vars + } + plot_element_kwargs["var_props"] = var_props - time, coords, data = get_data(open_h5_files, subfile_name, args['var']) - if args['interval'] is None: - interval = 1000.0 / args['fps'] + # Animate or single frame? + if len(obs_ids) == 1: + animate = False + obs_id = obs_ids[0] + elif step is None: + animate = True + obs_id = obs_ids[0] else: - interval = args['interval'] - if args['step'] is None: - render_animation(args['var'], args['output'], interval, time, coords, - data) + animate = False + obs_id = obs_ids[step] + obs_value = volfiles[0].get_observation_value(obs_id) + + # For Lagrange interpolation + domain = deserialize_domain[1](volfiles[0].get_domain(obs_id)) + if domain.is_time_dependent(): + functions_of_time = deserialize_functions_of_time( + volfiles[0].get_functions_of_time(obs_id)) + plot_element_kwargs["functions_of_time"] = functions_of_time + plot_element_kwargs["time"] = obs_value + + # We store plots here so we can update them later + plot_element_kwargs["handles"] = dict() + + # Plot first frame + fig = plt.figure() + for element, vars_data in iter_elements(volfiles, obs_id, vars): + legend_items = plot_element(element, vars_data, **plot_element_kwargs) + + # Configure the axes + if x_logscale: + plt.xscale("log") + if y_logscale: + plt.yscale("log") + plt.xlabel(x_label if x_label else "x") + plt.ylabel(y_label) + plt.legend(legend_items.values(), legend_items.keys()) + title_handle = plt.title(title if title else f"t = {obs_value:g}") + if animate and not (x_bounds and y_bounds): + data_bounds = get_bounds(volfiles, obs_ids, vars) + if not x_bounds: + x_bounds = data_bounds[0] + if not y_bounds: + y_bounds = data_bounds[1] + if not y_logscale: + margin = (y_bounds[1] - y_bounds[0]) * 0.05 + y_bounds[0] -= margin + y_bounds[1] += margin + if x_bounds: + plt.xlim(*x_bounds) + if y_bounds: + plt.ylim(*y_bounds) + + if animate: + import matplotlib.animation + import rich.progress + + progress = rich.progress.Progress( + rich.progress.TextColumn( + "[progress.description]{task.description}"), + rich.progress.BarColumn(), rich.progress.MofNCompleteColumn(), + rich.progress.TimeRemainingColumn()) + task_id = progress.add_task("Rendering", total=len(obs_ids)) + + def update(frame): + obs_id = obs_ids[frame] + obs_value = volfiles[0].get_observation_value(obs_id) + if domain.is_time_dependent(): + functions_of_time = deserialize_functions_of_time( + volfiles[0].get_functions_of_time(obs_id)) + plot_element_kwargs["functions_of_time"] = functions_of_time + plot_element_kwargs["time"] = obs_value + title_handle.set_text(title if title else f"t = {obs_value:g}") + for element, vars_data in iter_elements(volfiles, obs_id, vars): + plot_element(element, vars_data, **plot_element_kwargs) + progress.update(task_id, completed=frame + 1) + + anim = matplotlib.animation.FuncAnimation(fig=fig, + func=update, + frames=range(len(obs_ids)), + interval=interval, + blit=False) + + if output: + if animate: + if not output.endswith(".mp4"): + output += ".mp4" + with progress: + anim.save(output, writer='ffmpeg') + else: + if not output.endswith(".pdf"): + output += ".pdf" + plt.savefig(output, bbox_inches="tight") else: - render_single_time(args['var'], args['step'], args['output'], time, - coords, data) + if not os.environ.get("DISPLAY"): + logger.warning( + "No 'DISPLAY' environment variable is configured so plotting " + "interactively is unlikely to work. Write the plot to a file " + "with the --output/-o option.") + plt.show() + + for h5file in open_h5_files: + h5file.close() if __name__ == "__main__": diff --git a/tests/Unit/Visualization/Python/Test_Render1D.py b/tests/Unit/Visualization/Python/Test_Render1D.py index 92fa1a84190e0..270981c30f2bd 100644 --- a/tests/Unit/Visualization/Python/Test_Render1D.py +++ b/tests/Unit/Visualization/Python/Test_Render1D.py @@ -3,34 +3,80 @@ # Distributed under the MIT License. # See LICENSE.txt for details. -from spectre.Visualization.Render1D import (find_extrema_over_data_set, - render_single_time) - -import unittest import os +import shutil +import unittest + import numpy as np -import matplotlib as mpl -mpl.use('agg') +import spectre.IO.H5 as spectre_h5 +from click.testing import CliRunner +from spectre.Domain import ElementId, serialize_domain +from spectre.Domain.Creators import Interval +from spectre.Informer import unit_test_build_path +from spectre.IO.H5 import ElementVolumeData, TensorComponent +from spectre.Spectral import Basis, Mesh, Quadrature +from spectre.Visualization.Render1D import render_1d_command class TestRender1D(unittest.TestCase): - def test_find_extrema_over_data_set(self): - test_array = np.array([1.1, 6.45, 0.34, 2.3]) - expected_vals = (0.34, 6.45) - self.assertEqual(find_extrema_over_data_set(test_array), expected_vals) - - def test_render_single_time(self): - var_name = "Variable Test" - time_slice = 1 - output_prefix = "TestRenderSingleTime" - time = [0.0, 0.1] - coords = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] - data = [[5.2, 4.5, 9.0, 2.0, 8.0], [1.1, 4.0, 6.0, 5.3, 3.0]] - # test whether a pdf file is saved when run - render_single_time(var_name, time_slice, output_prefix, time, coords, - data) - self.assertTrue(os.path.isfile(output_prefix + '.pdf')) - os.remove(output_prefix + '.pdf') + def setUp(self): + self.test_dir = os.path.join(unit_test_build_path(), + 'Visualization/Render1D') + shutil.rmtree(self.test_dir, ignore_errors=True) + os.makedirs(self.test_dir, exist_ok=True) + + # Generate 1D volume data + domain = Interval(lower_x=[0.], + upper_x=[1.], + initial_refinement_level_x=[1], + initial_number_of_grid_points_in_x=[4], + is_periodic_in_x=[False]).create_domain() + serialized_domain = serialize_domain(domain) + mesh = Mesh[1](4, Basis.Legendre, Quadrature.GaussLobatto) + self.h5file = os.path.join(self.test_dir, "voldata.h5") + with spectre_h5.H5File(self.h5file, "w") as open_h5file: + volfile = open_h5file.insert_vol("/VolumeData", version=0) + volfile.write_volume_data( + observation_id=0, + observation_value=1., + elements=[ + ElementVolumeData( + ElementId[1]("[B0,(L1I0)]"), + [TensorComponent("U", np.random.rand(4))], mesh), + ElementVolumeData( + ElementId[1]("[B0,(L1I1)]"), + [TensorComponent("U", np.random.rand(4))], mesh), + ], + serialized_domain=serialized_domain) + volfile.write_volume_data( + observation_id=1, + observation_value=2., + elements=[ + ElementVolumeData( + ElementId[1]("[B0,(L1I0)]"), + [TensorComponent("U", np.random.rand(4))], mesh), + ElementVolumeData( + ElementId[1]("[B0,(L1I1)]"), + [TensorComponent("U", np.random.rand(4))], mesh), + ], + serialized_domain=serialized_domain) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def test_render_1d(self): + # Can't easily test the layout of the plot, so we just test that the + # script runs without error and produces output. + # We also don't have ffmpeg installed in the CI container, so we can't + # test an animation. + runner = CliRunner() + plot_file = os.path.join(self.test_dir, "plot") + result = runner.invoke( + render_1d_command, + [self.h5file, "-d", "VolumeData", "--step", "0", "-o", plot_file], + catch_exceptions=False) + self.assertEqual(result.exit_code, 0, msg=result.output) + self.assertTrue(os.path.exists(plot_file + ".pdf"), msg=result.output) if __name__ == '__main__':