From 907de8d0a571c1788644a9f8910818b0e50144f5 Mon Sep 17 00:00:00 2001 From: Marc Williamson Date: Wed, 2 Jun 2021 16:31:52 -0400 Subject: [PATCH] matplotlib last interaction radius plot working. --- tardis/visualization/__init__.py | 1 + .../tools/interaction_radius_plot.py | 334 ++++++++++++++++++ tardis/visualization/tools/sdec_plot.py | 2 + 3 files changed, 337 insertions(+) create mode 100644 tardis/visualization/tools/interaction_radius_plot.py diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py index 4b806fdd147..9edc30d3805 100644 --- a/tardis/visualization/__init__.py +++ b/tardis/visualization/__init__.py @@ -9,3 +9,4 @@ from tardis.visualization.widgets.line_info import LineInfoWidget from tardis.visualization.widgets.custom_abundance import CustomAbundanceWidget from tardis.visualization.tools.sdec_plot import SDECPlotter +from tardis.visualization.tools.interaction_radius_plot import InteractionRadiusPlotter diff --git a/tardis/visualization/tools/interaction_radius_plot.py b/tardis/visualization/tools/interaction_radius_plot.py new file mode 100644 index 00000000000..a98ecb4adea --- /dev/null +++ b/tardis/visualization/tools/interaction_radius_plot.py @@ -0,0 +1,334 @@ +""" +Last interaction radius plot package for TARDIS simulations. + +This plot is a spectral diagnostics plot similar to those originally +proposed in Williamson et al. (2021). +""" + +import tardis.visualization.tools.sdec_plot as sdec + +import numpy as np +import pandas as pd +import astropy.units as u + +from tardis.util.base import ( + atomic_number2element_symbol, + element_symbol2atomic_number, + species_string_to_tuple, + species_tuple_to_string, + roman_to_int, + int_to_roman, +) + +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import matplotlib.colors as clr +import plotly.graph_objects as go + + +class InteractionRadiusPlotter: + """ + Plotting interface for the interaction radius plot. + """ + + def __init__(self, data, time_explosion): + """ + Initialize the plotter with required data from the simulation. + + Parameters + ---------- + data : dict of SDECData + Dictionary to store data required for interaction radius plot, + for both packet modes (real, virtual). + """ + + self.data = data + self.time_explosion = time_explosion + return + + @classmethod + def from_simulation(cls, sim): + """ + Create an instance of the plotter from a TARDIS simulation object. + + Parameters + ---------- + sim : tardis.simulation.Simulation + TARDIS simulation object produced by running a simulation. + + Returns + ------- + Plotter + """ + + return cls(dict(virtual=sdec.SDECData.from_simulation(sim, "virtual"), + real=sdec.SDECData.from_simulation(sim, "real")), + sim.model.time_explosion) + + @classmethod + def from_hdf(cls, hdf_fpath): + """ + Create an instance of the Plotter from a simulation HDF file. + + Parameters + ---------- + hdf_fpath : str + Valid path to the HDF file where simulation is saved. + + Returns + ------- + Plotter + """ + hdfstore = pd.HDFStore(hdf_fpath) + time_explosion = hdfstore['/simulation/plasma/scalars']['time_explosion'] * u.s + return cls(dict(virtual=sdec.SDECData.from_hdf(hdf_fpath, "virtual"), + real=sdec.SDECData.from_hdf(hdf_fpath, "real")), + ) + + def _parse_species_list(self, species_list): + """ + Parse user requested species list and create list of species ids to be used. + + Parameters + ---------- + species_list : list of species to plot + List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. + Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions + (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + + """ + if species_list is not None: + # check if there are any digits in the species list. If there are, then exit. + # species_list should only contain species in the Roman numeral + # format, e.g. Si II, and each ion must contain a space + if any(char.isdigit() for char in " ".join(species_list)) == True: + raise ValueError( + "All species must be in Roman numeral form, e.g. Si II" + ) + else: + full_species_list = [] + for species in species_list: + # check if a hyphen is present. If it is, then it indicates a + # range of ions. Add each ion in that range to the list as a new entry + if "-" in species: + # split the string on spaces. First thing in the list is then the element + element = species.split(" ")[0] + # Next thing is the ion range + # convert the requested ions into numerals + first_ion_numeral = roman_to_int( + species.split(" ")[-1].split("-")[0] + ) + second_ion_numeral = roman_to_int( + species.split(" ")[-1].split("-")[-1] + ) + # add each ion between the two requested into the species list + for ion_number in np.arange( + first_ion_numeral, second_ion_numeral + 1 + ): + full_species_list.append( + f"{element} {int_to_roman(ion_number)}" + ) + else: + # Otherwise it's either an element or ion so just add to the list + full_species_list.append(species) + + # full_species_list is now a list containing each individual species requested + # e.g. it parses species_list = [Si I - V] into species_list = [Si I, Si II, Si III, Si IV, Si V] + self._full_species_list = full_species_list + requested_species_ids = [] + keep_colour = [] + + # go through each of the requested species. Check whether it is + # an element or ion (ions have spaces). If it is an element, + # add all possible ions to the ions list. Otherwise just add + # the requested ion + for species in full_species_list: + if " " in species: + requested_species_ids.append( + [ + species_string_to_tuple(species)[0] * 100 + + species_string_to_tuple(species)[1] + ] + ) + else: + atomic_number = element_symbol2atomic_number(species) + requested_species_ids.append( + [ + atomic_number * 100 + ion_number + for ion_number in np.arange(atomic_number) + ] + ) + # add the atomic number to a list so you know that this element should + # have all species in the same colour, i.e. it was requested like + # species_list = [Si] + keep_colour.append(atomic_number) + requested_species_ids = [ + species_id + for temp_list in requested_species_ids + for species_id in temp_list + ] + + self._species_list = requested_species_ids + self._keep_colour = keep_colour + else: + self._species_list = None + return + + def _make_colorbar_labels(self): + """Get the labels for the species in the colorbar.""" + if self._species_list is None: + # If species_list is none then the labels are just elements + species_name = [ + atomic_number2element_symbol(atomic_num) + for atomic_num in self.species + ] + else: + species_name = [] + for species in self.species: + # Go through each species requested + ion_number = species % 100 + atomic_number = (species - ion_number) / 100 + + ion_numeral = int_to_roman(ion_number + 1) + atomic_symbol = atomic_number2element_symbol(atomic_number) + + # if the element was requested, and not a specific ion, then + # add the element symbol to the label list + if (atomic_number in self._keep_colour) & ( + atomic_symbol not in species_name + ): + # compiling the label, and adding it to the list + label = f"{atomic_symbol}" + species_name.append(label) + elif atomic_number not in self._keep_colour: + # otherwise add the ion to the label list + label = f"{atomic_symbol} {ion_numeral}" + species_name.append(label) + + self._species_name = species_name + return + + def _make_colorbar_colors(self): + """Get the colours for the species to be plotted.""" + # the colours depends on the species present in the model and what's requested + # some species need to be shown in the same colour, so the exact colours have to be + # worked out + + color_list = [] + + # Colors for each element + # Create new variables to keep track of the last atomic number that was plotted + # This is used when plotting species in case an element was given in the list + # This is to ensure that all ions of that element are grouped together + # ii is to track the colour index + # e.g. if Si is given in species_list, this is to ensure Si I, Si II, etc. all have the same colour + color_counter = 0 + previous_atomic_number = 0 + for species_counter, identifier in enumerate(self.species): + if self._species_list is not None: + # Get the ion number and atomic number for each species + ion_number = identifier % 100 + atomic_number = (identifier - ion_number) / 100 + if previous_atomic_number == 0: + # If this is the first species being plotted, then take note of the atomic number + # don't update the colour index + color_counter = color_counter + previous_atomic_number = atomic_number + elif previous_atomic_number in self._keep_colour: + # If the atomic number is in the list of elements that should all be plotted in the same colour + # then don't update the colour index if this element has been plotted already + if previous_atomic_number == atomic_number: + color_counter = color_counter + previous_atomic_number = atomic_number + else: + # Otherwise, increase the colour counter by one, because this is a new element + color_counter = color_counter + 1 + previous_atomic_number = atomic_number + else: + # If this is just a normal species that was requested then increment the colour index + color_counter = color_counter + 1 + previous_atomic_number = atomic_number + # Calculate the colour of this species + color = self.cmap(color_counter / len(self._species_name)) + + else: + # If you're not using species list then this is just a fraction based on the total + # number of columns in the dataframe + color = self.cmap(species_counter / len(self.species)) + + color_list.append(color) + + self._color_list = color_list + + return + + def _show_colorbar_mpl(self): + """Show matplotlib colorbar with labels of elements mapped to colors.""" + + color_values = [ + self.cmap(species_counter / len(self._species_name)) + for species_counter in range(len(self._species_name)) + ] + + custcmap = clr.ListedColormap(color_values) + norm = clr.Normalize(vmin=0, vmax=len(self._species_name)) + mappable = cm.ScalarMappable(norm=norm, cmap=custcmap) + mappable.set_array(np.linspace(1, len(self._species_name) + 1, 256)) + cbar = plt.colorbar(mappable, ax=self.ax) + + bounds = np.arange(len(self._species_name)) + 0.5 + cbar.set_ticks(bounds) + + cbar.set_ticklabels(self._species_name) + return + + def generate_plot_mpl(self, + packets_mode="virtual", + ax=None, + figsize=(12, 7), + cmapname="jet", + species_list=None): + """ + Generate the last interaction radius distribution plot + using matplotlib. + """ + + # Parse the requested species list + self._parse_species_list(species_list=species_list) + species_in_model = np.unique( + self.data[packets_mode].packets_df_line_interaction['last_line_interaction_species'].values) + msk = np.isin(self._species_list, species_in_model) + self.species = np.array(self._species_list)[msk] + + if ax is None: + self.ax = plt.figure(figsize=figsize).add_subplot(111) + else: + self.ax = ax + + # Get the labels in the color bar. This determines the number of unique colors + self._make_colorbar_labels() + # Set colormap to be used in elements of emission and absorption plots + self.cmap = cm.get_cmap(cmapname, len(self._species_name)) + # Get the number of unqie colors + self._make_colorbar_colors() + self._show_colorbar_mpl() + + groups = self.data[packets_mode].packets_df_line_interaction.groupby(by='last_line_interaction_species') + + plot_colors = [] + plot_data = [] + + for species_counter, identifier in enumerate(self.species): + g_df = groups.get_group(identifier) + r_last_interaction = g_df['last_interaction_in_r'].values * u.cm + v_last_interaction = (r_last_interaction / self.time_explosion).to('km/s') + plot_data.append(v_last_interaction) + plot_colors.append(self._color_list[species_counter]) + + self.ax.hist(plot_data, bins=50, color=plot_colors) + self.ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) + self.ax.tick_params('both', labelsize=20) + self.ax.set_xlabel('Last Interaction Velocity (km/s)', fontsize=25) + self.ax.set_ylabel('Packet Count', fontsize=25) + + return plt.gca() diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index eb146ee0294..de5a3abe968 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -401,6 +401,8 @@ def from_hdf(cls, hdf_fpath, packets_mode): ) + + class SDECPlotter: """ Plotting interface for Spectral element DEComposition (SDEC) Plot.