diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..cf5cb2db4aa --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 80 +target-version = ['py36'] \ No newline at end of file diff --git a/tardis/__init__.py b/tardis/__init__.py index 77b6a16437b..152b22a141d 100644 --- a/tardis/__init__.py +++ b/tardis/__init__.py @@ -6,21 +6,24 @@ import pyne.data from tardis.util.colored_logger import ColoredFormatter, formatter_message + # Affiliated packages may add whatever they like to this file, but # should keep this content at the top. # ---------------------------------------------------------------------------- from ._astropy_init import * + # ---------------------------------------------------------------------------- from tardis.base import run_tardis from tardis.io.util import yaml_load_config_file as yaml_load -warnings.filterwarnings('ignore', category=pyne.data.QAWarning) + +warnings.filterwarnings("ignore", category=pyne.data.QAWarning) FORMAT = "[$BOLD%(name)-20s$RESET][%(levelname)-18s] %(message)s ($BOLD%(filename)s$RESET:%(lineno)d)" COLOR_FORMAT = formatter_message(FORMAT, True) logging.captureWarnings(True) -logger = logging.getLogger('tardis') +logger = logging.getLogger("tardis") logger.setLevel(logging.INFO) console_handler = logging.StreamHandler(sys.stdout) @@ -28,4 +31,4 @@ console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) -logging.getLogger('py.warnings').addHandler(console_handler) +logging.getLogger("py.warnings").addHandler(console_handler) diff --git a/tardis/_astropy_init.py b/tardis/_astropy_init.py index 3d761295658..7a4e25060e2 100644 --- a/tardis/_astropy_init.py +++ b/tardis/_astropy_init.py @@ -1,12 +1,13 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -__all__ = ['__version__', '__githash__', 'test'] +__all__ = ["__version__", "__githash__", "test"] # this indicates whether or not we are in the package's setup.py try: _ASTROPY_SETUP_ except NameError: from sys import version_info + if version_info[0] >= 3: import builtins else: @@ -16,21 +17,34 @@ try: from .version import version as __version__ except ImportError: - __version__ = '' + __version__ = "" try: from .version import githash as __githash__ except ImportError: - __githash__ = '' + __githash__ = "" # set up the test command def _get_test_runner(): import os from astropy.tests.helper import TestRunner + return TestRunner(os.path.dirname(__file__)) -def test(package=None, test_path=None, args=None, plugins=None, - verbose=False, pastebin=None, remote_data=False, pep8=False, - pdb=False, coverage=False, open_files=False, **kwargs): + +def test( + package=None, + test_path=None, + args=None, + plugins=None, + verbose=False, + pastebin=None, + remote_data=False, + pep8=False, + pdb=False, + coverage=False, + open_files=False, + **kwargs, +): """ Run the tests using `py.test `__. A proper set of arguments is constructed and passed to `pytest.main`_. @@ -105,10 +119,20 @@ def test(package=None, test_path=None, args=None, plugins=None, """ test_runner = _get_test_runner() return test_runner.run_tests( - package=package, test_path=test_path, args=args, - plugins=plugins, verbose=verbose, pastebin=pastebin, - remote_data=remote_data, pep8=pep8, pdb=pdb, - coverage=coverage, open_files=open_files, **kwargs) + package=package, + test_path=test_path, + args=args, + plugins=plugins, + verbose=verbose, + pastebin=pastebin, + remote_data=remote_data, + pep8=pep8, + pdb=pdb, + coverage=coverage, + open_files=open_files, + **kwargs, + ) + if not _ASTROPY_SETUP_: import os @@ -118,21 +142,30 @@ def test(package=None, test_path=None, args=None, plugins=None, # add these here so we only need to cleanup the namespace at the end config_dir = None - if not os.environ.get('ASTROPY_SKIP_CONFIG_UPDATE', False): + if not os.environ.get("ASTROPY_SKIP_CONFIG_UPDATE", False): config_dir = os.path.dirname(__file__) config_template = os.path.join(config_dir, __package__ + ".cfg") if os.path.isfile(config_template): try: config.configuration.update_default_config( - __package__, config_dir, version=__version__) + __package__, config_dir, version=__version__ + ) except TypeError as orig_error: try: config.configuration.update_default_config( - __package__, config_dir) + __package__, config_dir + ) except config.configuration.ConfigurationDefaultMissingError as e: - wmsg = (e.args[0] + " Cannot install default profile. If you are " - "importing from source, this is expected.") - warn(config.configuration.ConfigurationDefaultMissingWarning(wmsg)) + wmsg = ( + e.args[0] + + " Cannot install default profile. If you are " + "importing from source, this is expected." + ) + warn( + config.configuration.ConfigurationDefaultMissingWarning( + wmsg + ) + ) del e except: raise orig_error diff --git a/tardis/analysis.py b/tardis/analysis.py index 8118eca7ce0..89a8781079f 100644 --- a/tardis/analysis.py +++ b/tardis/analysis.py @@ -1,4 +1,4 @@ -#codes to for analyse the model. +# codes to for analyse the model. import re import os @@ -10,25 +10,36 @@ class LastLineInteraction(object): - @classmethod def from_model(cls, model): - return cls(model.runner.last_line_interaction_in_id, - model.runner.last_line_interaction_out_id, - model.runner.last_line_interaction_shell_id, - model.runner.output_nu, model.plasma.atomic_data.lines) - - def __init__(self, last_line_interaction_in_id, - last_line_interaction_out_id, last_line_interaction_shell_id, - output_nu, lines, packet_filter_mode='packet_nu'): + return cls( + model.runner.last_line_interaction_in_id, + model.runner.last_line_interaction_out_id, + model.runner.last_line_interaction_shell_id, + model.runner.output_nu, + model.plasma.atomic_data.lines, + ) + + def __init__( + self, + last_line_interaction_in_id, + last_line_interaction_out_id, + last_line_interaction_shell_id, + output_nu, + lines, + packet_filter_mode="packet_nu", + ): # mask out packets which did not perform a line interaction # TODO mask out packets which do not escape to observer? mask = last_line_interaction_out_id != -1 self.last_line_interaction_in_id = last_line_interaction_in_id[mask] self.last_line_interaction_out_id = last_line_interaction_out_id[mask] - self.last_line_interaction_shell_id = last_line_interaction_shell_id[mask] + self.last_line_interaction_shell_id = last_line_interaction_shell_id[ + mask + ] self.last_line_interaction_angstrom = output_nu.to( - u.Angstrom, equivalencies=u.spectral())[mask] + u.Angstrom, equivalencies=u.spectral() + )[mask] self.lines = lines self._wavelength_start = 0 * u.angstrom @@ -38,27 +49,25 @@ def __init__(self, last_line_interaction_in_id, self.packet_filter_mode = packet_filter_mode self.update_last_interaction_filter() - - @property def wavelength_start(self): - return self._wavelength_start.to('angstrom') + return self._wavelength_start.to("angstrom") @wavelength_start.setter def wavelength_start(self, value): if not isinstance(value, u.Quantity): - raise ValueError('needs to be a Quantity') + raise ValueError("needs to be a Quantity") self._wavelength_start = value self.update_last_interaction_filter() @property def wavelength_end(self): - return self._wavelength_end.to('angstrom') + return self._wavelength_end.to("angstrom") @wavelength_end.setter def wavelength_end(self, value): if not isinstance(value, u.Quantity): - raise ValueError('needs to be a Quantity') + raise ValueError("needs to be a Quantity") self._wavelength_end = value self.update_last_interaction_filter() @@ -81,126 +90,150 @@ def ion_number(self, value): self.update_last_interaction_filter() def update_last_interaction_filter(self): - if self.packet_filter_mode == 'packet_nu': + if self.packet_filter_mode == "packet_nu": packet_filter = ( - (self.last_line_interaction_angstrom > - self.wavelength_start) & - (self.last_line_interaction_angstrom < - self.wavelength_end)) - elif self.packet_filter_mode == 'line_in_nu': - line_in_nu = ( - self.lines.wavelength.iloc[ - self.last_line_interaction_in_id].values) + self.last_line_interaction_angstrom > self.wavelength_start + ) & (self.last_line_interaction_angstrom < self.wavelength_end) + elif self.packet_filter_mode == "line_in_nu": + line_in_nu = self.lines.wavelength.iloc[ + self.last_line_interaction_in_id + ].values packet_filter = ( - (line_in_nu > self.wavelength_start.to(u.angstrom).value) & - (line_in_nu < self.wavelength_end.to(u.angstrom).value)) - + line_in_nu > self.wavelength_start.to(u.angstrom).value + ) & (line_in_nu < self.wavelength_end.to(u.angstrom).value) self.last_line_in = self.lines.iloc[ - self.last_line_interaction_in_id[packet_filter]] + self.last_line_interaction_in_id[packet_filter] + ] self.last_line_out = self.lines.iloc[ - self.last_line_interaction_out_id[packet_filter]] + self.last_line_interaction_out_id[packet_filter] + ] if self.atomic_number is not None: self.last_line_in = self.last_line_in.xs( - self.atomic_number, level='atomic_number', drop_level=False) + self.atomic_number, level="atomic_number", drop_level=False + ) self.last_line_out = self.last_line_out.xs( - self.atomic_number, level='atomic_number', drop_level=False) + self.atomic_number, level="atomic_number", drop_level=False + ) if self.ion_number is not None: self.last_line_in = self.last_line_in.xs( - self.ion_number, level='ion_number', drop_level=False) + self.ion_number, level="ion_number", drop_level=False + ) self.last_line_out = self.last_line_out.xs( - self.ion_number, level='ion_number', drop_level=False) + self.ion_number, level="ion_number", drop_level=False + ) last_line_in_count = self.last_line_in.line_id.value_counts() last_line_out_count = self.last_line_out.line_id.value_counts() self.last_line_in_table = self.last_line_in.reset_index()[ - [ - 'wavelength', 'atomic_number', 'ion_number', - 'level_number_lower', 'level_number_upper']] - self.last_line_in_table['count'] = last_line_in_count - self.last_line_in_table.sort_values(by='count', ascending=False, - inplace=True) + [ + "wavelength", + "atomic_number", + "ion_number", + "level_number_lower", + "level_number_upper", + ] + ] + self.last_line_in_table["count"] = last_line_in_count + self.last_line_in_table.sort_values( + by="count", ascending=False, inplace=True + ) self.last_line_out_table = self.last_line_out.reset_index()[ - [ - 'wavelength', 'atomic_number', 'ion_number', - 'level_number_lower', 'level_number_upper']] - self.last_line_out_table['count'] = last_line_out_count - self.last_line_out_table.sort_values(by='count', ascending=False, - inplace=True) + [ + "wavelength", + "atomic_number", + "ion_number", + "level_number_lower", + "level_number_upper", + ] + ] + self.last_line_out_table["count"] = last_line_out_count + self.last_line_out_table.sort_values( + by="count", ascending=False, inplace=True + ) def plot_wave_in_out(self, fig, do_clf=True, plot_resonance=True): if do_clf: fig.clf() ax = fig.add_subplot(111) - wave_in = self.last_line_list_in['wavelength'] - wave_out = self.last_line_list_out['wavelength'] + wave_in = self.last_line_list_in["wavelength"] + wave_out = self.last_line_list_out["wavelength"] if plot_resonance: min_wave = np.min([wave_in.min(), wave_out.min()]) max_wave = np.max([wave_in.max(), wave_out.max()]) - ax.plot([min_wave, max_wave], [min_wave, max_wave], 'b-') + ax.plot([min_wave, max_wave], [min_wave, max_wave], "b-") - ax.plot(wave_in, wave_out, 'b.', picker=True) - ax.set_xlabel('Last interaction Wave in') - ax.set_ylabel('Last interaction Wave out') + ax.plot(wave_in, wave_out, "b.", picker=True) + ax.set_xlabel("Last interaction Wave in") + ax.set_ylabel("Last interaction Wave out") def onpick(event): print("-" * 80) - print("Line_in (%d/%d):\n%s" % ( - len(event.ind), self.current_no_packets, - self.last_line_list_in.ix[event.ind])) + print( + "Line_in (%d/%d):\n%s" + % ( + len(event.ind), + self.current_no_packets, + self.last_line_list_in.ix[event.ind], + ) + ) print("\n\n") - print("Line_out (%d/%d):\n%s" % ( - len(event.ind), self.current_no_packets, - self.last_line_list_in.ix[event.ind])) + print( + "Line_out (%d/%d):\n%s" + % ( + len(event.ind), + self.current_no_packets, + self.last_line_list_in.ix[event.ind], + ) + ) print("^" * 80) def onpress(event): pass - fig.canvas.mpl_connect('pick_event', onpick) - fig.canvas.mpl_connect('on_press', onpress) + fig.canvas.mpl_connect("pick_event", onpick) + fig.canvas.mpl_connect("on_press", onpress) class TARDISHistory(object): """ Records the history of the model """ + def __init__(self, hdf5_fname, iterations=None): self.hdf5_fname = hdf5_fname if iterations is None: iterations = [] - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") for key in hdf_store.keys(): - if key.split('/')[1] == 'atom_data': + if key.split("/")[1] == "atom_data": continue iterations.append( - int(re.match(r'model(\d+)', key.split('/')[1]).groups()[0])) + int(re.match(r"model(\d+)", key.split("/")[1]).groups()[0]) + ) self.iterations = np.sort(np.unique(iterations)) hdf_store.close() else: - self.iterations=iterations + self.iterations = iterations self.levels = None self.lines = None - - def load_atom_data(self): if self.levels is None or self.lines is None: - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') - self.levels = hdf_store['atom_data/levels'] - self.lines = hdf_store['atom_data/lines'] + hdf_store = pd.HDFStore(self.hdf5_fname, "r") + self.levels = hdf_store["atom_data/levels"] + self.lines = hdf_store["atom_data/lines"] hdf_store.close() - def load_t_inner(self, iterations=None): t_inners = [] - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") if iterations is None: iterations = self.iterations @@ -210,7 +243,9 @@ def load_t_inner(self, iterations=None): iterations = self.iterations[iterations] for iter in iterations: - t_inners.append(hdf_store['model%03d/configuration' %iter].ix['t_inner']) + t_inners.append( + hdf_store["model%03d/configuration" % iter].ix["t_inner"] + ) hdf_store.close() t_inners = np.array(t_inners) @@ -218,7 +253,7 @@ def load_t_inner(self, iterations=None): def load_t_rads(self, iterations=None): t_rads_dict = {} - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") if iterations is None: iterations = self.iterations @@ -227,10 +262,9 @@ def load_t_rads(self, iterations=None): else: iterations = self.iterations[iterations] - for iter in iterations: - current_iter = 'iter%03d' % iter - t_rads_dict[current_iter] = hdf_store['model%03d/t_rads' % iter] + current_iter = "iter%03d" % iter + t_rads_dict[current_iter] = hdf_store["model%03d/t_rads" % iter] t_rads = pd.DataFrame(t_rads_dict) hdf_store.close() @@ -238,7 +272,7 @@ def load_t_rads(self, iterations=None): def load_ws(self, iterations=None): ws_dict = {} - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") if iterations is None: iterations = self.iterations @@ -248,8 +282,8 @@ def load_ws(self, iterations=None): iterations = self.iterations[iterations] for iter in iterations: - current_iter = 'iter{:03d}'.format(iter) - ws_dict[current_iter] = hdf_store['model{:03d}/ws'.format(iter)] + current_iter = "iter{:03d}".format(iter) + ws_dict[current_iter] = hdf_store["model{:03d}/ws".format(iter)] hdf_store.close() @@ -257,7 +291,7 @@ def load_ws(self, iterations=None): def load_level_populations(self, iterations=None): level_populations_dict = {} - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") is_scalar = False if iterations is None: iterations = self.iterations @@ -268,9 +302,10 @@ def load_level_populations(self, iterations=None): iterations = self.iterations[iterations] for iter in iterations: - current_iter = 'iter%03d' % iter + current_iter = "iter%03d" % iter level_populations_dict[current_iter] = hdf_store[ - 'model{:03d}/level_populations'.format(iter)] + "model{:03d}/level_populations".format(iter) + ] hdf_store.close() if is_scalar: @@ -280,7 +315,7 @@ def load_level_populations(self, iterations=None): def load_jblues(self, iterations=None): jblues_dict = {} - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") is_scalar = False if iterations is None: iterations = self.iterations @@ -291,9 +326,10 @@ def load_jblues(self, iterations=None): iterations = self.iterations[iterations] for iter in iterations: - current_iter = 'iter{:03d}'.format(iter) + current_iter = "iter{:03d}".format(iter) jblues_dict[current_iter] = hdf_store[ - 'model{:03d}/j_blues'.format(iter)] + "model{:03d}/j_blues".format(iter) + ] hdf_store.close() if is_scalar: @@ -301,10 +337,9 @@ def load_jblues(self, iterations=None): else: return pd.Panel(jblues_dict) - def load_ion_populations(self, iterations=None): ion_populations_dict = {} - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + hdf_store = pd.HDFStore(self.hdf5_fname, "r") is_scalar = False if iterations is None: @@ -316,9 +351,10 @@ def load_ion_populations(self, iterations=None): iterations = self.iterations[iterations] for iter in iterations: - current_iter = 'iter{:03d}'.format(iter) + current_iter = "iter{:03d}".format(iter) ion_populations_dict[current_iter] = hdf_store[ - 'model{:03d}/ion_populations'.format(iter)] + "model{:03d}/ion_populations".format(iter) + ] hdf_store.close() if is_scalar: @@ -326,37 +362,47 @@ def load_ion_populations(self, iterations=None): else: return pd.Panel(ion_populations_dict) - def load_spectrum(self, iteration, spectrum_keyword='luminosity_density'): - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') + def load_spectrum(self, iteration, spectrum_keyword="luminosity_density"): + hdf_store = pd.HDFStore(self.hdf5_fname, "r") - spectrum = hdf_store['model%03d/%s' % (self.iterations[iteration], spectrum_keyword)] + spectrum = hdf_store[ + "model%03d/%s" % (self.iterations[iteration], spectrum_keyword) + ] hdf_store.close() return spectrum def calculate_relative_lte_level_populations(self, species, iteration=-1): self.load_atom_data() t_rads = self.load_t_rads(iteration) - beta_rads = 1 / (constants.k_B.cgs.value * t_rads.values[:,0]) + beta_rads = 1 / (constants.k_B.cgs.value * t_rads.values[:, 0]) species_levels = self.levels.ix[species] relative_lte_level_populations = ( - (species_levels.g.values[np.newaxis].T / - float(species_levels.g.loc[0])) * - np.exp(-beta_rads * species_levels.energy.values[np.newaxis].T)) + species_levels.g.values[np.newaxis].T + / float(species_levels.g.loc[0]) + ) * np.exp(-beta_rads * species_levels.energy.values[np.newaxis].T) - return pd.DataFrame(relative_lte_level_populations, index=species_levels.index) + return pd.DataFrame( + relative_lte_level_populations, index=species_levels.index + ) def calculate_departure_coefficients(self, species, iteration=-1): self.load_atom_data() t_rads = self.load_t_rads(iteration) - beta_rads = 1 / (constants.k_B.cgs.value * t_rads.values[:,0]) + beta_rads = 1 / (constants.k_B.cgs.value * t_rads.values[:, 0]) species_levels = self.levels.ix[species] - species_level_populations = self.load_level_populations(iteration).ix[species] - departure_coefficient = ((species_level_populations.values * species_levels.g.ix[0]) / - (species_level_populations.ix[0].values * species_levels.g.values[np.newaxis].T)) \ - * np.exp(beta_rads * species_levels.energy.values[np.newaxis].T) + species_level_populations = self.load_level_populations(iteration).ix[ + species + ] + departure_coefficient = ( + (species_level_populations.values * species_levels.g.ix[0]) + / ( + species_level_populations.ix[0].values + * species_levels.g.values[np.newaxis].T + ) + ) * np.exp(beta_rads * species_levels.energy.values[np.newaxis].T) return pd.DataFrame(departure_coefficient, index=species_levels.index) @@ -364,15 +410,28 @@ def get_last_line_interaction(self, iteration=-1): iteration = self.iterations[iteration] self.load_atom_data() - hdf_store = pd.HDFStore(self.hdf5_fname, 'r') - model_string = 'model'+('%03d' % iteration) + '/%s' - last_line_interaction_in_id = hdf_store[model_string % 'last_line_interaction_in_id'].values - last_line_interaction_out_id = hdf_store[model_string % 'last_line_interaction_out_id'].values - last_line_interaction_shell_id = hdf_store[model_string % 'last_line_interaction_shell_id'].values + hdf_store = pd.HDFStore(self.hdf5_fname, "r") + model_string = "model" + ("%03d" % iteration) + "/%s" + last_line_interaction_in_id = hdf_store[ + model_string % "last_line_interaction_in_id" + ].values + last_line_interaction_out_id = hdf_store[ + model_string % "last_line_interaction_out_id" + ].values + last_line_interaction_shell_id = hdf_store[ + model_string % "last_line_interaction_shell_id" + ].values try: - montecarlo_nu = hdf_store[model_string % 'montecarlo_nus_path'].values + montecarlo_nu = hdf_store[ + model_string % "montecarlo_nus_path" + ].values except KeyError: - montecarlo_nu = hdf_store[model_string % 'montecarlo_nus'].values + montecarlo_nu = hdf_store[model_string % "montecarlo_nus"].values hdf_store.close() - return LastLineInteraction(last_line_interaction_in_id, last_line_interaction_out_id, last_line_interaction_shell_id, - montecarlo_nu, self.lines) + return LastLineInteraction( + last_line_interaction_in_id, + last_line_interaction_out_id, + last_line_interaction_shell_id, + montecarlo_nu, + self.lines, + ) diff --git a/tardis/base.py b/tardis/base.py index 065e52017f8..906bc34cfab 100644 --- a/tardis/base.py +++ b/tardis/base.py @@ -1,7 +1,9 @@ # functions that are important for the general usage of TARDIS -def run_tardis(config, atom_data=None, packet_source=None, - simulation_callbacks=[]): + +def run_tardis( + config, atom_data=None, packet_source=None, simulation_callbacks=[] +): """ This function is one of the core functions to run TARDIS from a given config object. @@ -35,9 +37,9 @@ def run_tardis(config, atom_data=None, packet_source=None, except TypeError: tardis_config = Configuration.from_config_dict(config) - simulation = Simulation.from_config(tardis_config, - packet_source=packet_source, - atom_data=atom_data) + simulation = Simulation.from_config( + tardis_config, packet_source=packet_source, atom_data=atom_data + ) for cb in simulation_callbacks: simulation.add_callback(*cb) diff --git a/tardis/conftest.py b/tardis/conftest.py index faae42e5caf..9a2573861a0 100644 --- a/tardis/conftest.py +++ b/tardis/conftest.py @@ -6,7 +6,8 @@ # test infrastructure. from astropy.version import version as astropy_version -if astropy_version < '3.0': + +if astropy_version < "3.0": # With older versions of Astropy, we actually need to import the pytest # plugins themselves in order to make them discoverable by pytest. from astropy.tests.pytest_plugins import * @@ -15,7 +16,10 @@ # automatically made available when Astropy is installed. This means it's # not necessary to import them here, but we still need to import global # variables that are used for configuration. - from astropy.tests.plugins.display import PYTEST_HEADER_MODULES, TESTED_VERSIONS + from astropy.tests.plugins.display import ( + PYTEST_HEADER_MODULES, + TESTED_VERSIONS, + ) from astropy.tests.helper import enable_deprecations_as_exceptions @@ -82,15 +86,15 @@ # the tests. Making it pass for KeyError is essential in some cases when # the package uses other astropy affiliated packages. try: - PYTEST_HEADER_MODULES['Numpy'] = 'numpy' - PYTEST_HEADER_MODULES['Scipy'] = 'scipy' - PYTEST_HEADER_MODULES['Pandas'] = 'pandas' - PYTEST_HEADER_MODULES['Astropy'] = 'astropy' - PYTEST_HEADER_MODULES['Yaml'] = 'yaml' - PYTEST_HEADER_MODULES['Cython'] = 'cython' - PYTEST_HEADER_MODULES['h5py'] = 'h5py' - PYTEST_HEADER_MODULES['Matplotlib'] = 'matplotlib' - PYTEST_HEADER_MODULES['Ipython'] = 'IPython' + PYTEST_HEADER_MODULES["Numpy"] = "numpy" + PYTEST_HEADER_MODULES["Scipy"] = "scipy" + PYTEST_HEADER_MODULES["Pandas"] = "pandas" + PYTEST_HEADER_MODULES["Astropy"] = "astropy" + PYTEST_HEADER_MODULES["Yaml"] = "yaml" + PYTEST_HEADER_MODULES["Cython"] = "cython" + PYTEST_HEADER_MODULES["h5py"] = "h5py" + PYTEST_HEADER_MODULES["Matplotlib"] = "matplotlib" + PYTEST_HEADER_MODULES["Ipython"] = "IPython" # del PYTEST_HEADER_MODULES['h5py'] except (NameError, KeyError): # NameError is needed to support Astropy < 1.0 pass @@ -105,40 +109,50 @@ try: from .version import version except ImportError: - version = 'dev' + version = "dev" try: packagename = os.path.basename(os.path.dirname(__file__)) TESTED_VERSIONS[packagename] = version -except NameError: # Needed to support Astropy <= 1.0.0 +except NameError: # Needed to support Astropy <= 1.0.0 pass - # ------------------------------------------------------------------------- # Initialization # ------------------------------------------------------------------------- def pytest_addoption(parser): - parser.addoption("--tardis-refdata", default=None, - help="Path to Tardis Reference Folder") - parser.addoption("--integration-tests", - dest="integration-tests", default=None, - help="path to configuration file for integration tests") - parser.addoption("--generate-reference", - action="store_true", default=False, - help="generate reference data instead of testing") - parser.addoption("--less-packets", - action="store_true", default=False, - help="Run integration tests with less packets.") + parser.addoption( + "--tardis-refdata", default=None, help="Path to Tardis Reference Folder" + ) + parser.addoption( + "--integration-tests", + dest="integration-tests", + default=None, + help="path to configuration file for integration tests", + ) + parser.addoption( + "--generate-reference", + action="store_true", + default=False, + help="generate reference data instead of testing", + ) + parser.addoption( + "--less-packets", + action="store_true", + default=False, + help="Run integration tests with less packets.", + ) + # ------------------------------------------------------------------------- # project specific fixtures # ------------------------------------------------------------------------- -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def generate_reference(pytestconfig): option = pytestconfig.getvalue("generate_reference") if option is None: @@ -151,32 +165,32 @@ def generate_reference(pytestconfig): def tardis_ref_path(pytestconfig): tardis_ref_path = pytestconfig.getvalue("tardis_refdata") if tardis_ref_path is None: - pytest.skip('--tardis-refdata was not specified') + pytest.skip("--tardis-refdata was not specified") else: return os.path.expandvars(os.path.expanduser(tardis_ref_path)) + from tardis.tests.fixtures.atom_data import * + @pytest.yield_fixture(scope="session") def tardis_ref_data(tardis_ref_path, generate_reference): if generate_reference: - mode = 'w' + mode = "w" else: - mode = 'r' + mode = "r" with pd.HDFStore( - os.path.join( - tardis_ref_path, - 'unit_test_data.h5'), - mode=mode - ) as store: + os.path.join(tardis_ref_path, "unit_test_data.h5"), mode=mode + ) as store: yield store - @pytest.fixture def tardis_config_verysimple(): return yaml_load_config_file( - 'tardis/io/tests/data/tardis_configv1_verysimple.yml') + "tardis/io/tests/data/tardis_configv1_verysimple.yml" + ) + ### # HDF Fixtures @@ -185,14 +199,14 @@ def tardis_config_verysimple(): @pytest.fixture(scope="session") def hdf_file_path(tmpdir_factory): - path = tmpdir_factory.mktemp('hdf_buffer').join('test.hdf') + path = tmpdir_factory.mktemp("hdf_buffer").join("test.hdf") return str(path) @pytest.fixture(scope="session") def config_verysimple(): - filename = 'tardis_configv1_verysimple.yml' - path = os.path.abspath(os.path.join('tardis/io/tests/data/', filename)) + filename = "tardis_configv1_verysimple.yml" + path = os.path.abspath(os.path.join("tardis/io/tests/data/", filename)) config = Configuration.from_yaml(path) return config diff --git a/tardis/constants.py b/tardis/constants.py index 269d599b5fc..c4926adee17 100644 --- a/tardis/constants.py +++ b/tardis/constants.py @@ -1 +1 @@ -from astropy.constants.astropyconst13 import * \ No newline at end of file +from astropy.constants.astropyconst13 import * diff --git a/tardis/gui/datahandler.py b/tardis/gui/datahandler.py index b7bc9a803a2..4e9798a2761 100644 --- a/tardis/gui/datahandler.py +++ b/tardis/gui/datahandler.py @@ -6,30 +6,33 @@ import matplotlib.pylab as plt -if os.environ.get('QT_API', None)=='pyqt': +if os.environ.get("QT_API", None) == "pyqt": from PyQt5 import QtGui, QtCore, QtWidgets -elif os.environ.get('QT_API', None)=='pyside': +elif os.environ.get("QT_API", None) == "pyside": from PySide2 import QtGui, QtCore, QtWidgets else: - raise ImportError('QT_API was not set! Please exit the IPython console\n' - ' and at the bash prompt use : \n\n export QT_API=pyside \n or\n' - ' export QT_API=pyqt \n\n For more information refer to user guide.') + raise ImportError( + "QT_API was not set! Please exit the IPython console\n" + " and at the bash prompt use : \n\n export QT_API=pyside \n or\n" + " export QT_API=pyqt \n\n For more information refer to user guide." + ) import yaml from tardis import run_tardis from tardis.gui.widgets import MatplotlibWidget, ModelViewer, ShellInfo from tardis.gui.widgets import LineInfo, LineInteractionTables -if (parse_version(matplotlib.__version__) >= parse_version('1.4')): - matplotlib.style.use('fivethirtyeight') +if parse_version(matplotlib.__version__) >= parse_version("1.4"): + matplotlib.style.use("fivethirtyeight") else: print("Please upgrade matplotlib to a version >=1.4 for best results!") -matplotlib.rcParams['font.family'] = 'serif' -matplotlib.rcParams['font.size'] = 10.0 -matplotlib.rcParams['lines.linewidth'] = 1.0 -matplotlib.rcParams['axes.formatter.use_mathtext'] = True -matplotlib.rcParams['axes.edgecolor'] = matplotlib.rcParams['grid.color'] -matplotlib.rcParams['axes.linewidth'] = matplotlib.rcParams['grid.linewidth'] +matplotlib.rcParams["font.family"] = "serif" +matplotlib.rcParams["font.size"] = 10.0 +matplotlib.rcParams["lines.linewidth"] = 1.0 +matplotlib.rcParams["axes.formatter.use_mathtext"] = True +matplotlib.rcParams["axes.edgecolor"] = matplotlib.rcParams["grid.color"] +matplotlib.rcParams["axes.linewidth"] = matplotlib.rcParams["grid.linewidth"] + class Node(object): """Object that serves as the nodes in the TreeModel. @@ -93,8 +96,8 @@ def __init__(self, data, parent=None): self.parent = parent self.children = [] self.data = data - self.siblings = {} #For 'type' fields. Will store the nodes to - #enable disable on selection + self.siblings = {} # For 'type' fields. Will store the nodes to + # enable disable on selection def append_child(self, child): """Add a child to this node.""" @@ -159,6 +162,7 @@ def set_data(self, column, value): return True + class TreeModel(QtCore.QAbstractItemModel): """The class that defines the tree for ConfigEditor. @@ -174,6 +178,7 @@ class TreeModel(QtCore.QAbstractItemModel): nodes that have values that can be set from a list. """ + def __init__(self, dictionary, parent=None): """Create a tree of tardis configuration dictionary. @@ -192,7 +197,7 @@ def __init__(self, dictionary, parent=None): self.typenodes = [] self.dict_to_tree(dictionary, self.root) - #mandatory functions for subclasses + # mandatory functions for subclasses def columnCount(self, index): """Return the number of columns in the node pointed to by the given model index. @@ -222,15 +227,19 @@ def flags(self, index): return QtCore.Qt.NoItemFlags node = index.internalPointer() - if ((node.get_parent() in self.disabledNodes) or - (node in self.disabledNodes)): + if (node.get_parent() in self.disabledNodes) or ( + node in self.disabledNodes + ): return QtCore.Qt.NoItemFlags - if node.num_children()==0: - return (QtCore.Qt.ItemIsEditable | QtCore.Qt.ItemIsEnabled | - QtCore.Qt.ItemIsSelectable) + if node.num_children() == 0: + return ( + QtCore.Qt.ItemIsEditable + | QtCore.Qt.ItemIsEnabled + | QtCore.Qt.ItemIsSelectable + ) - return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable + return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable def getItem(self, index): """Returns the node to which the model index is pointing. If the @@ -249,8 +258,10 @@ def headerData(self, section, orientation, role): be needed for QTreeView. """ - if (orientation == QtCore.Qt.Horizontal and - role == QtCore.Qt.DisplayRole): + if ( + orientation == QtCore.Qt.Horizontal + and role == QtCore.Qt.DisplayRole + ): return self.root.get_data(section) return None @@ -283,8 +294,9 @@ def insertRows(self, position, rows, parent=QtCore.QModelIndex()): """Insert rows in the tree model.""" parentItem = self.getItem(parent) self.beginInsertRows(parent, position, position + rows - 1) - success = parentItem.insertChildren(position, rows, - self.rootItem.columnCount()) + success = parentItem.insertChildren( + position, rows, self.rootItem.columnCount() + ) self.endInsertRows() return success @@ -327,8 +339,9 @@ def setData(self, index, value, role=QtCore.Qt.EditRole): self.dataChanged.emit(index, index) return result - def setHeaderData(self, section, orientation, value, - role=QtCore.Qt.EditRole): + def setHeaderData( + self, section, orientation, value, role=QtCore.Qt.EditRole + ): """Change header data. Unused in columnview.""" if role != QtCore.Qt.EditRole or orientation != QtCore.Qt.Horizontal: return False @@ -338,8 +351,8 @@ def setHeaderData(self, section, orientation, value, self.headerDataChanged.emit(orientation, section, section) return result - - #Custom functions + + # Custom functions def dict_to_tree(self, dictionary, root): """Create the tree and append siblings to nodes that need them. @@ -351,11 +364,11 @@ def dict_to_tree(self, dictionary, root): The root node of the tree. """ - #Construct tree with all nodes + # Construct tree with all nodes self.tree_from_node(dictionary, root) - #Append siblings to type nodes - for node in self.typenodes: #For every type node + # Append siblings to type nodes + for node in self.typenodes: # For every type node parent = node.get_parent() sibsdict = {} for i in range(parent.num_children()): @@ -363,8 +376,8 @@ def dict_to_tree(self, dictionary, root): typesleaf = node.get_child(0) for i in range(typesleaf.num_columns()): - sibstrings = typesleaf.get_data(i).split('|_:_|') - + sibstrings = typesleaf.get_data(i).split("|_:_|") + typesleaf.set_data(i, sibstrings[0]) sibslist = [] for j in range(1, len(sibstrings)): @@ -372,15 +385,14 @@ def dict_to_tree(self, dictionary, root): sibslist.append(sibsdict[sibstrings[j]]) typesleaf.siblings[sibstrings[0]] = sibslist - - #Then append siblings of current selection for all type nodes to - #disabled nodes - for i in range(1,typesleaf.num_columns()): + + # Then append siblings of current selection for all type nodes to + # disabled nodes + for i in range(1, typesleaf.num_columns()): key = typesleaf.get_data(i) for nd in typesleaf.siblings[key]: self.disabledNodes.append(nd) - def tree_from_node(self, dictionary, root): """Convert dictionary to tree. Called by dict_to_tree.""" for key in dictionary: @@ -390,15 +402,15 @@ def tree_from_node(self, dictionary, root): self.tree_from_node(dictionary[key], child) elif isinstance(dictionary[key], list): if isinstance(dictionary[key][1], list): - leaf = Node(dictionary[key][1]) + leaf = Node(dictionary[key][1]) else: leaf = Node([dictionary[key][1]]) child.append_child(leaf) - if key == 'type': + if key == "type": self.typenodes.append(child) - def dict_from_node(self, node): + def dict_from_node(self, node): """Take a node and convert the whole subtree rooted at it into a dictionary. @@ -412,7 +424,7 @@ def dict_from_node(self, node): else: dictionary[nd.get_data(0)] = self.dict_from_node(nd) return dictionary - elif len(children)==1: + elif len(children) == 1: return children[0].get_data(0) @@ -421,21 +433,24 @@ class TreeDelegate(QtWidgets.QStyledItemDelegate): TreeModel. """ + def __init__(self, parent=None): """Call the constructor of the superclass.""" QtWidgets.QStyledItemDelegate.__init__(self, parent) - - #Mandatory methods for subclassing + + # Mandatory methods for subclassing def createEditor(self, parent, option, index): """Create a lineEdit or combobox depending on the type of node.""" node = index.internalPointer() - if node.num_columns()>1: + if node.num_columns() > 1: combobox = QtGui.QComboBox(parent) - combobox.addItems([node.get_data(i) for i in range(node.num_columns())]) + combobox.addItems( + [node.get_data(i) for i in range(node.num_columns())] + ) combobox.setEditable(False) return combobox else: - editor = QtWidgets.QLineEdit(parent) + editor = QtWidgets.QLineEdit(parent) editor.setText(str(node.get_data(0))) editor.returnPressed.connect(self.close_and_commit) return editor @@ -448,13 +463,13 @@ def setModelData(self, editor, model, index): """ node = index.internalPointer() - if node.num_columns() > 1 and node.get_parent().get_data(0) != 'type': + if node.num_columns() > 1 and node.get_parent().get_data(0) != "type": selectedIndex = editor.currentIndex() firstItem = node.get_data(0) node.setData(0, str(editor.currentText())) node.setData(selectedIndex, str(firstItem)) - elif node.num_columns() > 1 and node.get_parent().get_data(0) == 'type': + elif node.num_columns() > 1 and node.get_parent().get_data(0) == "type": selectedIndex = editor.currentIndex() firstItem = node.get_data(0) node.setData(0, str(editor.currentText())) @@ -468,26 +483,37 @@ def setModelData(self, editor, model, index): for nd in itemsToEnable: if nd in model.disabledNodes: - model.disabledNodes.remove(nd) + model.disabledNodes.remove(nd) - elif isinstance(editor, QtWidgets.QLineEdit): + elif isinstance(editor, QtWidgets.QLineEdit): node.setData(0, str(editor.text())) else: - QtWidgets.QStyledItemDelegate.setModelData(self, editor, model, index) - - #Custom methods + QtWidgets.QStyledItemDelegate.setModelData( + self, editor, model, index + ) + + # Custom methods def close_and_commit(self): """Saver for the line edits.""" editor = self.sender() if isinstance(editor, QtWidgets.QLineEdit): self.commitData.emit(editor) - self.closeEditor.emit(editor, QtWidgets.QAbstractItemDelegate.NoHint) + self.closeEditor.emit( + editor, QtWidgets.QAbstractItemDelegate.NoHint + ) + class SimpleTableModel(QtCore.QAbstractTableModel): """Create a table data structure for the table widgets.""" - - def __init__(self, headerdata=None, iterate_header=(0, 0), - index_info=None, parent=None, *args): + + def __init__( + self, + headerdata=None, + iterate_header=(0, 0), + index_info=None, + parent=None, + *args, + ): """Call constructor of the QAbstractTableModel and set parameters given by user. """ @@ -496,8 +522,8 @@ def __init__(self, headerdata=None, iterate_header=(0, 0), self.arraydata = [] self.iterate_header = iterate_header self.index_info = index_info - - #Implementing methods mandatory for subclassing QAbstractTableModel + + # Implementing methods mandatory for subclassing QAbstractTableModel def rowCount(self, parent=QtCore.QModelIndex()): """Return number of rows.""" return len(self.arraydata[0]) @@ -518,7 +544,10 @@ def headerData(self, section, orientation, role=QtCore.Qt.DisplayRole): return self.headerdata[0][0] + str(section + 1) else: return self.headerdata[0][section] - elif orientation == QtCore.Qt.Horizontal and role == QtCore.Qt.DisplayRole: + elif ( + orientation == QtCore.Qt.Horizontal + and role == QtCore.Qt.DisplayRole + ): if self.iterate_header[1] == 1: return self.headerdata[1][0] + str(section + 1) elif self.iterate_header[1] == 2: @@ -534,7 +563,7 @@ def data(self, index, role=QtCore.Qt.DisplayRole): return None elif role != QtCore.Qt.DisplayRole: return None - return (self.arraydata[index.column()][index.row()]) + return self.arraydata[index.column()][index.row()] def setData(self, index, value, role=QtCore.Qt.EditRole): """Change the data in the model for specified index and role @@ -544,19 +573,21 @@ def setData(self, index, value, role=QtCore.Qt.EditRole): elif role != QtCore.Qt.EditRole: return False self.arraydata[index.column()][index.row()] = value - - self.dataChanged=QtCore.Signal(QtGui.QModelIndex(),QtGui.QModelIndex()) + + self.dataChanged = QtCore.Signal( + QtGui.QModelIndex(), QtGui.QModelIndex() + ) self.dataChanged.emit(index, index) return True - #Methods used to inderact with the SimpleTableModel + # Methods used to inderact with the SimpleTableModel def update_table(self): """Update table to set all the new data.""" for r in range(self.rowCount()): for c in range(self.columnCount()): index = self.createIndex(r, c) self.setData(index, self.arraydata[c][r]) - + def add_data(self, datain): """Add data to the model.""" self.arraydata.append(datain) diff --git a/tardis/gui/interface.py b/tardis/gui/interface.py index ee3de06bc8f..dce94639fd0 100644 --- a/tardis/gui/interface.py +++ b/tardis/gui/interface.py @@ -1,24 +1,30 @@ import os -if os.environ.get('QT_API', None)=='pyqt': + +if os.environ.get("QT_API", None) == "pyqt": from PyQt5 import QtCore, QtWidgets -elif os.environ.get('QT_API', None)=='pyside': - from PySide2 import QtCore,QtWidgets +elif os.environ.get("QT_API", None) == "pyside": + from PySide2 import QtCore, QtWidgets else: - raise ImportError('QT_API was not set! Please exit the IPython console\n' - ' and at the bash prompt use : \n\n export QT_API=pyside \n or\n' - ' export QT_API=pyqt \n\n For more information refer to user guide.') + raise ImportError( + "QT_API was not set! Please exit the IPython console\n" + " and at the bash prompt use : \n\n export QT_API=pyside \n or\n" + " export QT_API=pyqt \n\n For more information refer to user guide." + ) import sys + try: from IPython.lib.guisupport import get_app_qt5, start_event_loop_qt5 from IPython.lib.guisupport import is_event_loop_running_qt5 - importFailed = False + + importFailed = False except ImportError: importFailed = True -from tardis.gui.widgets import Tardis +from tardis.gui.widgets import Tardis from tardis.gui.datahandler import SimpleTableModel from tardis import run_tardis - + + def show(model): """Take an instance of tardis model and display it. @@ -47,14 +53,15 @@ def show(model): else: start_event_loop_qt5(app) - #If the IPython console is being used, this will evaluate to true. - #In that case the window created will be garbage collected unless a - #reference to it is maintained after this function exits. So the win is - #returned. + # If the IPython console is being used, this will evaluate to true. + # In that case the window created will be garbage collected unless a + # reference to it is maintained after this function exits. So the win is + # returned. if is_event_loop_running_qt5(app): return win -if __name__=='__main__': + +if __name__ == "__main__": """When this module is executed as script, take arguments, calculate model and call the show function. @@ -62,5 +69,4 @@ def show(model): yamlfile = sys.argv[1] atomfile = sys.argv[2] mdl = run_tardis(yamlfile, atomfile) - show(mdl) - + show(mdl) diff --git a/tardis/gui/tests/test_gui.py b/tardis/gui/tests/test_gui.py index 233b6ab40f4..d5a7de98e0d 100644 --- a/tardis/gui/tests/test_gui.py +++ b/tardis/gui/tests/test_gui.py @@ -5,29 +5,30 @@ from tardis.simulation import Simulation import astropy.units as u -if 'QT_API' in os.environ: +if "QT_API" in os.environ: from tardis.gui.widgets import Tardis from tardis.gui.datahandler import SimpleTableModel -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def refdata(tardis_ref_data): def get_ref_data(key): - return tardis_ref_data[os.path.join( - 'test_simulation', key)] + return tardis_ref_data[os.path.join("test_simulation", key)] + return get_ref_data -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def config(): return Configuration.from_yaml( - 'tardis/io/tests/data/tardis_configv1_verysimple.yml') + "tardis/io/tests/data/tardis_configv1_verysimple.yml" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def simulation_one_loop( - atomic_data_fname, config, - tardis_ref_data, generate_reference): + atomic_data_fname, config, tardis_ref_data, generate_reference +): config.atom_data = atomic_data_fname config.montecarlo.iterations = 2 config.montecarlo.no_of_packets = int(4e4) @@ -39,7 +40,9 @@ def simulation_one_loop( return simulation -@pytest.mark.skipif('QT_API' not in os.environ, reason="enviroment variable QT_API is not set") +@pytest.mark.skipif( + "QT_API" not in os.environ, reason="enviroment variable QT_API is not set" +) def test_gui(simulation_one_loop): simulation = simulation_one_loop app = QtWidgets.QApplication([]) diff --git a/tardis/gui/widgets.py b/tardis/gui/widgets.py index a49386e08c3..0b26f580d08 100644 --- a/tardis/gui/widgets.py +++ b/tardis/gui/widgets.py @@ -2,20 +2,24 @@ import tardis.util.base -if os.environ.get('QT_API', None)=='pyqt': +if os.environ.get("QT_API", None) == "pyqt": from PyQt5 import QtGui, QtCore, QtWidgets -elif os.environ.get('QT_API', None)=='pyside': - from PySide2 import QtGui, QtCore,QtWidgets +elif os.environ.get("QT_API", None) == "pyside": + from PySide2 import QtGui, QtCore, QtWidgets else: - raise ImportError('QT_API was not set! Please exit the IPython console\n' - ' and at the bash prompt use : \n\n export QT_API=pyside \n or\n' - ' export QT_API=pyqt \n\n For more information refer to user guide.') + raise ImportError( + "QT_API was not set! Please exit the IPython console\n" + " and at the bash prompt use : \n\n export QT_API=pyside \n or\n" + " export QT_API=pyqt \n\n For more information refer to user guide." + ) import matplotlib from matplotlib.figure import * import matplotlib.gridspec as gridspec from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5 import NavigationToolbar2QT as NavigationToolbar +from matplotlib.backends.backend_qt5 import ( + NavigationToolbar2QT as NavigationToolbar, +) from matplotlib import colors from matplotlib.patches import Circle import matplotlib.pylab as plt @@ -24,6 +28,7 @@ import tardis from tardis import analysis, util + class MatplotlibWidget(FigureCanvas): """Canvas to draw graphs on.""" @@ -35,67 +40,80 @@ def __init__(self, tablecreator, parent, fig=None): self.tablecreator = tablecreator self.parent = parent - self.figure = Figure()#(frameon=False,facecolor=(1,1,1)) + self.figure = Figure() # (frameon=False,facecolor=(1,1,1)) self.cid = {} - if fig != 'model': + if fig != "model": self.ax = self.figure.add_subplot(111) else: self.gs = gridspec.GridSpec(2, 1, height_ratios=[1, 3]) self.ax1 = self.figure.add_subplot(self.gs[0]) - self.ax2 = self.figure.add_subplot(self.gs[1])#, aspect='equal') + self.ax2 = self.figure.add_subplot(self.gs[1]) # , aspect='equal') self.cb = None self.span = None super(MatplotlibWidget, self).__init__(self.figure) - super(MatplotlibWidget, self).setSizePolicy(QtWidgets.QSizePolicy.Expanding, - QtWidgets.QSizePolicy.Expanding) + super(MatplotlibWidget, self).setSizePolicy( + QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding + ) super(MatplotlibWidget, self).updateGeometry() - if fig != 'model': + if fig != "model": self.toolbar = NavigationToolbar(self, parent) - self.cid[0] = self.figure.canvas.mpl_connect('pick_event', - self.on_span_pick) + self.cid[0] = self.figure.canvas.mpl_connect( + "pick_event", self.on_span_pick + ) else: - self.cid[0] = self.figure.canvas.mpl_connect('pick_event', - self.on_shell_pick) + self.cid[0] = self.figure.canvas.mpl_connect( + "pick_event", self.on_shell_pick + ) def show_line_info(self): """Show line info for span selected region.""" - self.parent.line_info.append(LineInfo(self.parent, self.span.xy[0][0], - self.span.xy[2][0], self.tablecreator)) + self.parent.line_info.append( + LineInfo( + self.parent, + self.span.xy[0][0], + self.span.xy[2][0], + self.tablecreator, + ) + ) def show_span(self, garbage=0, left=5000, right=10000): """Hide/Show/Change the buttons that show line info in spectrum plot widget. """ - if self.parent.spectrum_span_button.text() == 'Show Wavelength Range': + if self.parent.spectrum_span_button.text() == "Show Wavelength Range": if not self.span: - self.span = self.ax.axvspan(left, right, color='r', alpha=0.3, - picker=self.span_picker) + self.span = self.ax.axvspan( + left, right, color="r", alpha=0.3, picker=self.span_picker + ) else: self.span.set_visible(True) self.parent.spectrum_line_info_button.show() - self.parent.spectrum_span_button.setText('Hide Wavelength Range') + self.parent.spectrum_span_button.setText("Hide Wavelength Range") else: self.span.set_visible(False) self.parent.spectrum_line_info_button.hide() - self.parent.spectrum_span_button.setText('Show Wavelength Range') + self.parent.spectrum_span_button.setText("Show Wavelength Range") self.draw() def on_span_pick(self, event): """Callback to 'pick'(grab with mouse) the span selector tool.""" self.figure.canvas.mpl_disconnect(self.cid[0]) - self.span.set_edgecolor('m') + self.span.set_edgecolor("m") self.span.set_linewidth(5) self.draw() - if event.edge == 'left': - self.cid[1] = self.figure.canvas.mpl_connect('motion_notify_event', - self.on_span_left_motion) - elif event.edge == 'right': - self.cid[1] = self.figure.canvas.mpl_connect('motion_notify_event', - self.on_span_right_motion) - self.cid[2] = self.figure.canvas.mpl_connect('button_press_event', - self.on_span_resized) + if event.edge == "left": + self.cid[1] = self.figure.canvas.mpl_connect( + "motion_notify_event", self.on_span_left_motion + ) + elif event.edge == "right": + self.cid[1] = self.figure.canvas.mpl_connect( + "motion_notify_event", self.on_span_right_motion + ) + self.cid[2] = self.figure.canvas.mpl_connect( + "button_press_event", self.on_span_resized + ) def on_span_left_motion(self, mouseevent): """Update data of span selector tool on left movement of mouse and @@ -122,9 +140,10 @@ def on_span_resized(self, mouseevent): """Redraw the red rectangle to currently selected span.""" self.figure.canvas.mpl_disconnect(self.cid[1]) self.figure.canvas.mpl_disconnect(self.cid[2]) - self.cid[0] = self.figure.canvas.mpl_connect('pick_event', - self.on_span_pick) - self.span.set_edgecolor('r') + self.cid[0] = self.figure.canvas.mpl_connect( + "pick_event", self.on_span_pick + ) + self.span.set_edgecolor("r") self.span.set_linewidth(1) self.draw() @@ -137,9 +156,9 @@ def highlight_shell(self, index): self.parent.tableview.selectRow(index) for i in range(len(self.parent.shells)): if i != index and i != index + 1: - self.parent.shells[i].set_edgecolor('k') + self.parent.shells[i].set_edgecolor("k") else: - self.parent.shells[i].set_edgecolor('w') + self.parent.shells[i].set_edgecolor("w") self.draw() def shell_picker(self, shell, mouseevent): @@ -158,33 +177,39 @@ def span_picker(self, span, mouseevent, tolerance=5): """ left = float(span.xy[0][0]) right = float(span.xy[2][0]) - tolerance = span.axes.transData.inverted().transform((tolerance, 0) - )[0] - span.axes.transData.inverted().transform((0, 0))[0] - event_attributes = {'edge': None} + tolerance = ( + span.axes.transData.inverted().transform((tolerance, 0))[0] + - span.axes.transData.inverted().transform((0, 0))[0] + ) + event_attributes = {"edge": None} if mouseevent.xdata is None: return False, event_attributes if left - tolerance <= mouseevent.xdata <= left + tolerance: - event_attributes['edge'] = 'left' + event_attributes["edge"] = "left" return True, event_attributes elif right - tolerance <= mouseevent.xdata <= right + tolerance: - event_attributes['edge'] = 'right' + event_attributes["edge"] = "right" return True, event_attributes return False, event_attributes + class Shell(matplotlib.patches.Wedge): """A data holder to store measurements of shells that will be drawn in the plot. """ + def __init__(self, index, center, r_inner, r_outer, **kwargs): - super(Shell, self).__init__(center, r_outer, 0, 90, - width=r_outer - r_inner, **kwargs) + super(Shell, self).__init__( + center, r_outer, 0, 90, width=r_outer - r_inner, **kwargs + ) self.index = index self.center = center self.r_outer = r_outer self.r_inner = r_inner self.width = r_outer - r_inner + class ConfigEditor(QtWidgets.QWidget): """The configuration editor widget. @@ -206,120 +231,151 @@ def __init__(self, yamlconfigfile, parent=None): """ super(ConfigEditor, self).__init__(parent) - #Configurations from the input and template + # Configurations from the input and template configDict = yaml.load(open(yamlconfigfile), Loader=yaml.CLoader) - templatedictionary ={'tardis_config_version':[True, 'v1.0'], - 'supernova':{ 'luminosity_requested':[True, '1 solLum'], - 'time_explosion':[True, None], - 'distance':[False, None], - 'luminosity_wavelength_start':[False, '0 angstrom'], - 'luminosity_wavelength_end':[False, 'inf angstrom'], - }, - 'atom_data':[True,'File Browser'], - 'plasma':{ 'initial_t_inner':[False, '-1K'], - 'initial_t_rad':[False,'10000K'], - 'disable_electron_scattering':[False, False], - 'ionization':[True, None], - 'excitation':[True, None], - 'radiative_rates_type':[True, None], - 'line_interaction_type':[True, None], - 'w_epsilon':[False, 1e-10], - 'delta_treatment':[False, None], - 'nlte':{ 'species':[False, []], - 'coronal_approximation':[False, False], - 'classical_nebular':[False, False] - } - }, - 'model':{ 'structure':{'type':[True, ['file|_:_|filename|_:_|' - 'filetype|_:_|v_inner_boundary|_:_|v_outer_boundary', - 'specific|_:_|velocity|_:_|density']], - 'filename':[True, None], - 'filetype':[True, None], - 'v_inner_boundary':[False, '0 km/s'], - 'v_outer_boundary':[False, 'inf km/s'], - 'velocity':[True, None], - 'density':{ 'type':[True, ['branch85_w7|_:_|w7_time_0' - '|_:_|w7_time_0|_:_|w7_time_0', - 'exponential|_:_|time_0|_:_|rho_0|_:_|' - 'v_0','power_law|_:_|time_0|_:_|rho_0' - '|_:_|v_0|_:_|exponent','uniform|_:_|value']], - 'w7_time_0':[False, '0.000231481 day'], - 'w7_rho_0':[False, '3e29 g/cm^3'], - 'w7_v_0': [False, '1 km/s'], - 'time_0':[True, None], - 'rho_0':[True, None], - 'v_0': [True, None], - 'exponent': [True, None], - 'value':[True, None] - } - }, - 'abundances':{ 'type':[True, ['file|_:_|filetype|_:_|' - 'filename', 'uniform']], - 'filename':[True, None], - 'filetype':[False, None] - } + templatedictionary = { + "tardis_config_version": [True, "v1.0"], + "supernova": { + "luminosity_requested": [True, "1 solLum"], + "time_explosion": [True, None], + "distance": [False, None], + "luminosity_wavelength_start": [False, "0 angstrom"], + "luminosity_wavelength_end": [False, "inf angstrom"], + }, + "atom_data": [True, "File Browser"], + "plasma": { + "initial_t_inner": [False, "-1K"], + "initial_t_rad": [False, "10000K"], + "disable_electron_scattering": [False, False], + "ionization": [True, None], + "excitation": [True, None], + "radiative_rates_type": [True, None], + "line_interaction_type": [True, None], + "w_epsilon": [False, 1e-10], + "delta_treatment": [False, None], + "nlte": { + "species": [False, []], + "coronal_approximation": [False, False], + "classical_nebular": [False, False], + }, + }, + "model": { + "structure": { + "type": [ + True, + [ + "file|_:_|filename|_:_|" + "filetype|_:_|v_inner_boundary|_:_|v_outer_boundary", + "specific|_:_|velocity|_:_|density", + ], + ], + "filename": [True, None], + "filetype": [True, None], + "v_inner_boundary": [False, "0 km/s"], + "v_outer_boundary": [False, "inf km/s"], + "velocity": [True, None], + "density": { + "type": [ + True, + [ + "branch85_w7|_:_|w7_time_0" + "|_:_|w7_time_0|_:_|w7_time_0", + "exponential|_:_|time_0|_:_|rho_0|_:_|" "v_0", + "power_law|_:_|time_0|_:_|rho_0" + "|_:_|v_0|_:_|exponent", + "uniform|_:_|value", + ], + ], + "w7_time_0": [False, "0.000231481 day"], + "w7_rho_0": [False, "3e29 g/cm^3"], + "w7_v_0": [False, "1 km/s"], + "time_0": [True, None], + "rho_0": [True, None], + "v_0": [True, None], + "exponent": [True, None], + "value": [True, None], + }, + }, + "abundances": { + "type": [ + True, + ["file|_:_|filetype|_:_|" "filename", "uniform"], + ], + "filename": [True, None], + "filetype": [False, None], + }, + }, + "montecarlo": { + "seed": [False, 23111963], + "no_of_packets": [True, None], + "iterations": [True, None], + "black_body_sampling": { + "start": "1 angstrom", + "stop": "1000000 angstrom", + "num": "1.e+6", + }, + "last_no_of_packets": [False, -1], + "no_of_virtual_packets": [False, 0], + "enable_reflective_inner_boundary": [False, False], + "inner_boundary_albedo": [False, 0.0], + "convergence_strategy": { + "type": [ + True, + [ + "damped|_:_|damping_constant|_:_|t_inner|_:_|" + "t_rad|_:_|w|_:_|lock_t_inner_cycles|_:_|" + "t_inner_update_exponent", + "specific|_:_|threshold" + "|_:_|fraction|_:_|hold_iterations|_:_|t_inner" + "|_:_|t_rad|_:_|w|_:_|lock_t_inner_cycles|_:_|" + "damping_constant|_:_|t_inner_update_exponent", + ], + ], + "t_inner_update_exponent": [False, -0.5], + "lock_t_inner_cycles": [False, 1], + "hold_iterations": [True, 3], + "fraction": [True, 0.8], + "damping_constant": [False, 0.5], + "threshold": [True, None], + "t_inner": { + "damping_constant": [False, 0.5], + "threshold": [False, None], + }, + "t_rad": { + "damping_constant": [False, 0.5], + "threshold": [True, None], + }, + "w": { + "damping_constant": [False, 0.5], + "threshold": [True, None], }, - 'montecarlo':{'seed':[False, 23111963], - 'no_of_packets':[True, None], - 'iterations':[True, None], - 'black_body_sampling':{ - 'start': '1 angstrom', - 'stop': '1000000 angstrom', - 'num': '1.e+6', - }, - 'last_no_of_packets':[False, -1], - 'no_of_virtual_packets':[False, 0], - 'enable_reflective_inner_boundary':[False, False], - 'inner_boundary_albedo':[False, 0.0], - 'convergence_strategy':{ 'type':[True, - ['damped|_:_|damping_constant|_:_|t_inner|_:_|' - 't_rad|_:_|w|_:_|lock_t_inner_cycles|_:_|' - 't_inner_update_exponent','specific|_:_|threshold' - '|_:_|fraction|_:_|hold_iterations|_:_|t_inner' - '|_:_|t_rad|_:_|w|_:_|lock_t_inner_cycles|_:_|' - 'damping_constant|_:_|t_inner_update_exponent']], - 't_inner_update_exponent':[False, -0.5], - 'lock_t_inner_cycles':[False, 1], - 'hold_iterations':[True, 3], - 'fraction':[True, 0.8], - 'damping_constant':[False, 0.5], - 'threshold':[True, None], - 't_inner':{ 'damping_constant':[False, 0.5], - 'threshold': [False, None] - }, - 't_rad':{'damping_constant':[False, 0.5], - 'threshold':[True, None] - }, - 'w':{'damping_constant': [False, 0.5], - 'threshold': [True, None] - } - } - }, - 'spectrum':[True, None] - } + }, + }, + "spectrum": [True, None], + } self.match_dicts(configDict, templatedictionary) self.layout = QtWidgets.QVBoxLayout() - #Make tree + # Make tree self.trmodel = TreeModel(templatedictionary) self.colView = QtWidgets.QColumnView() self.colView.setModel(self.trmodel) - #Five columns of width 256 each can be visible at once - self.colView.setFixedWidth(256*5) + # Five columns of width 256 each can be visible at once + self.colView.setFixedWidth(256 * 5) self.colView.setItemDelegate(TreeDelegate(self)) self.layout.addWidget(self.colView) - #Recalculate button - button = QtWidgets.QPushButton('Recalculate') + # Recalculate button + button = QtWidgets.QPushButton("Recalculate") button.setFixedWidth(90) self.layout.addWidget(button) button.clicked.connect(self.recalculate) - #Finally put them all in + # Finally put them all in self.setLayout(self.layout) - def match_dicts(self, dict1, dict2): #dict1<=dict2 + def match_dicts(self, dict1, dict2): # dict1<=dict2 """Compare and combine two dictionaries. If there are new keys in `dict1` then they are appended to `dict2`. @@ -357,13 +413,14 @@ def match_dicts(self, dict1, dict2): #dict1<=dict2 elif isinstance(dict2[key], list): if isinstance(dict2[key][1], list): - #options = dict2[key][1] #This is passed by reference. - #So copy the list manually. - options = [dict2[key][1][i] for i in range( - len(dict2[key][1]))] + # options = dict2[key][1] #This is passed by reference. + # So copy the list manually. + options = [ + dict2[key][1][i] for i in range(len(dict2[key][1])) + ] for i in range(len(options)): - options[i] = options[i].split('|_:_|')[0] + options[i] = options[i].split("|_:_|")[0] optionselected = dict1[key] @@ -374,13 +431,14 @@ def match_dicts(self, dict1, dict2): #dict1<=dict2 dict2[key][1][0] = dict2[key][1][indexofselected] dict2[key][1][indexofselected] = temp - else: - print('The selected and available options') + print("The selected and available options") print(optionselected) print(options) - raise IOError("An invalid option was" - " provided in the input file") + raise IOError( + "An invalid option was" + " provided in the input file" + ) else: dict2[key] = dict1[key] @@ -394,6 +452,7 @@ def recalculate(self): """ pass + class ModelViewer(QtWidgets.QWidget): """The widget that holds all the plots and tables that visualize the data in the tardis model. This is also appended to the stacked @@ -405,49 +464,55 @@ def __init__(self, tablecreator, parent=None): """Create all widgets that are children of ModelViewer.""" QtWidgets.QWidget.__init__(self, parent) - #Data structures + # Data structures self.model = None self.shell_info = {} self.line_info = [] - #functions + # functions self.createTable = tablecreator - #Shells widget + # Shells widget self.shellWidget = self.make_shell_widget() - #Spectrum widget + # Spectrum widget self.spectrumWidget = self.make_spectrum_widget() - #Plot tab widget + # Plot tab widget self.plotTabWidget = QtWidgets.QTabWidget() - self.plotTabWidget.addTab(self.shellWidget,"&Shells") + self.plotTabWidget.addTab(self.shellWidget, "&Shells") self.plotTabWidget.addTab(self.spectrumWidget, "S&pectrum") - #Table widget - self.tablemodel = self.createTable([['Shell: '], ["Rad. temp", "Ws", "V"]], - (1, 0)) + # Table widget + self.tablemodel = self.createTable( + [["Shell: "], ["Rad. temp", "Ws", "V"]], (1, 0) + ) self.tableview = QtWidgets.QTableView() self.tableview.setMinimumWidth(200) - self.sectionClicked=QtCore.Signal(int) - self.tableview.verticalHeader().sectionClicked.connect(self.graph.highlight_shell) + self.sectionClicked = QtCore.Signal(int) + self.tableview.verticalHeader().sectionClicked.connect( + self.graph.highlight_shell + ) - self.sectionDoubleClicked=QtCore.Signal(int) - self.tableview.verticalHeader().sectionDoubleClicked.connect(self.on_header_double_clicked) + self.sectionDoubleClicked = QtCore.Signal(int) + self.tableview.verticalHeader().sectionDoubleClicked.connect( + self.on_header_double_clicked + ) - #Label for text output + # Label for text output self.outputLabel = QtWidgets.QLabel() - self.outputLabel.setFrameStyle(QtWidgets.QFrame.StyledPanel | - QtWidgets.QFrame.Sunken) + self.outputLabel.setFrameStyle( + QtWidgets.QFrame.StyledPanel | QtWidgets.QFrame.Sunken + ) self.outputLabel.setStyleSheet("QLabel{background-color:white;}") - #Group boxes + # Group boxes graphsBox = QtWidgets.QGroupBox("Visualized results") textsBox = QtWidgets.QGroupBox("Model parameters") tableBox = QtWidgets.QGroupBox("Tabulated results") - #For textbox + # For textbox textlayout = QtWidgets.QHBoxLayout() textlayout.addWidget(self.outputLabel) @@ -474,19 +539,20 @@ def fill_output_label(self): quick user access. """ - labeltext = 'Iterations requested: {}
Iterations executed: {}
\ + labeltext = "Iterations requested: {}
Iterations executed: {}
\ Model converged : {}
Simulation Time : {} s
\ Inner Temperature : {} K
Number of packets : {}
\ - Inner Luminosity : {}'\ - .format(self.model.iterations, - self.model.iterations_executed, - 'True' if - self.model.converged else - 'False', - self.model.runner.time_of_simulation.value, - self.model.model.t_inner.value, - self.model.last_no_of_packets, - self.model.runner.calculate_luminosity_inner(self.model.model)) + Inner Luminosity : {}".format( + self.model.iterations, + self.model.iterations_executed, + 'True' + if self.model.converged + else 'False', + self.model.runner.time_of_simulation.value, + self.model.model.t_inner.value, + self.model.last_no_of_packets, + self.model.runner.calculate_luminosity_inner(self.model.model), + ) self.outputLabel.setText(labeltext) def make_shell_widget(self): @@ -494,19 +560,21 @@ def make_shell_widget(self): container widget. Return the container widget. """ - #Widgets for plot of shells - self.graph = MatplotlibWidget(self.createTable, self, 'model') - self.graph_label = QtWidgets.QLabel('Select Property:') + # Widgets for plot of shells + self.graph = MatplotlibWidget(self.createTable, self, "model") + self.graph_label = QtWidgets.QLabel("Select Property:") self.graph_button = QtWidgets.QToolButton() - self.graph_button.setText('Rad. temp') + self.graph_button.setText("Rad. temp") self.graph_button.setPopupMode(QtWidgets.QToolButton.MenuButtonPopup) self.graph_button.setMenu(QtWidgets.QMenu(self.graph_button)) - self.graph_button.menu().addAction('Rad. temp').triggered.connect( - self.change_graph_to_t_rads) - self.graph_button.menu().addAction('Ws').triggered.connect( - self.change_graph_to_ws) - - #Layouts: bottom up + self.graph_button.menu().addAction("Rad. temp").triggered.connect( + self.change_graph_to_t_rads + ) + self.graph_button.menu().addAction("Ws").triggered.connect( + self.change_graph_to_ws + ) + + # Layouts: bottom up self.graph_subsublayout = QtWidgets.QHBoxLayout() self.graph_subsublayout.addWidget(self.graph_label) self.graph_subsublayout.addWidget(self.graph_button) @@ -525,20 +593,26 @@ def make_spectrum_widget(self): """ self.spectrum = MatplotlibWidget(self.createTable, self) - self.spectrum_label = QtWidgets.QLabel('Select Spectrum:') + self.spectrum_label = QtWidgets.QLabel("Select Spectrum:") self.spectrum_button = QtWidgets.QToolButton() - self.spectrum_button.setText('spec_flux_angstrom') + self.spectrum_button.setText("spec_flux_angstrom") self.spectrum_button.setPopupMode(QtWidgets.QToolButton.MenuButtonPopup) self.spectrum_button.setMenu(QtWidgets.QMenu(self.spectrum_button)) - self.spectrum_button.menu().addAction('spec_flux_angstrom' - ).triggered.connect(self.change_spectrum_to_spec_flux_angstrom) - self.spectrum_button.menu().addAction('spec_virtual_flux_angstrom' - ).triggered.connect(self.change_spectrum_to_spec_virtual_flux_angstrom) - self.spectrum_span_button = QtWidgets.QPushButton('Show Wavelength Range') + self.spectrum_button.menu().addAction( + "spec_flux_angstrom" + ).triggered.connect(self.change_spectrum_to_spec_flux_angstrom) + self.spectrum_button.menu().addAction( + "spec_virtual_flux_angstrom" + ).triggered.connect(self.change_spectrum_to_spec_virtual_flux_angstrom) + self.spectrum_span_button = QtWidgets.QPushButton( + "Show Wavelength Range" + ) self.spectrum_span_button.clicked.connect(self.spectrum.show_span) - self.spectrum_line_info_button = QtWidgets.QPushButton('Show Line Info') + self.spectrum_line_info_button = QtWidgets.QPushButton("Show Line Info") self.spectrum_line_info_button.hide() - self.spectrum_line_info_button.clicked.connect(self.spectrum.show_line_info) + self.spectrum_line_info_button.clicked.connect( + self.spectrum.show_line_info + ) self.spectrum_subsublayout = QtWidgets.QHBoxLayout() self.spectrum_subsublayout.addWidget(self.spectrum_span_button) @@ -563,10 +637,10 @@ def update_data(self, model=None): for index in self.shell_info.keys(): self.shell_info[index].update_tables() self.plot_model() - if self.graph_button.text == 'Ws': + if self.graph_button.text == "Ws": self.change_graph_to_ws() self.plot_spectrum() - if self.spectrum_button.text == 'spec_virtual_flux_angstrom': + if self.spectrum_button.text == "spec_virtual_flux_angstrom": self.change_spectrum_to_spec_virtual_flux_angstrom() self.show() @@ -582,24 +656,28 @@ def change_spectrum_to_spec_virtual_flux_angstrom(self): """Change the spectrum data to the virtual spectrum.""" if self.model.runner.spectrum_virtual.luminosity_density_lambda is None: luminosity_density_lambda = np.zeros_like( - self.model.runner.spectrum_virtual.wavelength) + self.model.runner.spectrum_virtual.wavelength + ) else: - luminosity_density_lambda = \ - self.model.runner.spectrum_virtual.luminosity_density_lambda.value + luminosity_density_lambda = ( + self.model.runner.spectrum_virtual.luminosity_density_lambda.value + ) - self.change_spectrum(luminosity_density_lambda, 'spec_flux_angstrom') + self.change_spectrum(luminosity_density_lambda, "spec_flux_angstrom") def change_spectrum_to_spec_flux_angstrom(self): """Change spectrum data back from virtual spectrum. (See the method above).""" if self.model.runner.spectrum.luminosity_density_lambda is None: luminosity_density_lambda = np.zeros_like( - self.model.runner.spectrum.wavelength) + self.model.runner.spectrum.wavelength + ) else: - luminosity_density_lambda = \ - self.model.runner.spectrum.luminosity_density_lambda.value + luminosity_density_lambda = ( + self.model.runner.spectrum.luminosity_density_lambda.value + ) - self.change_spectrum(luminosity_density_lambda, 'spec_flux_angstrom') + self.change_spectrum(luminosity_density_lambda, "spec_flux_angstrom") def change_spectrum(self, data, name): """Replot the spectrum plot using the data provided. Called @@ -615,27 +693,29 @@ def change_spectrum(self, data, name): def plot_spectrum(self): """Plot the spectrum and add labels to the graph.""" self.spectrum.ax.clear() - self.spectrum.ax.set_title('Spectrum') - self.spectrum.ax.set_xlabel('Wavelength (A)') - self.spectrum.ax.set_ylabel('Intensity') + self.spectrum.ax.set_title("Spectrum") + self.spectrum.ax.set_xlabel("Wavelength (A)") + self.spectrum.ax.set_ylabel("Intensity") wavelength = self.model.runner.spectrum.wavelength.value if self.model.runner.spectrum.luminosity_density_lambda is None: luminosity_density_lambda = np.zeros_like(wavelength) else: - luminosity_density_lambda =\ - self.model.runner.spectrum.luminosity_density_lambda.value + luminosity_density_lambda = ( + self.model.runner.spectrum.luminosity_density_lambda.value + ) - self.spectrum.dataplot = self.spectrum.ax.plot(wavelength, - luminosity_density_lambda, label='b') + self.spectrum.dataplot = self.spectrum.ax.plot( + wavelength, luminosity_density_lambda, label="b" + ) self.spectrum.draw() def change_graph_to_ws(self): """Change the shell plot to show dilution factor.""" - self.change_graph(self.model.model.w, 'Ws', '') + self.change_graph(self.model.model.w, "Ws", "") def change_graph_to_t_rads(self): """Change the graph back to radiation Temperature.""" - self.change_graph(self.model.model.t_rad.value, 't_rad', '(K)') + self.change_graph(self.model.model.t_rad.value, "t_rad", "(K)") def change_graph(self, data, name, unit): """Called to change the shell plot by the two methods above.""" @@ -643,15 +723,15 @@ def change_graph(self, data, name, unit): self.graph.dataplot[0].set_ydata(data) self.graph.ax1.relim() self.graph.ax1.autoscale() - self.graph.ax1.set_title(name + ' vs Shell') - self.graph.ax1.set_ylabel(name + ' ' + unit) + self.graph.ax1.set_title(name + " vs Shell") + self.graph.ax1.set_ylabel(name + " " + unit) normalizer = colors.Normalize(vmin=data.min(), vmax=data.max()) color_map = plt.cm.ScalarMappable(norm=normalizer, cmap=plt.cm.jet) color_map.set_array(data) self.graph.cb.set_clim(vmin=data.min(), vmax=data.max()) self.graph.cb.update_normal(color_map) - if unit == '(K)': - unit = 'T (K)' + if unit == "(K)": + unit = "T (K)" self.graph.cb.set_label(unit) for i, item in enumerate(data): self.shells[i].set_facecolor(color_map.to_rgba(item)) @@ -663,52 +743,76 @@ def plot_model(self): """ self.graph.ax1.clear() - self.graph.ax1.set_title('Rad. Temp vs Shell') - self.graph.ax1.set_xlabel('Shell Number') - self.graph.ax1.set_ylabel('Rad. Temp (K)') + self.graph.ax1.set_title("Rad. Temp vs Shell") + self.graph.ax1.set_xlabel("Shell Number") + self.graph.ax1.set_ylabel("Rad. Temp (K)") self.graph.ax1.yaxis.get_major_formatter().set_powerlimits((0, 1)) self.graph.dataplot = self.graph.ax1.plot( - range(len(self.model.model.t_rad.value)), self.model.model.t_rad.value) + range(len(self.model.model.t_rad.value)), + self.model.model.t_rad.value, + ) self.graph.ax2.clear() - self.graph.ax2.set_title('Shell View') + self.graph.ax2.set_title("Shell View") self.graph.ax2.set_xticklabels([]) self.graph.ax2.set_yticklabels([]) self.graph.ax2.grid = True self.shells = [] - t_rad_normalizer = colors.Normalize(vmin=self.model.model.t_rad.value.min(), - vmax=self.model.model.t_rad.value.max()) - t_rad_color_map = plt.cm.ScalarMappable(norm=t_rad_normalizer, - cmap=plt.cm.jet) + t_rad_normalizer = colors.Normalize( + vmin=self.model.model.t_rad.value.min(), + vmax=self.model.model.t_rad.value.max(), + ) + t_rad_color_map = plt.cm.ScalarMappable( + norm=t_rad_normalizer, cmap=plt.cm.jet + ) t_rad_color_map.set_array(self.model.model.t_rad.value) if self.graph.cb: - self.graph.cb.set_clim(vmin=self.model.model.t_rad.value.min(), - vmax=self.model.model.t_rad.value.max()) + self.graph.cb.set_clim( + vmin=self.model.model.t_rad.value.min(), + vmax=self.model.model.t_rad.value.max(), + ) self.graph.cb.update_normal(t_rad_color_map) else: self.graph.cb = self.graph.figure.colorbar(t_rad_color_map) - self.graph.cb.set_label('T (K)') - self.graph.normalizing_factor = 0.2 * ( - self.model.model.r_outer.value[-1] - - self.model.model.r_inner.value[0]) / ( - self.model.model.r_inner.value[0]) - - #self.graph.normalizing_factor = 8e-16 + self.graph.cb.set_label("T (K)") + self.graph.normalizing_factor = ( + 0.2 + * ( + self.model.model.r_outer.value[-1] + - self.model.model.r_inner.value[0] + ) + / (self.model.model.r_inner.value[0]) + ) + + # self.graph.normalizing_factor = 8e-16 for i, t_rad in enumerate(self.model.model.t_rad.value): - r_inner = (self.model.model.r_inner.value[i] * - self.graph.normalizing_factor) - r_outer = (self.model.model.r_outer.value[i] * - self.graph.normalizing_factor) - self.shells.append(Shell(i, (0,0), r_inner, r_outer, - facecolor=t_rad_color_map.to_rgba(t_rad), - picker=self.graph.shell_picker)) + r_inner = ( + self.model.model.r_inner.value[i] + * self.graph.normalizing_factor + ) + r_outer = ( + self.model.model.r_outer.value[i] + * self.graph.normalizing_factor + ) + self.shells.append( + Shell( + i, + (0, 0), + r_inner, + r_outer, + facecolor=t_rad_color_map.to_rgba(t_rad), + picker=self.graph.shell_picker, + ) + ) self.graph.ax2.add_patch(self.shells[i]) - self.graph.ax2.set_xlim(0, - self.model.model.r_outer.value[-1] * - self.graph.normalizing_factor) - self.graph.ax2.set_ylim(0, - self.model.model.r_outer.value[-1] * - self.graph.normalizing_factor) + self.graph.ax2.set_xlim( + 0, + self.model.model.r_outer.value[-1] * self.graph.normalizing_factor, + ) + self.graph.ax2.set_ylim( + 0, + self.model.model.r_outer.value[-1] * self.graph.normalizing_factor, + ) self.graph.figure.tight_layout() self.graph.draw() @@ -716,6 +820,7 @@ def on_header_double_clicked(self, index): """Callback to get counts for different Z from table.""" self.shell_info[index] = ShellInfo(index, self.createTable, self) + class ShellInfo(QtWidgets.QDialog): """Dialog to display Shell abundances.""" @@ -727,19 +832,21 @@ def __init__(self, index, tablecreator, parent=None): self.parent = parent self.shell_index = index self.setGeometry(400, 150, 200, 400) - self.setWindowTitle('Shell %d Abundances' % (self.shell_index + 1)) + self.setWindowTitle("Shell %d Abundances" % (self.shell_index + 1)) self.atomstable = QtWidgets.QTableView() self.ionstable = QtWidgets.QTableView() self.levelstable = QtWidgets.QTableView() - self.sectionClicked=QtCore.Signal(int) - self.atomstable.verticalHeader().sectionClicked.connect(self.on_atom_header_double_clicked) - - - self.table1_data = self.parent.model.plasma.abundance[ - self.shell_index] - self.atomsdata = self.createTable([['Z = '], ['Count (Shell %d)' % ( - self.shell_index + 1)]], iterate_header=(2, 0), - index_info=self.table1_data.index.values.tolist()) + self.sectionClicked = QtCore.Signal(int) + self.atomstable.verticalHeader().sectionClicked.connect( + self.on_atom_header_double_clicked + ) + + self.table1_data = self.parent.model.plasma.abundance[self.shell_index] + self.atomsdata = self.createTable( + [["Z = "], ["Count (Shell %d)" % (self.shell_index + 1)]], + iterate_header=(2, 0), + index_info=self.table1_data.index.values.tolist(), + ) self.ionsdata = None self.levelsdata = None self.atomsdata.add_data(self.table1_data.values.tolist()) @@ -759,22 +866,30 @@ def on_atom_header_double_clicked(self, index): ion populations.""" self.current_atom_index = self.table1_data.index.values.tolist()[index] self.table2_data = self.parent.model.plasma.ion_number_density[ - self.shell_index].ix[self.current_atom_index] - self.ionsdata = self.createTable([['Ion: '], - ['Count (Z = %d)' % self.current_atom_index]], + self.shell_index + ].ix[self.current_atom_index] + self.ionsdata = self.createTable( + [["Ion: "], ["Count (Z = %d)" % self.current_atom_index]], iterate_header=(2, 0), - index_info=self.table2_data.index.values.tolist()) + index_info=self.table2_data.index.values.tolist(), + ) normalized_data = [] for item in self.table2_data.values: - normalized_data.append(float(item / - self.parent.model.plasma.number_density[self.shell_index] - .ix[self.current_atom_index])) - + normalized_data.append( + float( + item + / self.parent.model.plasma.number_density[ + self.shell_index + ].ix[self.current_atom_index] + ) + ) self.ionsdata.add_data(normalized_data) self.ionstable.setModel(self.ionsdata) - self.sectionClicked=QtCore.Signal(int) - self.ionstable.verticalHeader().sectionClicked.connect(self.on_ion_header_double_clicked) + self.sectionClicked = QtCore.Signal(int) + self.ionstable.verticalHeader().sectionClicked.connect( + self.on_ion_header_double_clicked + ) self.levelstable.hide() self.ionstable.setColumnWidth(0, 120) self.ionstable.show() @@ -785,15 +900,18 @@ def on_ion_header_double_clicked(self, index): """Called on double click of ion headers to show level populations.""" self.current_ion_index = self.table2_data.index.values.tolist()[index] self.table3_data = self.parent.model.plasma.level_number_density[ - self.shell_index].ix[self.current_atom_index, self.current_ion_index] - self.levelsdata = self.createTable([['Level: '], - ['Count (Ion %d)' % self.current_ion_index]], + self.shell_index + ].ix[self.current_atom_index, self.current_ion_index] + self.levelsdata = self.createTable( + [["Level: "], ["Count (Ion %d)" % self.current_ion_index]], iterate_header=(2, 0), - index_info=self.table3_data.index.values.tolist()) + index_info=self.table3_data.index.values.tolist(), + ) normalized_data = [] for item in self.table3_data.values.tolist(): - normalized_data.append(float(item / - self.table2_data.ix[self.current_ion_index])) + normalized_data.append( + float(item / self.table2_data.ix[self.current_ion_index]) + ) self.levelsdata.add_data(normalized_data) self.levelstable.setModel(self.levelsdata) self.levelstable.setColumnWidth(0, 120) @@ -804,8 +922,9 @@ def on_ion_header_double_clicked(self, index): def update_tables(self): """Update table data for shell info viewer.""" self.table1_data = self.parent.model.plasma.number_density[ - self.shell_index] - self.atomsdata.index_info=self.table1_data.index.values.tolist() + self.shell_index + ] + self.atomsdata.index_info = self.table1_data.index.values.tolist() self.atomsdata.arraydata = [] self.atomsdata.add_data(self.table1_data.values.tolist()) self.atomsdata.update_table() @@ -814,8 +933,10 @@ def update_tables(self): self.setGeometry(400, 150, 200, 400) self.show() + class LineInfo(QtWidgets.QDialog): """Dialog to show the line info used by spectrum widget.""" + def __init__(self, parent, wavelength_start, wavelength_end, tablecreator): """Create the dialog and set data in it from the model. Show widget.""" @@ -823,28 +944,47 @@ def __init__(self, parent, wavelength_start, wavelength_end, tablecreator): self.createTable = tablecreator self.parent = parent self.setGeometry(180 + len(self.parent.line_info) * 20, 150, 250, 400) - self.setWindowTitle('Line Interaction: %.2f - %.2f (A) ' % ( - wavelength_start, wavelength_end,)) + self.setWindowTitle( + "Line Interaction: %.2f - %.2f (A) " + % (wavelength_start, wavelength_end,) + ) self.layout = QtWidgets.QVBoxLayout() packet_nu_line_interaction = analysis.LastLineInteraction.from_model( - self.parent.model) - packet_nu_line_interaction.packet_filter_mode = 'packet_nu' - packet_nu_line_interaction.wavelength_start = wavelength_start * u.angstrom + self.parent.model + ) + packet_nu_line_interaction.packet_filter_mode = "packet_nu" + packet_nu_line_interaction.wavelength_start = ( + wavelength_start * u.angstrom + ) packet_nu_line_interaction.wavelength_end = wavelength_end * u.angstrom line_in_nu_line_interaction = analysis.LastLineInteraction.from_model( - self.parent.model) - line_in_nu_line_interaction.packet_filter_mode = 'line_in_nu' - line_in_nu_line_interaction.wavelength_start = wavelength_start * u.angstrom + self.parent.model + ) + line_in_nu_line_interaction.packet_filter_mode = "line_in_nu" + line_in_nu_line_interaction.wavelength_start = ( + wavelength_start * u.angstrom + ) line_in_nu_line_interaction.wavelength_end = wavelength_end * u.angstrom - - self.layout.addWidget(LineInteractionTables(packet_nu_line_interaction, - self.parent.model.plasma.atomic_data.atom_data, self.parent.model.plasma.lines, 'filtered by frequency of packet', - self.createTable)) - self.layout.addWidget(LineInteractionTables(line_in_nu_line_interaction, - self.parent.model.plasma.atomic_data.atom_data, self.parent.model.plasma.lines, - 'filtered by frequency of line interaction', self.createTable)) + self.layout.addWidget( + LineInteractionTables( + packet_nu_line_interaction, + self.parent.model.plasma.atomic_data.atom_data, + self.parent.model.plasma.lines, + "filtered by frequency of packet", + self.createTable, + ) + ) + self.layout.addWidget( + LineInteractionTables( + line_in_nu_line_interaction, + self.parent.model.plasma.atomic_data.atom_data, + self.parent.model.plasma.lines, + "filtered by frequency of line interaction", + self.createTable, + ) + ) self.setLayout(self.layout) self.show() @@ -856,35 +996,51 @@ def get_data(self, wavelength_start, wavelength_end): """ self.wavelength_start = wavelength_start * u.angstrom self.wavelength_end = wavelength_end * u.angstrom - last_line_in_ids, last_line_out_ids = analysis.get_last_line_interaction( - self.wavelength_start, self.wavelength_end, self.parent.model) + ( + last_line_in_ids, + last_line_out_ids, + ) = analysis.get_last_line_interaction( + self.wavelength_start, self.wavelength_end, self.parent.model + ) self.last_line_in, self.last_line_out = ( self.parent.model.atom_data.lines.ix[last_line_in_ids], - self.parent.model.atom_data.lines.ix[last_line_out_ids]) - self.grouped_lines_in, self.grouped_lines_out = (self.last_line_in.groupby( - ['atomic_number', 'ion_number']), - self.last_line_out.groupby(['atomic_number', 'ion_number'])) - self.ions_in, self.ions_out = (self.grouped_lines_in.groups.keys(), - self.grouped_lines_out.groups.keys()) + self.parent.model.atom_data.lines.ix[last_line_out_ids], + ) + self.grouped_lines_in, self.grouped_lines_out = ( + self.last_line_in.groupby(["atomic_number", "ion_number"]), + self.last_line_out.groupby(["atomic_number", "ion_number"]), + ) + self.ions_in, self.ions_out = ( + self.grouped_lines_in.groups.keys(), + self.grouped_lines_out.groups.keys(), + ) self.ions_in.sort() self.ions_out.sort() self.header_list = [] - self.ion_table = (self.grouped_lines_in.wavelength.count().astype(float) / - self.grouped_lines_in.wavelength.count().sum()).values.tolist() + self.ion_table = ( + self.grouped_lines_in.wavelength.count().astype(float) + / self.grouped_lines_in.wavelength.count().sum() + ).values.tolist() for z, ion in self.ions_in: - self.header_list.append('Z = %d: Ion %d' % (z, ion)) + self.header_list.append("Z = %d: Ion %d" % (z, ion)) def get_transition_table(self, lines, atom, ion): """Called by the two methods below to get transition table for given lines, atom and ions. """ - grouped = lines.groupby(['atomic_number', 'ion_number']) - transitions_with_duplicates = lines.ix[grouped.groups[(atom, ion)] - ].groupby(['level_number_lower', 'level_number_upper']).groups - transitions = lines.ix[grouped.groups[(atom, ion)] - ].drop_duplicates().groupby(['level_number_lower', - 'level_number_upper']).groups + grouped = lines.groupby(["atomic_number", "ion_number"]) + transitions_with_duplicates = ( + lines.ix[grouped.groups[(atom, ion)]] + .groupby(["level_number_lower", "level_number_upper"]) + .groups + ) + transitions = ( + lines.ix[grouped.groups[(atom, ion)]] + .drop_duplicates() + .groupby(["level_number_lower", "level_number_upper"]) + .groups + ) transitions_count = [] transitions_parsed = [] for item in transitions.values(): @@ -898,8 +1054,16 @@ def get_transition_table(self, lines, atom, ion): for index in range(len(transitions_count)): transitions_count[index] /= float(s) for key, value in transitions.items(): - transitions_parsed.append("%d-%d (%.2f A)" % (key[0], key[1], - self.parent.model.atom_data.lines.ix[value[0]]['wavelength'])) + transitions_parsed.append( + "%d-%d (%.2f A)" + % ( + key[0], + key[1], + self.parent.model.atom_data.lines.ix[value[0]][ + "wavelength" + ], + ) + ) return transitions_parsed, transitions_count def on_atom_clicked(self, index): @@ -907,16 +1071,24 @@ def on_atom_clicked(self, index): dialog created by the spectrum widget. """ - self.transitionsin_parsed, self.transitionsin_count = ( - self.get_transition_table(self.last_line_in, - self.ions_in[index][0], self.ions_in[index][1])) - self.transitionsout_parsed, self.transitionsout_count = ( - self.get_transition_table(self.last_line_out, - self.ions_out[index][0], self.ions_out[index][1])) - self.transitionsindata = self.createTable([self.transitionsin_parsed, - ['Lines In']]) - self.transitionsoutdata = self.createTable([self.transitionsout_parsed, - ['Lines Out']]) + ( + self.transitionsin_parsed, + self.transitionsin_count, + ) = self.get_transition_table( + self.last_line_in, self.ions_in[index][0], self.ions_in[index][1] + ) + ( + self.transitionsout_parsed, + self.transitionsout_count, + ) = self.get_transition_table( + self.last_line_out, self.ions_out[index][0], self.ions_out[index][1] + ) + self.transitionsindata = self.createTable( + [self.transitionsin_parsed, ["Lines In"]] + ) + self.transitionsoutdata = self.createTable( + [self.transitionsout_parsed, ["Lines Out"]] + ) self.transitionsindata.add_data(self.transitionsin_count) self.transitionsoutdata.add_data(self.transitionsout_count) self.transitionsintable.setModel(self.transitionsindata) @@ -931,16 +1103,24 @@ def on_atom_clicked2(self, index): dialog created by the spectrum widget. """ - self.transitionsin_parsed, self.transitionsin_count = ( - self.get_transition_table(self.last_line_in, self.ions_in[index][0], - self.ions_in[index][1])) - self.transitionsout_parsed, self.transitionsout_count = ( - self.get_transition_table(self.last_line_out, - self.ions_out[index][0], self.ions_out[index][1])) - self.transitionsindata = self.createTable([self.transitionsin_parsed, - ['Lines In']]) - self.transitionsoutdata = self.createTable([self.transitionsout_parsed, - ['Lines Out']]) + ( + self.transitionsin_parsed, + self.transitionsin_count, + ) = self.get_transition_table( + self.last_line_in, self.ions_in[index][0], self.ions_in[index][1] + ) + ( + self.transitionsout_parsed, + self.transitionsout_count, + ) = self.get_transition_table( + self.last_line_out, self.ions_out[index][0], self.ions_out[index][1] + ) + self.transitionsindata = self.createTable( + [self.transitionsin_parsed, ["Lines In"]] + ) + self.transitionsoutdata = self.createTable( + [self.transitionsout_parsed, ["Lines Out"]] + ) self.transitionsindata.add_data(self.transitionsin_count) self.transitionsoutdata.add_data(self.transitionsout_count) self.transitionsintable2.setModel(self.transitionsindata) @@ -950,14 +1130,21 @@ def on_atom_clicked2(self, index): self.setGeometry(180 + len(self.parent.line_info) * 20, 150, 750, 400) self.show() + class LineInteractionTables(QtWidgets.QWidget): """Widget to hold the line interaction tables used by LineInfo which in turn is used by spectrum widget. """ - def __init__(self, line_interaction_analysis, atom_data, lines_data, description, - tablecreator): + def __init__( + self, + line_interaction_analysis, + atom_data, + lines_data, + description, + tablecreator, + ): """Create the widget and set data.""" super(LineInteractionTables, self).__init__() self.createTable = tablecreator @@ -967,17 +1154,26 @@ def __init__(self, line_interaction_analysis, atom_data, lines_data, description self.layout = QtWidgets.QHBoxLayout() self.line_interaction_analysis = line_interaction_analysis self.atom_data = atom_data - self.lines_data = lines_data.reset_index().set_index('line_id') - line_interaction_species_group = \ - line_interaction_analysis.last_line_in.groupby(['atomic_number', - 'ion_number']) + self.lines_data = lines_data.reset_index().set_index("line_id") + line_interaction_species_group = line_interaction_analysis.last_line_in.groupby( + ["atomic_number", "ion_number"] + ) self.species_selected = sorted( - line_interaction_species_group.groups.keys()) - species_symbols = [tardis.util.base.species_tuple_to_string(item) for item in self.species_selected] - species_table_model = self.createTable([species_symbols, ['Species']]) + line_interaction_species_group.groups.keys() + ) + species_symbols = [ + tardis.util.base.species_tuple_to_string(item) + for item in self.species_selected + ] + species_table_model = self.createTable([species_symbols, ["Species"]]) species_abundances = ( - line_interaction_species_group.wavelength.count().astype(float) / - line_interaction_analysis.last_line_in.wavelength.count()).astype(float).tolist() + ( + line_interaction_species_group.wavelength.count().astype(float) + / line_interaction_analysis.last_line_in.wavelength.count() + ) + .astype(float) + .tolist() + ) species_abundances = list(map(float, species_abundances)) species_table_model.add_data(species_abundances) self.species_table.setModel(species_table_model) @@ -985,8 +1181,10 @@ def __init__(self, line_interaction_analysis, atom_data, lines_data, description line_interaction_species_group.wavelength.count() self.layout.addWidget(self.text_description) self.layout.addWidget(self.species_table) - self.sectionClicked=QtCore.Signal(int) - self.species_table.verticalHeader().sectionClicked.connect(self.on_species_clicked) + self.sectionClicked = QtCore.Signal(int) + self.species_table.verticalHeader().sectionClicked.connect( + self.on_species_clicked + ) self.layout.addWidget(self.transitions_table) self.setLayout(self.layout) @@ -999,41 +1197,54 @@ def on_species_clicked(self, index): last_line_out = self.line_interaction_analysis.last_line_out current_last_line_in = last_line_in.xs( - key=(current_species[0], current_species[1]), - level=['atomic_number', 'ion_number'], - drop_level=False).reset_index() + key=(current_species[0], current_species[1]), + level=["atomic_number", "ion_number"], + drop_level=False, + ).reset_index() current_last_line_out = last_line_out.xs( - key=(current_species[0], current_species[1]), - level=['atomic_number', 'ion_number'], - drop_level=False).reset_index() - - current_last_line_in['line_id_out'] = current_last_line_out.line_id + key=(current_species[0], current_species[1]), + level=["atomic_number", "ion_number"], + drop_level=False, + ).reset_index() + current_last_line_in["line_id_out"] = current_last_line_out.line_id last_line_in_string = [] last_line_count = [] grouped_line_interactions = current_last_line_in.groupby( - ['line_id', 'line_id_out']) - exc_deexc_string = 'exc. %d-%d (%.2f A) de-exc. %d-%d (%.2f A)' - - for line_id, row in grouped_line_interactions.wavelength.count().iteritems(): + ["line_id", "line_id_out"] + ) + exc_deexc_string = "exc. %d-%d (%.2f A) de-exc. %d-%d (%.2f A)" + + for ( + line_id, + row, + ) in grouped_line_interactions.wavelength.count().iteritems(): current_line_in = self.lines_data.loc[line_id[0]] current_line_out = self.lines_data.loc[line_id[1]] - last_line_in_string.append(exc_deexc_string % ( - current_line_in['level_number_lower'], - current_line_in['level_number_upper'], - current_line_in['wavelength'], - current_line_out['level_number_upper'], - current_line_out['level_number_lower'], - current_line_out['wavelength'])) + last_line_in_string.append( + exc_deexc_string + % ( + current_line_in["level_number_lower"], + current_line_in["level_number_upper"], + current_line_in["wavelength"], + current_line_out["level_number_upper"], + current_line_out["level_number_lower"], + current_line_out["wavelength"], + ) + ) last_line_count.append(int(row)) - - last_line_in_model = self.createTable([last_line_in_string, [ - 'Num. pkts %d' % current_last_line_in.wavelength.count()]]) + last_line_in_model = self.createTable( + [ + last_line_in_string, + ["Num. pkts %d" % current_last_line_in.wavelength.count()], + ] + ) last_line_in_model.add_data(last_line_count) self.transitions_table.setModel(last_line_in_model) + class Tardis(QtWidgets.QMainWindow): """Create the top level window for the GUI and wait for call to display data. @@ -1065,8 +1276,8 @@ def __init__(self, tablemodel, config=None, atom_data=None, parent=None): """ - #assumes that qt has already been initialized by starting IPython - #with the flag "--pylab=qt"gut + # assumes that qt has already been initialized by starting IPython + # with the flag "--pylab=qt"gut # app = QtCore.QCoreApplication.instance() # if app is None: # app = QtGui.QApplication([]) @@ -1078,48 +1289,52 @@ def __init__(self, tablemodel, config=None, atom_data=None, parent=None): QtWidgets.QMainWindow.__init__(self, parent) - #path to icons folder - self.path = os.path.join(tardis.__path__[0],'gui','images') + # path to icons folder + self.path = os.path.join(tardis.__path__[0], "gui", "images") - #Check if configuration file was provided - self.mode = 'passive' + # Check if configuration file was provided + self.mode = "passive" if config is not None: - self.mode = 'active' + self.mode = "active" - #Statusbar + # Statusbar statusbr = self.statusBar() lblstr = 'Calculation did not converge' self.successLabel = QtWidgets.QLabel(lblstr) - self.successLabel.setFrameStyle(QtWidgets.QFrame.StyledPanel | - QtWidgets.QFrame.Sunken) + self.successLabel.setFrameStyle( + QtWidgets.QFrame.StyledPanel | QtWidgets.QFrame.Sunken + ) statusbr.addPermanentWidget(self.successLabel) - self.modeLabel = QtWidgets.QLabel('Passive mode') + self.modeLabel = QtWidgets.QLabel("Passive mode") statusbr.addPermanentWidget(self.modeLabel) statusbr.showMessage(self.mode, 5000) statusbr.showMessage("Ready", 5000) - #Actions + # Actions quitAction = QtWidgets.QAction("&Quit", self) - quitAction.setIcon(QtGui.QIcon(os.path.join(self.path, - 'closeicon.png'))) + quitAction.setIcon( + QtGui.QIcon(os.path.join(self.path, "closeicon.png")) + ) quitAction.triggered.connect(self.close) self.viewMdv = QtWidgets.QAction("View &Model", self) - self.viewMdv.setIcon(QtGui.QIcon(os.path.join(self.path, - 'mdvswitch.png'))) + self.viewMdv.setIcon( + QtGui.QIcon(os.path.join(self.path, "mdvswitch.png")) + ) self.viewMdv.setCheckable(True) self.viewMdv.setChecked(True) self.viewMdv.setEnabled(False) self.viewMdv.triggered.connect(self.switch_to_mdv) self.viewForm = QtWidgets.QAction("&Edit Model", self) - self.viewForm.setIcon(QtGui.QIcon(os.path.join(self.path, - 'formswitch.png'))) + self.viewForm.setIcon( + QtGui.QIcon(os.path.join(self.path, "formswitch.png")) + ) self.viewForm.setCheckable(True) self.viewForm.setEnabled(False) self.viewForm.triggered.connect(self.switch_to_form) - #Menubar + # Menubar self.fileMenu = self.menuBar().addMenu("&File") self.fileMenu.addAction(quitAction) self.viewMenu = self.menuBar().addMenu("&View") @@ -1127,7 +1342,7 @@ def __init__(self, tablemodel, config=None, atom_data=None, parent=None): self.viewMenu.addAction(self.viewForm) self.helpMenu = self.menuBar().addMenu("&Help") - #Toolbar + # Toolbar fileToolbar = self.addToolBar("File") fileToolbar.setObjectName("FileToolBar") fileToolbar.addAction(quitAction) @@ -1137,14 +1352,14 @@ def __init__(self, tablemodel, config=None, atom_data=None, parent=None): viewToolbar.addAction(self.viewMdv) viewToolbar.addAction(self.viewForm) - #Central Widget + # Central Widget self.stackedWidget = QtWidgets.QStackedWidget() self.mdv = ModelViewer(tablemodel) self.stackedWidget.addWidget(self.mdv) - #In case of active mode - if self.mode == 'active': - #Disabled currently + # In case of active mode + if self.mode == "active": + # Disabled currently # self.formWidget = ConfigEditor(config) # #scrollarea # scrollarea = QtGui.QScrollArea() @@ -1154,8 +1369,10 @@ def __init__(self, tablemodel, config=None, atom_data=None, parent=None): # self.viewMdv.setEnabled(True) # model = run_tardis(config, atom_data) # self.show_model(model) - raise TemporarilyUnavaliable("The active mode is under" - "development. Please use the passive mode for now.") + raise TemporarilyUnavaliable( + "The active mode is under" + "development. Please use the passive mode for now." + ) self.setCentralWidget(self.stackedWidget) @@ -1172,8 +1389,8 @@ def show_model(self, model=None): self.mdv.change_model(model) if model.converged: self.successLabel.setText('converged') - if self.mode == 'active': - self.modeLabel.setText('Active Mode') + if self.mode == "active": + self.modeLabel.setText("Active Mode") self.mdv.fill_output_label() self.mdv.tableview.setModel(self.mdv.tablemodel) @@ -1181,7 +1398,6 @@ def show_model(self, model=None): self.mdv.plot_spectrum() self.showMaximized() - def switch_to_mdv(self): """Switch the cental stacked widget to show the modelviewer.""" self.stackedWidget.setCurrentIndex(0) @@ -1192,6 +1408,7 @@ def switch_to_form(self): self.stackedWidget.setCurrentIndex(1) self.viewMdv.setChecked(False) + class TemporarilyUnavaliable(Exception): """Exception raised when creation of active mode of tardis is attempted.""" diff --git a/tardis/io/__init__.py b/tardis/io/__init__.py index 8034d6d0cc0..ca4d671e64b 100644 --- a/tardis/io/__init__.py +++ b/tardis/io/__init__.py @@ -1,4 +1,8 @@ -#readin model_data -from tardis.io.model_reader import read_simple_ascii_density, read_simple_ascii_abundances, read_density_file +# readin model_data +from tardis.io.model_reader import ( + read_simple_ascii_density, + read_simple_ascii_abundances, + read_density_file, +) from tardis.io.config_internal import get_internal_configuration, get_data_dir diff --git a/tardis/io/atom_data/__init__.py b/tardis/io/atom_data/__init__.py index 10a3e57fd9b..5fd5f11b1ce 100644 --- a/tardis/io/atom_data/__init__.py +++ b/tardis/io/atom_data/__init__.py @@ -1,2 +1,2 @@ from tardis.io.atom_data.base import AtomData -from tardis.io.atom_data.atom_web_download import download_atom_data \ No newline at end of file +from tardis.io.atom_data.atom_web_download import download_atom_data diff --git a/tardis/io/atom_data/atom_web_download.py b/tardis/io/atom_data/atom_web_download.py index 3abd24d884a..fdc60e55328 100644 --- a/tardis/io/atom_data/atom_web_download.py +++ b/tardis/io/atom_data/atom_web_download.py @@ -7,6 +7,7 @@ logger = logging.getLogger(__name__) + def get_atomic_repo_config(): """ Get the repo configuration dictionary for the atomic data @@ -17,7 +18,7 @@ def get_atomic_repo_config(): """ - atomic_repo_fname = get_internal_data_path('atomic_data_repo.yml') + atomic_repo_fname = get_internal_data_path("atomic_data_repo.yml") return yaml.load(open(atomic_repo_fname), Loader=yaml.CLoader) @@ -38,11 +39,15 @@ def download_atom_data(atomic_data_name=None): atomic_repo = get_atomic_repo_config() if atomic_data_name is None: - atomic_data_name = atomic_repo['default'] + atomic_data_name = atomic_repo["default"] if atomic_data_name not in atomic_repo: - raise ValueError('Atomic Data name {0} not known'.format(atomic_data_name)) - dst_dir = os.path.join(get_data_dir(), '{0}.h5'.format(atomic_data_name)) - src_url = atomic_repo[atomic_data_name]['url'] - logger.info('Downloading atomic data from {0} to {1}'.format(src_url, dst_dir)) - download_from_url(src_url, dst_dir) \ No newline at end of file + raise ValueError( + "Atomic Data name {0} not known".format(atomic_data_name) + ) + dst_dir = os.path.join(get_data_dir(), "{0}.h5".format(atomic_data_name)) + src_url = atomic_repo[atomic_data_name]["url"] + logger.info( + "Downloading atomic data from {0} to {1}".format(src_url, dst_dir) + ) + download_from_url(src_url, dst_dir) diff --git a/tardis/io/atom_data/base.py b/tardis/io/atom_data/base.py index 8202ec8fcb4..26841c1175e 100644 --- a/tardis/io/atom_data/base.py +++ b/tardis/io/atom_data/base.py @@ -12,6 +12,7 @@ from astropy.units import Quantity from tardis.io.atom_data.util import resolve_atom_data_fname + class AtomDataNotPreparedError(Exception): pass @@ -117,23 +118,25 @@ class AtomData(object): """ hdf_names = [ - "atom_data", - "ionization_data", - "levels", - "lines", - "macro_atom_data", - "macro_atom_references", - "zeta_data", - "collision_data", - "collision_data_temperatures", - "synpp_refs", - "photoionization_data" + "atom_data", + "ionization_data", + "levels", + "lines", + "macro_atom_data", + "macro_atom_references", + "zeta_data", + "collision_data", + "collision_data_temperatures", + "synpp_refs", + "photoionization_data", ] # List of tuples of the related dataframes. # Either all or none of the related dataframes must be given - related_groups = [("macro_atom_data_all", "macro_atom_references_all"), - ("collision_data", "collision_data_temperatures")] + related_groups = [ + ("macro_atom_data_all", "macro_atom_references_all"), + ("collision_data", "collision_data_temperatures"), + ] @classmethod def from_hdf(cls, fname=None): @@ -153,7 +156,7 @@ def from_hdf(cls, fname=None): fname = resolve_atom_data_fname(fname) - with pd.HDFStore(fname, 'r') as store: + with pd.HDFStore(fname, "r") as store: for name in cls.hdf_names: try: dataframes[name] = store[name] @@ -163,17 +166,17 @@ def from_hdf(cls, fname=None): atom_data = cls(**dataframes) try: - atom_data.uuid1 = store.root._v_attrs['uuid1'].decode('ascii') + atom_data.uuid1 = store.root._v_attrs["uuid1"].decode("ascii") except KeyError: atom_data.uuid1 = None try: - atom_data.md5 = store.root._v_attrs['md5'].decode('ascii') + atom_data.md5 = store.root._v_attrs["md5"].decode("ascii") except KeyError: atom_data.md5 = None try: - atom_data.version = store.root._v_attrs['database_version'] + atom_data.version = store.root._v_attrs["database_version"] except KeyError: atom_data.version = None @@ -181,18 +184,32 @@ def from_hdf(cls, fname=None): logger.info( "Read Atom Data with UUID={0} and MD5={1}.".format( - atom_data.uuid1, atom_data.md5)) + atom_data.uuid1, atom_data.md5 + ) + ) if nonavailable: - logger.info("Non provided atomic data: {0}".format( - ", ".join(nonavailable))) + logger.info( + "Non provided atomic data: {0}".format( + ", ".join(nonavailable) + ) + ) return atom_data - def __init__(self, atom_data, ionization_data, levels=None, lines=None, - macro_atom_data=None, macro_atom_references=None, - zeta_data=None, collision_data=None, - collision_data_temperatures=None, synpp_refs=None, - photoionization_data=None): + def __init__( + self, + atom_data, + ionization_data, + levels=None, + lines=None, + macro_atom_data=None, + macro_atom_references=None, + zeta_data=None, + collision_data=None, + collision_data_temperatures=None, + synpp_refs=None, + photoionization_data=None, + ): self.prepared = False @@ -205,7 +222,8 @@ def __init__(self, atom_data, ionization_data, levels=None, lines=None, # the value of constants.u is used in all cases) if u.u.cgs == const.u.cgs: atom_data.loc[:, "mass"] = Quantity( - atom_data["mass"].values, "u").cgs + atom_data["mass"].values, "u" + ).cgs else: atom_data.loc[:, "mass"] = atom_data["mass"].values * const.u.cgs @@ -214,11 +232,12 @@ def __init__(self, atom_data, ionization_data, levels=None, lines=None, ionization_data[:] = Quantity(ionization_data[:], "eV").cgs # Convert energy to CGS - levels.loc[:, "energy"] = Quantity(levels["energy"].values, 'eV').cgs + levels.loc[:, "energy"] = Quantity(levels["energy"].values, "eV").cgs # Create a new columns with wavelengths in the CGS units - lines.loc[:, 'wavelength_cm'] = Quantity( - lines['wavelength'], 'angstrom').cgs + lines.loc[:, "wavelength_cm"] = Quantity( + lines["wavelength"], "angstrom" + ).cgs # SET ATTRIBUTES @@ -243,31 +262,33 @@ def __init__(self, atom_data, ionization_data, levels=None, lines=None, self._check_related() self.symbol2atomic_number = OrderedDict( - zip(self.atom_data['symbol'].values, self.atom_data.index)) + zip(self.atom_data["symbol"].values, self.atom_data.index) + ) self.atomic_number2symbol = OrderedDict( - zip(self.atom_data.index, self.atom_data['symbol'])) + zip(self.atom_data.index, self.atom_data["symbol"]) + ) def _check_related(self): """ Check that either all or none of the related dataframes are given. """ for group in self.related_groups: - check_list = [ - name for name in group if getattr(self, name) is None] + check_list = [name for name in group if getattr(self, name) is None] if len(check_list) != 0 and len(check_list) != len(group): raise AtomDataMissingError( "The following dataframes from the related group [{0}] " "were not given: {1}".format( - ", ".join(group), - ", ".join(check_list) - ) + ", ".join(group), ", ".join(check_list) ) + ) def prepare_atom_data( - self, selected_atomic_numbers, - line_interaction_type='scatter', - nlte_species=[]): + self, + selected_atomic_numbers, + line_interaction_type="scatter", + nlte_species=[], + ): """ Prepares the atom data to set the lines, levels and if requested macro atom data. This function mainly cuts the `levels` and `lines` by @@ -295,105 +316,155 @@ def prepare_atom_data( self.nlte_species = nlte_species self.levels = self.levels[ - self.levels.index.isin( - self.selected_atomic_numbers, - level='atomic_number')] + self.levels.index.isin( + self.selected_atomic_numbers, level="atomic_number" + ) + ] self.levels_index = pd.Series( - np.arange(len(self.levels), dtype=int), - index=self.levels.index) + np.arange(len(self.levels), dtype=int), index=self.levels.index + ) # cutting levels_lines self.lines = self.lines[ - self.lines.index.isin( - self.selected_atomic_numbers, - level='atomic_number')] + self.lines.index.isin( + self.selected_atomic_numbers, level="atomic_number" + ) + ] - self.lines.sort_values(by='wavelength', inplace=True) + self.lines.sort_values(by="wavelength", inplace=True) self.lines_index = pd.Series( - np.arange(len(self.lines), dtype=int), - index=self.lines.set_index('line_id').index) + np.arange(len(self.lines), dtype=int), + index=self.lines.set_index("line_id").index, + ) - tmp_lines_lower2level_idx = self.lines.index.droplevel('level_number_upper') + tmp_lines_lower2level_idx = self.lines.index.droplevel( + "level_number_upper" + ) self.lines_lower2level_idx = ( - self.levels_index.loc[tmp_lines_lower2level_idx]. - astype(np.int64).values) + self.levels_index.loc[tmp_lines_lower2level_idx] + .astype(np.int64) + .values + ) - tmp_lines_upper2level_idx = self.lines.index.droplevel('level_number_lower') + tmp_lines_upper2level_idx = self.lines.index.droplevel( + "level_number_lower" + ) self.lines_upper2level_idx = ( - self.levels_index.loc[tmp_lines_upper2level_idx]. - astype(np.int64).values) + self.levels_index.loc[tmp_lines_upper2level_idx] + .astype(np.int64) + .values + ) if ( - self.macro_atom_data_all is not None and - not line_interaction_type == 'scatter'): + self.macro_atom_data_all is not None + and not line_interaction_type == "scatter" + ): self.macro_atom_data = self.macro_atom_data_all.loc[ - self.macro_atom_data_all['atomic_number'].isin(self.selected_atomic_numbers) + self.macro_atom_data_all["atomic_number"].isin( + self.selected_atomic_numbers + ) ].copy() self.macro_atom_references = self.macro_atom_references_all[ self.macro_atom_references_all.index.isin( - self.selected_atomic_numbers, - level='atomic_number') + self.selected_atomic_numbers, level="atomic_number" + ) ].copy() - if line_interaction_type == 'downbranch': + if line_interaction_type == "downbranch": self.macro_atom_data = self.macro_atom_data.loc[ - self.macro_atom_data['transition_type'] == -1 + self.macro_atom_data["transition_type"] == -1 ] self.macro_atom_references = self.macro_atom_references.loc[ - self.macro_atom_references['count_down'] > 0 + self.macro_atom_references["count_down"] > 0 ] - self.macro_atom_references.loc[:, 'count_total'] = self.macro_atom_references['count_down'] - self.macro_atom_references.loc[:, 'block_references'] = np.hstack( - (0, np.cumsum(self.macro_atom_references['count_down'].values[:-1])) + self.macro_atom_references.loc[ + :, "count_total" + ] = self.macro_atom_references["count_down"] + self.macro_atom_references.loc[ + :, "block_references" + ] = np.hstack( + ( + 0, + np.cumsum( + self.macro_atom_references["count_down"].values[:-1] + ), + ) ) - elif line_interaction_type == 'macroatom': - self.macro_atom_references.loc[:, 'block_references'] = np.hstack( - (0, np.cumsum(self.macro_atom_references['count_total'].values[:-1])) + elif line_interaction_type == "macroatom": + self.macro_atom_references.loc[ + :, "block_references" + ] = np.hstack( + ( + 0, + np.cumsum( + self.macro_atom_references["count_total"].values[ + :-1 + ] + ), + ) ) - self.macro_atom_references.loc[:, "references_idx"] = np.arange(len(self.macro_atom_references)) + self.macro_atom_references.loc[:, "references_idx"] = np.arange( + len(self.macro_atom_references) + ) self.macro_atom_data.loc[:, "lines_idx"] = self.lines_index.loc[ - self.macro_atom_data['transition_line_id'] + self.macro_atom_data["transition_line_id"] ].values - self.lines_upper2macro_reference_idx = self.macro_atom_references.loc[ - tmp_lines_upper2level_idx, 'references_idx' - ].astype(np.int64).values + self.lines_upper2macro_reference_idx = ( + self.macro_atom_references.loc[ + tmp_lines_upper2level_idx, "references_idx" + ] + .astype(np.int64) + .values + ) - if line_interaction_type == 'macroatom': + if line_interaction_type == "macroatom": # Sets all - tmp_macro_destination_level_idx = pd.MultiIndex.from_arrays([ - self.macro_atom_data['atomic_number'], - self.macro_atom_data['ion_number'], - self.macro_atom_data['destination_level_number'] - ]) - - tmp_macro_source_level_idx = pd.MultiIndex.from_arrays([ - self.macro_atom_data['atomic_number'], - self.macro_atom_data['ion_number'], - self.macro_atom_data['source_level_number'] - ]) - - self.macro_atom_data.loc[:, 'destination_level_idx'] = self.macro_atom_references.loc[ - tmp_macro_destination_level_idx, "references_idx" - ].astype(np.int64).values - - self.macro_atom_data.loc[:, 'source_level_idx'] = self.macro_atom_references.loc[ - tmp_macro_source_level_idx, "references_idx" - ].astype(np.int64).values - - elif line_interaction_type == 'downbranch': + tmp_macro_destination_level_idx = pd.MultiIndex.from_arrays( + [ + self.macro_atom_data["atomic_number"], + self.macro_atom_data["ion_number"], + self.macro_atom_data["destination_level_number"], + ] + ) + + tmp_macro_source_level_idx = pd.MultiIndex.from_arrays( + [ + self.macro_atom_data["atomic_number"], + self.macro_atom_data["ion_number"], + self.macro_atom_data["source_level_number"], + ] + ) + + self.macro_atom_data.loc[:, "destination_level_idx"] = ( + self.macro_atom_references.loc[ + tmp_macro_destination_level_idx, "references_idx" + ] + .astype(np.int64) + .values + ) + + self.macro_atom_data.loc[:, "source_level_idx"] = ( + self.macro_atom_references.loc[ + tmp_macro_source_level_idx, "references_idx" + ] + .astype(np.int64) + .values + ) + + elif line_interaction_type == "downbranch": # Sets all the destination levels to -1 to indicate that they # are not used in downbranch calculations - self.macro_atom_data.loc[:, 'destination_level_idx'] = -1 + self.macro_atom_data.loc[:, "destination_level_idx"] = -1 self.nlte_data = NLTEData(self, nlte_species) @@ -402,15 +473,14 @@ def _check_selected_atomic_numbers(self): available_atomic_numbers = np.unique( self.ionization_data.index.get_level_values(0) ) - atomic_number_check = np.isin(selected_atomic_numbers, - available_atomic_numbers) + atomic_number_check = np.isin( + selected_atomic_numbers, available_atomic_numbers + ) if not all(atomic_number_check): missing_atom_mask = np.logical_not(atomic_number_check) missing_atomic_numbers = selected_atomic_numbers[missing_atom_mask] - missing_numbers_str = ','.join( - missing_atomic_numbers.astype('str') - ) + missing_numbers_str = ",".join(missing_atomic_numbers.astype("str")) msg = "For atomic numbers {} there is no atomic data.".format( missing_numbers_str ) @@ -418,8 +488,11 @@ def _check_selected_atomic_numbers(self): def __repr__(self): return "".format( - self.uuid1, self.md5, self.lines.line_id.count(), - self.levels.energy.count()) + self.uuid1, + self.md5, + self.lines.line_id.count(), + self.levels.energy.count(), + ) class NLTEData(object): @@ -429,7 +502,7 @@ def __init__(self, atom_data, nlte_species): self.nlte_species = nlte_species if nlte_species: - logger.info('Preparing the NLTE data') + logger.info("Preparing the NLTE data") self._init_indices() if atom_data.collision_data is not None: self._create_collision_coefficient_matrix() @@ -444,12 +517,16 @@ def _init_indices(self): for species in self.nlte_species: lines_idx = np.where( - (self.lines.atomic_number == species[0]) & - (self.lines.ion_number == species[1]) - ) + (self.lines.atomic_number == species[0]) + & (self.lines.ion_number == species[1]) + ) self.lines_idx[species] = lines_idx - self.lines_level_number_lower[species] = self.lines.level_number_lower.values[lines_idx].astype(int) - self.lines_level_number_upper[species] = self.lines.level_number_upper.values[lines_idx].astype(int) + self.lines_level_number_lower[ + species + ] = self.lines.level_number_lower.values[lines_idx].astype(int) + self.lines_level_number_upper[ + species + ] = self.lines.level_number_upper.values[lines_idx].astype(int) self.A_uls[species] = self.atom_data.lines.A_ul.values[lines_idx] self.B_uls[species] = self.atom_data.lines.B_ul.values[lines_idx] @@ -459,55 +536,69 @@ def _create_collision_coefficient_matrix(self): self.C_ul_interpolator = {} self.delta_E_matrices = {} self.g_ratio_matrices = {} - collision_group = self.atom_data.collision_data.groupby(level=['atomic_number', 'ion_number']) + collision_group = self.atom_data.collision_data.groupby( + level=["atomic_number", "ion_number"] + ) for species in self.nlte_species: no_of_levels = self.atom_data.levels.loc[species].energy.count() C_ul_matrix = np.zeros( - ( - no_of_levels, - no_of_levels, - len(self.atom_data.collision_data_temperatures)) - ) + ( + no_of_levels, + no_of_levels, + len(self.atom_data.collision_data_temperatures), + ) + ) delta_E_matrix = np.zeros((no_of_levels, no_of_levels)) g_ratio_matrix = np.zeros((no_of_levels, no_of_levels)) for ( + ( atomic_number, ion_number, level_number_lower, - level_number_upper), line in ( - collision_group.get_group(species).iterrows()): - # line.columns : delta_e, g_ratio, temperatures ... - C_ul_matrix[level_number_lower, level_number_upper, :] = line.values[2:] - delta_E_matrix[level_number_lower, level_number_upper] = line['delta_e'] - #TODO TARDISATOMIC fix change the g_ratio to be the otherway round - I flip them now here. - g_ratio_matrix[level_number_lower, level_number_upper] = 1/line['g_ratio'] + level_number_upper, + ), + line, + ) in collision_group.get_group(species).iterrows(): + # line.columns : delta_e, g_ratio, temperatures ... + C_ul_matrix[ + level_number_lower, level_number_upper, : + ] = line.values[2:] + delta_E_matrix[level_number_lower, level_number_upper] = line[ + "delta_e" + ] + # TODO TARDISATOMIC fix change the g_ratio to be the otherway round - I flip them now here. + g_ratio_matrix[level_number_lower, level_number_upper] = ( + 1 / line["g_ratio"] + ) self.C_ul_interpolator[species] = interpolate.interp1d( - self.atom_data.collision_data_temperatures, - C_ul_matrix) + self.atom_data.collision_data_temperatures, C_ul_matrix + ) self.delta_E_matrices[species] = delta_E_matrix self.g_ratio_matrices[species] = g_ratio_matrix def get_collision_matrix(self, species, t_electrons): - ''' + """ Creat collision matrix by interpolating the C_ul values for the desired temperatures. - ''' + """ c_ul_matrix = self.C_ul_interpolator[species](t_electrons) no_of_levels = c_ul_matrix.shape[0] c_ul_matrix[np.isnan(c_ul_matrix)] = 0.0 - #TODO in tardisatomic the g_ratio is the other way round - here I'll flip it in prepare_collision matrix + # TODO in tardisatomic the g_ratio is the other way round - here I'll flip it in prepare_collision matrix c_lu_matrix = ( - c_ul_matrix * np.exp( - -self.delta_E_matrices[species].reshape( - (no_of_levels, no_of_levels, 1)) / - t_electrons.reshape( - (1, 1, t_electrons.shape[0])) - ) * - self.g_ratio_matrices[species].reshape( - (no_of_levels, no_of_levels, 1)) + c_ul_matrix + * np.exp( + -self.delta_E_matrices[species].reshape( + (no_of_levels, no_of_levels, 1) ) + / t_electrons.reshape((1, 1, t_electrons.shape[0])) + ) + * self.g_ratio_matrices[species].reshape( + (no_of_levels, no_of_levels, 1) + ) + ) return c_ul_matrix + c_lu_matrix.transpose(1, 0, 2) diff --git a/tardis/io/atom_data/util.py b/tardis/io/atom_data/util.py index d5f1fde1939..599e774d69f 100644 --- a/tardis/io/atom_data/util.py +++ b/tardis/io/atom_data/util.py @@ -2,10 +2,14 @@ import logging from tardis.io.config_internal import get_data_dir -from tardis.io.atom_data.atom_web_download import get_atomic_repo_config, download_atom_data +from tardis.io.atom_data.atom_web_download import ( + get_atomic_repo_config, + download_atom_data, +) logger = logging.getLogger(__name__) + def resolve_atom_data_fname(fname): """ Check where if atom data HDF file is available on disk, can be downloaded or does not exist @@ -26,17 +30,23 @@ def resolve_atom_data_fname(fname): fpath = os.path.join(os.path.join(get_data_dir(), fname)) if os.path.exists(fpath): - logger.info('Atom Data {0} not found in local path. Exists in TARDIS Data repo {1}'.format(fname, fpath)) + logger.info( + "Atom Data {0} not found in local path. Exists in TARDIS Data repo {1}".format( + fname, fpath + ) + ) return fpath - atom_data_name = fname.replace('.h5', '') + atom_data_name = fname.replace(".h5", "") atom_repo_config = get_atomic_repo_config() if atom_data_name in atom_repo_config: - raise IOError('Atom Data {0} not found in path or in TARDIS data repo - it is available as download:\n' - 'from tardis.io.atom_data import download_atom_data\n' - 'download_atom_data(\'{1}\')'.format(fname, atom_data_name)) - - raise IOError('Atom Data {0} is not found in current path or in TARDIS data repo. {1} is also not a standard known' - 'TARDIS atom dataset.'.format(fname, atom_data_name)) - - + raise IOError( + "Atom Data {0} not found in path or in TARDIS data repo - it is available as download:\n" + "from tardis.io.atom_data import download_atom_data\n" + "download_atom_data('{1}')".format(fname, atom_data_name) + ) + + raise IOError( + f"Atom Data {fname} is not found in current path or in TARDIS data repo. {atom_data_name} " + "is also not a standard known TARDIS atom dataset." + ) diff --git a/tardis/io/config_internal.py b/tardis/io/config_internal.py index 0963696fd4d..30ff8a1567e 100644 --- a/tardis/io/config_internal.py +++ b/tardis/io/config_internal.py @@ -5,40 +5,57 @@ from astropy.config import get_config_dir TARDIS_PATH = TARDIS_PATH[0] -DEFAULT_CONFIG_PATH = os.path.join(TARDIS_PATH, 'data', 'default_tardis_internal_config.yml') -DEFAULT_DATA_DIR = os.path.join(os.path.expanduser('~'), 'Downloads', 'tardis-data') +DEFAULT_CONFIG_PATH = os.path.join( + TARDIS_PATH, "data", "default_tardis_internal_config.yml" +) +DEFAULT_DATA_DIR = os.path.join( + os.path.expanduser("~"), "Downloads", "tardis-data" +) logger = logging.getLogger(__name__) + def get_internal_configuration(): - config_fpath = os.path.join(get_config_dir(), 'tardis_internal_config.yml') + config_fpath = os.path.join(get_config_dir(), "tardis_internal_config.yml") if not os.path.exists(config_fpath): - logger.warning("Configuration File {0} does not exist - creating new one from default".format(config_fpath)) + logger.warning( + "Configuration File {0} does not exist - creating new one from default".format( + config_fpath + ) + ) shutil.copy(DEFAULT_CONFIG_PATH, config_fpath) with open(config_fpath) as config_fh: return yaml.load(config_fh, Loader=yaml.CLoader) - def get_data_dir(): config = get_internal_configuration() - data_dir = config.get('data_dir', None) + data_dir = config.get("data_dir", None) if data_dir is None: - config_fpath = os.path.join(get_config_dir(), 'tardis_internal_config.yml') - logging.critical('\n{line_stars}\n\nTARDIS will download different kinds of data (e.g. atomic) to its data directory {default_data_dir}\n\n' - 'TARDIS DATA DIRECTORY not specified in {config_file}:\n\n' - 'ASSUMING DEFAULT DATA DIRECTORY {default_data_dir}\n ' - 'YOU CAN CHANGE THIS AT ANY TIME IN {config_file} \n\n' - '{line_stars} \n\n'.format(line_stars='*'*80, config_file=config_fpath, - default_data_dir=DEFAULT_DATA_DIR)) + config_fpath = os.path.join( + get_config_dir(), "tardis_internal_config.yml" + ) + logging.critical( + "\n{line_stars}\n\nTARDIS will download different kinds of data (e.g. atomic) to its data directory {default_data_dir}\n\n" + "TARDIS DATA DIRECTORY not specified in {config_file}:\n\n" + "ASSUMING DEFAULT DATA DIRECTORY {default_data_dir}\n " + "YOU CAN CHANGE THIS AT ANY TIME IN {config_file} \n\n" + "{line_stars} \n\n".format( + line_stars="*" * 80, + config_file=config_fpath, + default_data_dir=DEFAULT_DATA_DIR, + ) + ) if not os.path.exists(DEFAULT_DATA_DIR): os.makedirs(DEFAULT_DATA_DIR) - config['data_dir'] = DEFAULT_DATA_DIR - yaml.dump(config, open(config_fpath, 'w'), default_flow_style=False) + config["data_dir"] = DEFAULT_DATA_DIR + yaml.dump(config, open(config_fpath, "w"), default_flow_style=False) data_dir = DEFAULT_DATA_DIR if not os.path.exists(data_dir): - raise IOError('Data directory specified in {0} does not exist'.format(data_dir)) + raise IOError( + "Data directory specified in {0} does not exist".format(data_dir) + ) return data_dir diff --git a/tardis/io/config_reader.py b/tardis/io/config_reader.py index 5927194d93d..6c6badb77c0 100644 --- a/tardis/io/config_reader.py +++ b/tardis/io/config_reader.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -data_dir = os.path.abspath(os.path.join(tardis.__path__[0], 'data')) +data_dir = os.path.abspath(os.path.join(tardis.__path__[0], "data")) class ConfigurationError(ValueError): @@ -31,17 +31,20 @@ def parse_convergence_section(convergence_section_dict): dictionary """ - convergence_parameters = ['damping_constant', 'threshold'] + convergence_parameters = ["damping_constant", "threshold"] - for convergence_variable in ['t_inner', 't_rad', 'w']: + for convergence_variable in ["t_inner", "t_rad", "w"]: if convergence_variable not in convergence_section_dict: convergence_section_dict[convergence_variable] = {} - convergence_variable_section = convergence_section_dict[convergence_variable] + convergence_variable_section = convergence_section_dict[ + convergence_variable + ] for param in convergence_parameters: if convergence_variable_section.get(param, None) is None: if param in convergence_section_dict: - convergence_section_dict[convergence_variable][param] = ( - convergence_section_dict[param]) + convergence_section_dict[convergence_variable][ + param + ] = convergence_section_dict[param] return convergence_section_dict @@ -79,10 +82,9 @@ def from_yaml(cls, fname): try: yaml_dict = yaml_load_file(fname) except IOError as e: - logger.critical('No config file named: %s', fname) + logger.critical(f"No config file named: {fname}") raise e - return cls.from_config_dict(yaml_dict) @classmethod @@ -115,30 +117,32 @@ def __init__(self, value=None): for key in value: self.__setitem__(key, value[key]) else: - raise (TypeError, 'expected dict') + raise (TypeError, "expected dict") - if hasattr(self, 'csvy_model') and hasattr(self, 'model'): - raise ValueError('Cannot specify both model and csvy_model in main config file.') - if hasattr(self, 'csvy_model'): + if hasattr(self, "csvy_model") and hasattr(self, "model"): + raise ValueError( + "Cannot specify both model and csvy_model in main config file." + ) + if hasattr(self, "csvy_model"): model = dict() - csvy_model_path = os.path.join(self.config_dirname,self.csvy_model) + csvy_model_path = os.path.join(self.config_dirname, self.csvy_model) csvy_yml = load_yaml_from_csvy(csvy_model_path) - if 'v_inner_boundary' in csvy_yml: - model['v_inner_boundary'] = csvy_yml['v_inner_boundary'] - if 'v_outer_boundary' in csvy_yml: - model['v_outer_boundary'] = csvy_yml['v_outer_boundary'] + if "v_inner_boundary" in csvy_yml: + model["v_inner_boundary"] = csvy_yml["v_inner_boundary"] + if "v_outer_boundary" in csvy_yml: + model["v_outer_boundary"] = csvy_yml["v_outer_boundary"] - self.__setitem__('model',model) + self.__setitem__("model", model) for key in self.model: self.model.__setitem__(key, self.model[key]) - def __setitem__(self, key, value): - if isinstance(value, dict) and not isinstance(value, - ConfigurationNameSpace): + if isinstance(value, dict) and not isinstance( + value, ConfigurationNameSpace + ): value = ConfigurationNameSpace(value) - if key in self and hasattr(self[key], 'unit'): + if key in self and hasattr(self[key], "unit"): value = u.Quantity(value, self[key].unit) dict.__setitem__(self, key, value) @@ -167,23 +171,26 @@ def get_config_item(self, config_item_string): config_item_string: ~str string of shape 'section1.sectionb.param1' """ - config_item_path = config_item_string.split('.') + config_item_path = config_item_string.split(".") if len(config_item_path) == 1: config_item = config_item_path[0] - if config_item.startswith('item'): + if config_item.startswith("item"): return self[config_item_path[0]] else: return self[config_item] - elif len(config_item_path) == 2 and\ - config_item_path[1].startswith('item'): + elif len(config_item_path) == 2 and config_item_path[1].startswith( + "item" + ): return self[config_item_path[0]][ - int(config_item_path[1].replace('item', ''))] + int(config_item_path[1].replace("item", "")) + ] else: return self[config_item_path[0]].get_config_item( - '.'.join(config_item_path[1:])) + ".".join(config_item_path[1:]) + ) def set_config_item(self, config_item_string, value): """ @@ -199,24 +206,28 @@ def set_config_item(self, config_item_string, value): value to set the parameter with it """ - config_item_path = config_item_string.split('.') + config_item_path = config_item_string.split(".") if len(config_item_path) == 1: self[config_item_path[0]] = value - elif len(config_item_path) == 2 and \ - config_item_path[1].startswith('item'): + elif len(config_item_path) == 2 and config_item_path[1].startswith( + "item" + ): current_value = self[config_item_path[0]][ - int(config_item_path[1].replace('item', ''))] - if hasattr(current_value, 'unit'): + int(config_item_path[1].replace("item", "")) + ] + if hasattr(current_value, "unit"): self[config_item_path[0]][ - int(config_item_path[1].replace('item', ''))] =\ - u.Quantity(value, current_value.unit) + int(config_item_path[1].replace("item", "")) + ] = u.Quantity(value, current_value.unit) else: self[config_item_path[0]][ - int(config_item_path[1].replace('item', ''))] = value + int(config_item_path[1].replace("item", "")) + ] = value else: self[config_item_path[0]].set_config_item( - '.'.join(config_item_path[1:]), value) + ".".join(config_item_path[1:]), value + ) def deepcopy(self): return ConfigurationNameSpace(copy.deepcopy(dict(self))) @@ -230,24 +241,25 @@ class Configuration(ConfigurationNameSpace): @classmethod def from_yaml(cls, fname, *args, **kwargs): try: - yaml_dict = yaml_load_file(fname, - loader=kwargs.pop('loader', YAMLLoader)) + yaml_dict = yaml_load_file( + fname, loader=kwargs.pop("loader", YAMLLoader) + ) except IOError as e: - logger.critical('No config file named: %s', fname) + logger.critical(f"No config file named: {fname}") raise e - tardis_config_version = yaml_dict.get('tardis_config_version', None) - if tardis_config_version != 'v1.0': + tardis_config_version = yaml_dict.get("tardis_config_version", None) + if tardis_config_version != "v1.0": raise ConfigurationError( - 'Currently only tardis_config_version v1.0 supported') + "Currently only tardis_config_version v1.0 supported" + ) - kwargs['config_dirname'] = os.path.dirname(fname) + kwargs["config_dirname"] = os.path.dirname(fname) - return cls.from_config_dict( - yaml_dict, *args, **kwargs) + return cls.from_config_dict(yaml_dict, *args, **kwargs) @classmethod - def from_config_dict(cls, config_dict, validate=True, config_dirname=''): + def from_config_dict(cls, config_dict, validate=True, config_dirname=""): """ Validating and subsequently parsing a config file. @@ -272,24 +284,28 @@ def from_config_dict(cls, config_dict, validate=True, config_dirname=''): else: validated_config_dict = config_dict - validated_config_dict['config_dirname'] = config_dirname + validated_config_dict["config_dirname"] = config_dirname - montecarlo_section = validated_config_dict['montecarlo'] - if montecarlo_section['convergence_strategy']['type'] == "damped": - montecarlo_section['convergence_strategy'] = ( - parse_convergence_section( - montecarlo_section['convergence_strategy'])) - elif montecarlo_section['convergence_strategy']['type'] == "custom": + montecarlo_section = validated_config_dict["montecarlo"] + if montecarlo_section["convergence_strategy"]["type"] == "damped": + montecarlo_section[ + "convergence_strategy" + ] = parse_convergence_section( + montecarlo_section["convergence_strategy"] + ) + elif montecarlo_section["convergence_strategy"]["type"] == "custom": raise NotImplementedError( 'convergence_strategy is set to "custom"; ' - 'you need to implement your specific convergence treatment') + "you need to implement your specific convergence treatment" + ) else: - raise ValueError('convergence_strategy is not "damped" ' - 'or "custom"') + raise ValueError( + 'convergence_strategy is not "damped" ' 'or "custom"' + ) - enable_full_relativity = montecarlo_section['enable_full_relativity'] + enable_full_relativity = montecarlo_section["enable_full_relativity"] spectrum_integrated = ( - validated_config_dict['spectrum']['method'] == 'integrated' + validated_config_dict["spectrum"]["method"] == "integrated" ) if enable_full_relativity and spectrum_integrated: raise NotImplementedError( @@ -301,6 +317,5 @@ def from_config_dict(cls, config_dict, validate=True, config_dirname=''): return cls(validated_config_dict) - def __init__(self, config_dict): super(Configuration, self).__init__(config_dict) diff --git a/tardis/io/config_validator.py b/tardis/io/config_validator.py index fa1584e1755..8b23aefc8c2 100644 --- a/tardis/io/config_validator.py +++ b/tardis/io/config_validator.py @@ -6,8 +6,8 @@ from tardis.io.util import YAMLLoader base_dir = os.path.abspath(os.path.dirname(__file__)) -schema_dir = os.path.join(base_dir, 'schemas') -config_schema_file = os.path.join(schema_dir, 'base.yml') +schema_dir = os.path.join(base_dir, "schemas") +config_schema_file = os.path.join(schema_dir, "base.yml") def extend_with_default(validator_class): @@ -26,56 +26,55 @@ def extend_with_default(validator_class): The extended `jsonschema.IValidator` """ - validate_properties = validator_class.VALIDATORS['properties'] + validate_properties = validator_class.VALIDATORS["properties"] def set_defaults(validator, properties, instance, schema): # This validator also checks if default values # are of the correct type and properly sets default # values on schemas that use the oneOf keyword if not list( - validate_properties(validator, properties, instance, schema)): + validate_properties(validator, properties, instance, schema) + ): for property, subschema in properties.items(): - if 'default' in subschema: - instance.setdefault(property, subschema['default']) + if "default" in subschema: + instance.setdefault(property, subschema["default"]) for error in validate_properties( - validator, properties, instance, schema, + validator, properties, instance, schema, ): yield error - return validators.extend( - validator_class, {'properties': set_defaults}, - ) + return validators.extend(validator_class, {"properties": set_defaults},) DefaultDraft4Validator = extend_with_default(Draft4Validator) def _yaml_handler(path): - if not path.startswith('file://'): - raise Exception('Not a file URL: {}'.format(path)) - with open(path[len('file://'):]) as f: + if not path.startswith("file://"): + raise Exception("Not a file URL: {}".format(path)) + with open(path[len("file://") :]) as f: return yaml.load(f, Loader=YAMLLoader) -def validate_dict(config_dict, schemapath=config_schema_file, - validator=DefaultDraft4Validator): +def validate_dict( + config_dict, schemapath=config_schema_file, validator=DefaultDraft4Validator +): with open(schemapath) as f: schema = yaml.load(f, Loader=YAMLLoader) schemaurl = "file://" + schemapath - handlers = {'file': _yaml_handler} + handlers = {"file": _yaml_handler} resolver = RefResolver(schemaurl, schema, handlers=handlers) validated_dict = deepcopy(config_dict) - validator(schema=schema, - types={'quantity': (Quantity,)}, - resolver=resolver - ).validate(validated_dict) + validator( + schema=schema, types={"quantity": (Quantity,)}, resolver=resolver + ).validate(validated_dict) return validated_dict -def validate_yaml(configpath, schemapath=config_schema_file, - validator=DefaultDraft4Validator): +def validate_yaml( + configpath, schemapath=config_schema_file, validator=DefaultDraft4Validator +): with open(configpath) as f: config = yaml.load(f, Loader=YAMLLoader) return validate_dict(config, schemapath, validator) - diff --git a/tardis/io/decay.py b/tardis/io/decay.py index 570fd7e96e2..59f85b62d75 100644 --- a/tardis/io/decay.py +++ b/tardis/io/decay.py @@ -2,14 +2,15 @@ from pyne import nucname, material from astropy import units as u + class IsotopeAbundances(pd.DataFrame): _metadata = ["time_0"] def __init__(self, *args, **kwargs): - if 'time_0' in kwargs: - time_0 = kwargs['time_0'] - kwargs.pop('time_0') + if "time_0" in kwargs: + time_0 = kwargs["time_0"] + kwargs.pop("time_0") else: time_0 = 0 * u.d super(IsotopeAbundances, self).__init__(*args, **kwargs) @@ -22,8 +23,7 @@ def _constructor(self): def _update_material(self): self.comp_dicts = [dict() for i in range(len(self.columns))] for (atomic_number, mass_number), abundances in self.iterrows(): - nuclear_symbol = '%s%d'.format(nucname.name(atomic_number), - mass_number) + nuclear_symbol = f"{nucname.name(atomic_number)}{mass_number}" for i in range(len(self.columns)): self.comp_dicts[i][nuclear_symbol] = abundances[i] @@ -31,14 +31,17 @@ def _update_material(self): def from_materials(cls, materials): multi_index_tuples = set([]) for material in materials: - multi_index_tuples.update([cls.id_to_tuple(key) - for key in material.keys()]) + multi_index_tuples.update( + [cls.id_to_tuple(key) for key in material.keys()] + ) index = pd.MultiIndex.from_tuples( - multi_index_tuples, names=['atomic_number', 'mass_number']) - + multi_index_tuples, names=["atomic_number", "mass_number"] + ) - abundances = pd.DataFrame(data=0.0, index=index, columns=range(len(materials))) + abundances = pd.DataFrame( + data=0.0, index=index, columns=range(len(materials)) + ) for i, material in enumerate(materials): for key, value in material.items(): @@ -46,14 +49,10 @@ def from_materials(cls, materials): return cls(abundances) - - - @staticmethod def id_to_tuple(atomic_id): return nucname.znum(atomic_id), nucname.anum(atomic_id) - def to_materials(self): """ Convert DataFrame to a list of materials interpreting the MultiIndex as @@ -68,14 +67,13 @@ def to_materials(self): comp_dicts = [dict() for i in range(len(self.columns))] for (atomic_number, mass_number), abundances in self.iterrows(): - nuclear_symbol = '{0:s}{1:d}'.format(nucname.name(atomic_number), - mass_number) + nuclear_symbol = "{0:s}{1:d}".format( + nucname.name(atomic_number), mass_number + ) for i in range(len(self.columns)): comp_dicts[i][nuclear_symbol] = abundances[i] return [material.Material(comp_dict) for comp_dict in comp_dicts] - - def decay(self, t): """ Decay the Model @@ -91,13 +89,15 @@ def decay(self, t): """ materials = self.to_materials() - t_second = u.Quantity(t, u.day).to(u.s).value - self.time_0.to(u.s).value + t_second = ( + u.Quantity(t, u.day).to(u.s).value - self.time_0.to(u.s).value + ) decayed_materials = [item.decay(t_second) for item in materials] for i in range(len(materials)): materials[i].update(decayed_materials[i]) df = IsotopeAbundances.from_materials(materials) df.sort_index(inplace=True) - return df + return df def as_atoms(self): """ @@ -107,7 +107,7 @@ def as_atoms(self): : merged isotope abundances """ - return self.groupby('atomic_number').sum() + return self.groupby("atomic_number").sum() def merge(self, other, normalize=True): """ @@ -124,7 +124,7 @@ def merge(self, other, normalize=True): """ isotope_abundance = self.as_atoms() isotope_abundance = isotope_abundance.fillna(0.0) - #Merge abundance and isotope dataframe + # Merge abundance and isotope dataframe modified_df = isotope_abundance.add(other, fill_value=0) if normalize: diff --git a/tardis/io/model_reader.py b/tardis/io/model_reader.py index 4812e1236fc..d8060e2de67 100644 --- a/tardis/io/model_reader.py +++ b/tardis/io/model_reader.py @@ -1,4 +1,4 @@ -#reading different model files +# reading different model files import warnings import numpy as np @@ -8,6 +8,7 @@ from pyne import nucname import logging + # Adding logging support logger = logging.getLogger(__name__) @@ -43,37 +44,62 @@ def read_density_file(filename, filetype): the array containing the densities """ - file_parsers = {'artis': read_artis_density, - 'simple_ascii': read_simple_ascii_density, - 'cmfgen_model': read_cmfgen_density} + file_parsers = { + "artis": read_artis_density, + "simple_ascii": read_simple_ascii_density, + "cmfgen_model": read_cmfgen_density, + } electron_densities = None temperature = None - if filetype == 'cmfgen_model': - (time_of_model, velocity, - unscaled_mean_densities, electron_densities, temperature) = read_cmfgen_density(filename) + if filetype == "cmfgen_model": + ( + time_of_model, + velocity, + unscaled_mean_densities, + electron_densities, + temperature, + ) = read_cmfgen_density(filename) else: - (time_of_model, velocity, - unscaled_mean_densities) = file_parsers[filetype](filename) + (time_of_model, velocity, unscaled_mean_densities) = file_parsers[ + filetype + ](filename) v_inner = velocity[:-1] v_outer = velocity[1:] invalid_volume_mask = (v_outer - v_inner) <= 0 if invalid_volume_mask.sum() > 0: - message = "\n".join(["cell {0:d}: v_inner {1:s}, v_outer " - "{2:s}".format(i, v_inner_i, v_outer_i) for i, - v_inner_i, v_outer_i in - zip(np.arange(len(v_outer))[invalid_volume_mask], - v_inner[invalid_volume_mask], - v_outer[invalid_volume_mask])]) - raise ConfigurationError("Invalid volume of following cell(s):\n" - "{:s}".format(message)) - - return time_of_model, velocity, unscaled_mean_densities, electron_densities, temperature - -def read_abundances_file(abundance_filename, abundance_filetype, - inner_boundary_index=None, outer_boundary_index=None): + message = "\n".join( + [ + "cell {0:d}: v_inner {1:s}, v_outer " + "{2:s}".format(i, v_inner_i, v_outer_i) + for i, v_inner_i, v_outer_i in zip( + np.arange(len(v_outer))[invalid_volume_mask], + v_inner[invalid_volume_mask], + v_outer[invalid_volume_mask], + ) + ] + ) + raise ConfigurationError( + "Invalid volume of following cell(s):\n" "{:s}".format(message) + ) + + return ( + time_of_model, + velocity, + unscaled_mean_densities, + electron_densities, + temperature, + ) + + +def read_abundances_file( + abundance_filename, + abundance_filetype, + inner_boundary_index=None, + outer_boundary_index=None, +): """ read different density file formats @@ -95,25 +121,29 @@ def read_abundances_file(abundance_filename, abundance_filetype, """ - file_parsers = {'simple_ascii': read_simple_ascii_abundances, - 'artis': read_simple_ascii_abundances, - 'cmfgen_model': read_cmfgen_composition, - 'custom_composition': read_csv_composition} + file_parsers = { + "simple_ascii": read_simple_ascii_abundances, + "artis": read_simple_ascii_abundances, + "cmfgen_model": read_cmfgen_composition, + "custom_composition": read_csv_composition, + } isotope_abundance = pd.DataFrame() if abundance_filetype in ["cmfgen_model", "custom_composition"]: index, abundances, isotope_abundance = file_parsers[abundance_filetype]( - abundance_filename) + abundance_filename + ) else: - index, abundances = file_parsers[abundance_filetype]( - abundance_filename) + index, abundances = file_parsers[abundance_filetype](abundance_filename) if outer_boundary_index is not None: outer_boundary_index_m1 = outer_boundary_index - 1 else: outer_boundary_index_m1 = None index = index[inner_boundary_index:outer_boundary_index] - abundances = abundances.loc[:, slice(inner_boundary_index, outer_boundary_index_m1)] + abundances = abundances.loc[ + :, slice(inner_boundary_index, outer_boundary_index_m1) + ] abundances.columns = np.arange(len(abundances.columns)) return index, abundances, isotope_abundance @@ -131,37 +161,45 @@ def read_uniform_abundances(abundances_section, no_of_shells): abundance: ~pandas.DataFrame isotope_abundance: ~pandas.DataFrame """ - abundance = pd.DataFrame(columns=np.arange(no_of_shells), - index=pd.Index(np.arange(1, 120), - name='atomic_number'), - dtype=np.float64) + abundance = pd.DataFrame( + columns=np.arange(no_of_shells), + index=pd.Index(np.arange(1, 120), name="atomic_number"), + dtype=np.float64, + ) isotope_index = pd.MultiIndex( - [[]] * 2, [[]] * 2, names=['atomic_number', 'mass_number']) - isotope_abundance = pd.DataFrame(columns=np.arange(no_of_shells), - index=isotope_index, - dtype=np.float64) + [[]] * 2, [[]] * 2, names=["atomic_number", "mass_number"] + ) + isotope_abundance = pd.DataFrame( + columns=np.arange(no_of_shells), index=isotope_index, dtype=np.float64 + ) for element_symbol_string in abundances_section: - if element_symbol_string == 'type': + if element_symbol_string == "type": continue try: if element_symbol_string in nucname.name_zz: z = nucname.name_zz[element_symbol_string] abundance.loc[z] = float( - abundances_section[element_symbol_string]) + abundances_section[element_symbol_string] + ) else: mass_no = nucname.anum(element_symbol_string) z = nucname.znum(element_symbol_string) isotope_abundance.loc[(z, mass_no), :] = float( - abundances_section[element_symbol_string]) + abundances_section[element_symbol_string] + ) except RuntimeError as err: raise RuntimeError( - "Abundances are not defined properly in config file : {}".format(err.args)) + "Abundances are not defined properly in config file : {}".format( + err.args + ) + ) return abundance, isotope_abundance + def read_simple_ascii_density(fname): """ Reading a density file of the following structure (example; lines starting with a hash will be ignored): @@ -192,14 +230,18 @@ def read_simple_ascii_density(fname): time_of_model_string = fh.readline().strip() time_of_model = parse_quantity(time_of_model_string) - data = recfromtxt(fname, skip_header=1, - names=('index', 'velocity', 'density'), - dtype=(int, float, float)) - velocity = (data['velocity'] * u.km / u.s).to('cm/s') - mean_density = (data['density'] * u.Unit('g/cm^3'))[1:] + data = recfromtxt( + fname, + skip_header=1, + names=("index", "velocity", "density"), + dtype=(int, float, float), + ) + velocity = (data["velocity"] * u.km / u.s).to("cm/s") + mean_density = (data["density"] * u.Unit("g/cm^3"))[1:] return time_of_model, velocity, mean_density + def read_artis_density(fname): """ Reading a density file of the following structure (example; lines starting with a hash will be ignored): @@ -231,18 +273,31 @@ def read_artis_density(fname): if i == 0: no_of_shells = np.int64(line.strip()) elif i == 1: - time_of_model = u.Quantity(float(line.strip()), 'day').to('s') + time_of_model = u.Quantity(float(line.strip()), "day").to("s") elif i == 2: break - artis_model_columns = ['index', 'velocities', 'mean_densities_0', 'ni56_fraction', 'co56_fraction', 'fe52_fraction', - 'cr48_fraction'] - artis_model = recfromtxt(fname, skip_header=2, usecols=(0, 1, 2, 4, 5, 6, 7), unpack=True, - dtype=[(item, np.float64) for item in artis_model_columns]) - - - velocity = u.Quantity(artis_model['velocities'], 'km/s').to('cm/s') - mean_density = u.Quantity(10 ** artis_model['mean_densities_0'], 'g/cm^3')[1:] + artis_model_columns = [ + "index", + "velocities", + "mean_densities_0", + "ni56_fraction", + "co56_fraction", + "fe52_fraction", + "cr48_fraction", + ] + artis_model = recfromtxt( + fname, + skip_header=2, + usecols=(0, 1, 2, 4, 5, 6, 7), + unpack=True, + dtype=[(item, np.float64) for item in artis_model_columns], + ) + + velocity = u.Quantity(artis_model["velocities"], "km/s").to("cm/s") + mean_density = u.Quantity(10 ** artis_model["mean_densities_0"], "g/cm^3")[ + 1: + ] return time_of_model, velocity, mean_density @@ -282,26 +337,35 @@ def read_cmfgen_density(fname): temperature: ~np.ndarray """ - warnings.warn("The current CMFGEN model parser is deprecated", - DeprecationWarning) + warnings.warn( + "The current CMFGEN model parser is deprecated", DeprecationWarning + ) - df = pd.read_csv(fname, comment='#', delimiter=r'\s+', skiprows=[0, 2]) + df = pd.read_csv(fname, comment="#", delimiter=r"\s+", skiprows=[0, 2]) with open(fname) as fh: for row_index, line in enumerate(fh): if row_index == 0: - time_of_model_string = line.strip().replace('t0:', '') + time_of_model_string = line.strip().replace("t0:", "") time_of_model = parse_quantity(time_of_model_string) elif row_index == 2: quantities = line.split() - velocity = u.Quantity(df['velocity'].values, quantities[1]).to('cm/s') - temperature = u.Quantity(df['temperature'].values, quantities[2])[1:] - mean_density = u.Quantity(df['densities'].values, quantities[3])[1:] + velocity = u.Quantity(df["velocity"].values, quantities[1]).to("cm/s") + temperature = u.Quantity(df["temperature"].values, quantities[2])[1:] + mean_density = u.Quantity(df["densities"].values, quantities[3])[1:] electron_densities = u.Quantity( - df['electron_densities'].values, quantities[4])[1:] + df["electron_densities"].values, quantities[4] + )[1:] + + return ( + time_of_model, + velocity, + mean_density, + electron_densities, + temperature, + ) - return time_of_model, velocity, mean_density, electron_densities, temperature def read_simple_ascii_abundances(fname): """ @@ -327,13 +391,15 @@ def read_simple_ascii_abundances(fname): """ data = np.loadtxt(fname) - index = data[1:,0].astype(int) - abundances = pd.DataFrame(data[1:,1:].transpose(), index=np.arange(1, data.shape[1])) + index = data[1:, 0].astype(int) + abundances = pd.DataFrame( + data[1:, 1:].transpose(), index=np.arange(1, data.shape[1]) + ) return index, abundances -def read_cmfgen_composition(fname, delimiter=r'\s+'): +def read_cmfgen_composition(fname, delimiter=r"\s+"): """Read composition from a CMFGEN model file The CMFGEN file format contains information about the ejecta state in the @@ -346,14 +412,16 @@ def read_cmfgen_composition(fname, delimiter=r'\s+'): filename of the csv file """ - warnings.warn("The current CMFGEN model parser is deprecated", - DeprecationWarning) + warnings.warn( + "The current CMFGEN model parser is deprecated", DeprecationWarning + ) - return read_csv_isotope_abundances(fname, delimiter=delimiter, - skip_columns=4, skip_rows=[0, 2, 3]) + return read_csv_isotope_abundances( + fname, delimiter=delimiter, skip_columns=4, skip_rows=[0, 2, 3] + ) -def read_csv_composition(fname, delimiter=r'\s+'): +def read_csv_composition(fname, delimiter=r"\s+"): """Read composition from a simple CSV file The CSV file can contain specific isotopes or elemental abundances in the @@ -369,12 +437,14 @@ def read_csv_composition(fname, delimiter=r'\s+'): filename of the csv file """ - return read_csv_isotope_abundances(fname, delimiter=delimiter, - skip_columns=0, skip_rows=[1]) + return read_csv_isotope_abundances( + fname, delimiter=delimiter, skip_columns=0, skip_rows=[1] + ) -def read_csv_isotope_abundances(fname, delimiter=r'\s+', skip_columns=0, - skip_rows=[1]): +def read_csv_isotope_abundances( + fname, delimiter=r"\s+", skip_columns=0, skip_rows=[1] +): """ A generic parser for a TARDIS composition stored as a CSV file @@ -413,20 +483,23 @@ def read_csv_isotope_abundances(fname, delimiter=r'\s+', skip_columns=0, isotope_abundance: ~pandas.MultiIndex """ - df = pd.read_csv(fname, comment='#', - sep=delimiter, skiprows=skip_rows, index_col=0) + df = pd.read_csv( + fname, comment="#", sep=delimiter, skiprows=skip_rows, index_col=0 + ) df = df.transpose() - abundance = pd.DataFrame(columns=np.arange(df.shape[1]), - index=pd.Index([], - name='atomic_number'), - dtype=np.float64) + abundance = pd.DataFrame( + columns=np.arange(df.shape[1]), + index=pd.Index([], name="atomic_number"), + dtype=np.float64, + ) isotope_index = pd.MultiIndex( - [[]] * 2, [[]] * 2, names=['atomic_number', 'mass_number']) - isotope_abundance = pd.DataFrame(columns=np.arange(df.shape[1]), - index=isotope_index, - dtype=np.float64) + [[]] * 2, [[]] * 2, names=["atomic_number", "mass_number"] + ) + isotope_abundance = pd.DataFrame( + columns=np.arange(df.shape[1]), index=isotope_index, dtype=np.float64 + ) for element_symbol_string in df.index[skip_columns:]: if element_symbol_string in nucname.name_zz: @@ -435,11 +508,13 @@ def read_csv_isotope_abundances(fname, delimiter=r'\s+', skip_columns=0, else: z = nucname.znum(element_symbol_string) mass_no = nucname.anum(element_symbol_string) - isotope_abundance.loc[( - z, mass_no), :] = df.loc[element_symbol_string].tolist() + isotope_abundance.loc[(z, mass_no), :] = df.loc[ + element_symbol_string + ].tolist() return abundance.index, abundance, isotope_abundance + def parse_csv_abundances(csvy_data): """ A parser for the csv data part of a csvy model file. This function filters out columns that are not abundances. @@ -457,21 +532,27 @@ def parse_csv_abundances(csvy_data): isotope_abundance : ~pandas.MultiIndex """ - abundance_col_names = [name for name in csvy_data.columns if nucname.iselement(name) or nucname.isnuclide(name)] + abundance_col_names = [ + name + for name in csvy_data.columns + if nucname.iselement(name) or nucname.isnuclide(name) + ] df = csvy_data.loc[:, abundance_col_names] - + df = df.transpose() - abundance = pd.DataFrame(columns=np.arange(df.shape[1]), - index=pd.Index([], - name='atomic_number'), - dtype=np.float64) + abundance = pd.DataFrame( + columns=np.arange(df.shape[1]), + index=pd.Index([], name="atomic_number"), + dtype=np.float64, + ) isotope_index = pd.MultiIndex( - [[]] * 2, [[]] * 2, names=['atomic_number', 'mass_number']) - isotope_abundance = pd.DataFrame(columns=np.arange(df.shape[1]), - index=isotope_index, - dtype=np.float64) + [[]] * 2, [[]] * 2, names=["atomic_number", "mass_number"] + ) + isotope_abundance = pd.DataFrame( + columns=np.arange(df.shape[1]), index=isotope_index, dtype=np.float64 + ) for element_symbol_string in df.index[0:]: if element_symbol_string in nucname.name_zz: @@ -480,7 +561,8 @@ def parse_csv_abundances(csvy_data): else: z = nucname.znum(element_symbol_string) mass_no = nucname.anum(element_symbol_string) - isotope_abundance.loc[( - z, mass_no), :] = df.loc[element_symbol_string].tolist() + isotope_abundance.loc[(z, mass_no), :] = df.loc[ + element_symbol_string + ].tolist() return abundance.index, abundance, isotope_abundance diff --git a/tardis/io/parsers/__init__.py b/tardis/io/parsers/__init__.py index a5bdddb0a28..ef3825cde24 100644 --- a/tardis/io/parsers/__init__.py +++ b/tardis/io/parsers/__init__.py @@ -1,3 +1,4 @@ -from tardis.io.parsers.blondin_toymodel import (read_blondin_toymodel, - convert_blondin_toymodel) - +from tardis.io.parsers.blondin_toymodel import ( + read_blondin_toymodel, + convert_blondin_toymodel, +) diff --git a/tardis/io/parsers/blondin_toymodel.py b/tardis/io/parsers/blondin_toymodel.py index b6dece9f488..de9ff79cb03 100644 --- a/tardis/io/parsers/blondin_toymodel.py +++ b/tardis/io/parsers/blondin_toymodel.py @@ -9,8 +9,8 @@ from tardis.util.base import parse_quantity -PATTERN_REMOVE_BRACKET = re.compile(r'\[.+\]') -T0_PATTERN = re.compile('tend = (.+)\n') +PATTERN_REMOVE_BRACKET = re.compile(r"\[.+\]") +T0_PATTERN = re.compile("tend = (.+)\n") def read_blondin_toymodel(fname): @@ -32,64 +32,95 @@ def read_blondin_toymodel(fname): blondin_csv: pandas.DataFrame DataFrame containing the csv part of the toymodel """ - with open(fname, 'r') as fh: + with open(fname, "r") as fh: for line in fh: if line.startswith("#idx"): break else: raise ValueError( - 'File {0} does not conform to Toy Model format as it does ' - 'not contain #idx') - columns = [PATTERN_REMOVE_BRACKET.sub('', item) for item in - line[1:].split()] - - raw_blondin_csv = pd.read_csv(fname, delim_whitespace=True, comment='#', - header=None, names=columns) - raw_blondin_csv.set_index('idx', inplace=True) - - blondin_csv = raw_blondin_csv.loc[:, - ['vel', 'dens', 'temp', 'X_56Ni0', 'X_Ti', 'X_Ca', 'X_S', - 'X_Si', 'X_O', 'X_C']] - rename_col_dict = {'vel': 'velocity', 'dens': 'density', - 'temp': 't_electron'} + "File {0} does not conform to Toy Model format as it does " + "not contain #idx" + ) + columns = [ + PATTERN_REMOVE_BRACKET.sub("", item) for item in line[1:].split() + ] + + raw_blondin_csv = pd.read_csv( + fname, delim_whitespace=True, comment="#", header=None, names=columns + ) + raw_blondin_csv.set_index("idx", inplace=True) + + blondin_csv = raw_blondin_csv.loc[ + :, + [ + "vel", + "dens", + "temp", + "X_56Ni0", + "X_Ti", + "X_Ca", + "X_S", + "X_Si", + "X_O", + "X_C", + ], + ] + rename_col_dict = { + "vel": "velocity", + "dens": "density", + "temp": "t_electron", + } rename_col_dict.update({item: item[2:] for item in blondin_csv.columns[3:]}) - rename_col_dict['X_56Ni0'] = 'Ni56' + rename_col_dict["X_56Ni0"] = "Ni56" blondin_csv.rename(columns=rename_col_dict, inplace=True) blondin_csv.iloc[:, 3:] = blondin_csv.iloc[:, 3:].divide( - blondin_csv.iloc[:, 3:].sum(axis=1), axis=0) + blondin_csv.iloc[:, 3:].sum(axis=1), axis=0 + ) # changing velocities to outer boundary - new_velocities = 0.5 * (blondin_csv.velocity.iloc[ - :-1].values + blondin_csv.velocity.iloc[1:].values) + new_velocities = 0.5 * ( + blondin_csv.velocity.iloc[:-1].values + + blondin_csv.velocity.iloc[1:].values + ) new_velocities = np.hstack( - (new_velocities, [2 * new_velocities[-1] - new_velocities[-2]])) - blondin_csv['velocity'] = new_velocities + (new_velocities, [2 * new_velocities[-1] - new_velocities[-2]]) + ) + blondin_csv["velocity"] = new_velocities - with open(fname, 'r') as fh: + with open(fname, "r") as fh: t0_string = T0_PATTERN.findall(fh.read())[0] - t0 = parse_quantity(t0_string.replace('DAYS', 'day')) + t0 = parse_quantity(t0_string.replace("DAYS", "day")) blondin_dict = {} - blondin_dict['model_density_time_0'] = str(t0) - blondin_dict['description'] = 'Converted {0} to csvy format'.format(fname) - blondin_dict['tardis_model_config_version'] = 'v1.0' - blondin_dict_fields = [dict(name='velocity', unit='km/s', - desc='velocities of shell outer bounderies.')] + blondin_dict["model_density_time_0"] = str(t0) + blondin_dict["description"] = "Converted {0} to csvy format".format(fname) + blondin_dict["tardis_model_config_version"] = "v1.0" + blondin_dict_fields = [ + dict( + name="velocity", + unit="km/s", + desc="velocities of shell outer bounderies.", + ) + ] blondin_dict_fields.append( - dict(name='density', unit='g/cm^3', desc='mean density of shell.')) + dict(name="density", unit="g/cm^3", desc="mean density of shell.") + ) blondin_dict_fields.append( - dict(name='t_electron', unit='K', desc='electron temperature.')) + dict(name="t_electron", unit="K", desc="electron temperature.") + ) for abund in blondin_csv.columns[3:]: blondin_dict_fields.append( - dict(name=abund, desc='Fraction {0} abundance'.format(abund))) - blondin_dict['datatype'] = {'fields': blondin_dict_fields} + dict(name=abund, desc="Fraction {0} abundance".format(abund)) + ) + blondin_dict["datatype"] = {"fields": blondin_dict_fields} return blondin_dict, blondin_csv -def convert_blondin_toymodel(in_fname, out_fname, v_inner, v_outer, - conversion_t_electron_rad=None): +def convert_blondin_toymodel( + in_fname, out_fname, v_inner, v_outer, conversion_t_electron_rad=None +): """ Parameters @@ -109,25 +140,29 @@ def convert_blondin_toymodel(in_fname, out_fname, v_inner, v_outer, outer boundary velocity. If float will be interpreted as km/s """ blondin_dict, blondin_csv = read_blondin_toymodel(in_fname) - blondin_dict['v_inner_boundary'] = str(u.Quantity(v_inner, u.km / u.s)) - blondin_dict['v_outer_boundary'] = str(u.Quantity(v_outer, u.km / u.s)) + blondin_dict["v_inner_boundary"] = str(u.Quantity(v_inner, u.km / u.s)) + blondin_dict["v_outer_boundary"] = str(u.Quantity(v_outer, u.km / u.s)) if conversion_t_electron_rad is not None: - blondin_dict['datatype']['fields'].append({ - 'desc': - 'converted radiation temperature ' - 'using multiplicative factor={0}'.format( - conversion_t_electron_rad), - 'name': 't_rad', 'unit': 'K'}) - - blondin_csv['t_rad'] = (conversion_t_electron_rad * - blondin_csv.t_electron) - - - csvy_file = '---\n{0}\n---\n{1}'.format( + blondin_dict["datatype"]["fields"].append( + { + "desc": "converted radiation temperature " + "using multiplicative factor={0}".format( + conversion_t_electron_rad + ), + "name": "t_rad", + "unit": "K", + } + ) + + blondin_csv["t_rad"] = ( + conversion_t_electron_rad * blondin_csv.t_electron + ) + + csvy_file = "---\n{0}\n---\n{1}".format( yaml.dump(blondin_dict, default_flow_style=False), - blondin_csv.to_csv(index=False)) + blondin_csv.to_csv(index=False), + ) - with open(out_fname, 'w') as fh: + with open(out_fname, "w") as fh: fh.write(csvy_file) - diff --git a/tardis/io/parsers/csvy.py b/tardis/io/parsers/csvy.py index f0e23e81552..fcb5c175019 100644 --- a/tardis/io/parsers/csvy.py +++ b/tardis/io/parsers/csvy.py @@ -2,7 +2,9 @@ import pandas as pd from tardis.io.util import YAMLLoader -YAML_DELIMITER = '---' +YAML_DELIMITER = "---" + + def load_csvy(fname): """ Parameters @@ -24,14 +26,16 @@ def load_csvy(fname): yaml_end_ind = -1 for i, line in enumerate(fh): if i == 0: - assert line.strip() == YAML_DELIMITER, 'First line of csvy file is not \'---\'' + assert ( + line.strip() == YAML_DELIMITER + ), "First line of csvy file is not '---'" yaml_lines.append(line) if i > 0 and line.strip() == YAML_DELIMITER: yaml_end_ind = i break else: - raise ValueError('End %s not found'%(YAML_DELIMITER)) - yaml_dict = yaml.load(''.join(yaml_lines[1:-1]), YAMLLoader) + raise ValueError("End %s not found" % (YAML_DELIMITER)) + yaml_dict = yaml.load("".join(yaml_lines[1:-1]), YAMLLoader) try: data = pd.read_csv(fname, skiprows=yaml_end_ind + 1) except pd.io.common.EmptyDataError as e: @@ -39,6 +43,7 @@ def load_csvy(fname): return yaml_dict, data + def load_yaml_from_csvy(fname): """ Parameters @@ -57,16 +62,19 @@ def load_yaml_from_csvy(fname): yaml_end_ind = -1 for i, line in enumerate(fh): if i == 0: - assert line.strip() == YAML_DELIMITER, 'First line of csvy file is not \'---\'' + assert ( + line.strip() == YAML_DELIMITER + ), "First line of csvy file is not '---'" yaml_lines.append(line) if i > 0 and line.strip() == YAML_DELIMITER: yaml_end_ind = i break else: - raise ValueError('End %s not found'%(YAML_DELIMITER)) - yaml_dict = yaml.load(''.join(yaml_lines[1:-1]), YAMLLoader) + raise ValueError("End %s not found" % (YAML_DELIMITER)) + yaml_dict = yaml.load("".join(yaml_lines[1:-1]), YAMLLoader) return yaml_dict + def load_csv_from_csvy(fname): """ Parameters diff --git a/tardis/io/parsers/stella.py b/tardis/io/parsers/stella.py index 8d0119bf730..b102e4c1897 100644 --- a/tardis/io/parsers/stella.py +++ b/tardis/io/parsers/stella.py @@ -3,14 +3,18 @@ from astropy import units as u import numpy as np + def read_stella_data(filename): with open(filename) as fh: col = fh.readlines()[5] - col_names = re.split(r'\s{3,}', col.strip()) - col_names = [re.sub(r'\s\(.+\)', '', col_name).replace(' ', '_') for - col_name in col_names] - data = pd.read_csv(filename, skiprows=7, delim_whitespace=True, - names = col_names) + col_names = re.split(r"\s{3,}", col.strip()) + col_names = [ + re.sub(r"\s\(.+\)", "", col_name).replace(" ", "_") + for col_name in col_names + ] + data = pd.read_csv( + filename, skiprows=7, delim_whitespace=True, names=col_names + ) # drop last row of data data = data.iloc[0:-1] diff --git a/tardis/io/setup_package.py b/tardis/io/setup_package.py index 4c9cbe860ba..c12c8607825 100644 --- a/tardis/io/setup_package.py +++ b/tardis/io/setup_package.py @@ -1,6 +1,24 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst + def get_package_data(): - return {'tardis.io.tests': ['data/*.dat', 'data/*.yml', 'data/*.csv', 'data/*.csvy'], - 'tardis.model.tests' : ['data/*.dat', 'data/*.yml', 'data/*.csv', 'data/*.csvy'], - 'tardis.plasma.tests' : ['data/*.dat', 'data/*.yml', 'data/*.csv', 'data/*.txt']} + return { + "tardis.io.tests": [ + "data/*.dat", + "data/*.yml", + "data/*.csv", + "data/*.csvy", + ], + "tardis.model.tests": [ + "data/*.dat", + "data/*.yml", + "data/*.csv", + "data/*.csvy", + ], + "tardis.plasma.tests": [ + "data/*.dat", + "data/*.yml", + "data/*.csv", + "data/*.txt", + ], + } diff --git a/tardis/io/tests/test_HDFWriter.py b/tardis/io/tests/test_HDFWriter.py index e072e873ab1..eafd59228de 100644 --- a/tardis/io/tests/test_HDFWriter.py +++ b/tardis/io/tests/test_HDFWriter.py @@ -11,123 +11,155 @@ from tardis.io.util import HDFWriterMixin -#Test Cases +# Test Cases + +# DataFrame +# None +# Numpy Arrays +# Strings +# Numeric Values +# Pandas Series Object +# MultiIndex Object +# Quantity Objects with - Numeric Values, Numpy Arrays, DataFrame, Pandas Series, None objects -#DataFrame -#None -#Numpy Arrays -#Strings -#Numeric Values -#Pandas Series Object -#MultiIndex Object -#Quantity Objects with - Numeric Values, Numpy Arrays, DataFrame, Pandas Series, None objects class MockHDF(HDFWriterMixin): - hdf_properties = ['property'] + hdf_properties = ["property"] def __init__(self, property): self.property = property -simple_objects = [1.5, 'random_string', 4.2e7] + +simple_objects = [1.5, "random_string", 4.2e7] + @pytest.mark.parametrize("attr", simple_objects) def test_simple_write(tmpdir, attr): - fname = str(tmpdir.mkdir('data').join('test.hdf')) + fname = str(tmpdir.mkdir("data").join("test.hdf")) actual = MockHDF(attr) - actual.to_hdf(fname, path='test') - expected = pd.read_hdf(fname, key='/test/mock_hdf/scalars')['property'] + actual.to_hdf(fname, path="test") + expected = pd.read_hdf(fname, key="/test/mock_hdf/scalars")["property"] assert actual.property == expected -mock_df = pd.DataFrame({'one': pd.Series([1., 2., 3.], index=['a', 'b', 'c']), - 'two': pd.Series([1., 2., 3., 4.], index=['a', 'b', 'c', 'd'])}) -complex_objects = [np.array([4.0e14, 2, 2e14, 27.5]), - pd.Series([1., 2., 3.]), mock_df] + +mock_df = pd.DataFrame( + { + "one": pd.Series([1.0, 2.0, 3.0], index=["a", "b", "c"]), + "two": pd.Series([1.0, 2.0, 3.0, 4.0], index=["a", "b", "c", "d"]), + } +) +complex_objects = [ + np.array([4.0e14, 2, 2e14, 27.5]), + pd.Series([1.0, 2.0, 3.0]), + mock_df, +] + @pytest.mark.parametrize("attr", complex_objects) def test_complex_obj_write(tmpdir, attr): - fname = str(tmpdir.mkdir('data').join('test.hdf')) + fname = str(tmpdir.mkdir("data").join("test.hdf")) actual = MockHDF(attr) - actual.to_hdf(fname, path='test') - expected = pd.read_hdf(fname, key='/test/mock_hdf/property').values + actual.to_hdf(fname, path="test") + expected = pd.read_hdf(fname, key="/test/mock_hdf/property").values assert_array_almost_equal(actual.property, expected) -arr = np.array([['L1', 'L1', 'L2', 'L2', 'L3', 'L3', 'L4', 'L4'], - ['one', 'two', 'one', 'two', 'one', 'two', 'one', 'two']]) + +arr = np.array( + [ + ["L1", "L1", "L2", "L2", "L3", "L3", "L4", "L4"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] +) mock_multiIndex = pd.MultiIndex.from_arrays(arr.transpose()) + def test_MultiIndex_write(tmpdir): - fname = str(tmpdir.mkdir('data').join('test.hdf')) + fname = str(tmpdir.mkdir("data").join("test.hdf")) actual = MockHDF(mock_multiIndex) - actual.to_hdf(fname, path='test') - expected = pd.read_hdf(fname, key='/test/mock_hdf/property') + actual.to_hdf(fname, path="test") + expected = pd.read_hdf(fname, key="/test/mock_hdf/property") expected = pd.MultiIndex.from_tuples(expected.unstack().values) pdt.assert_almost_equal(actual.property, expected) -#Test Quantity Objects + +# Test Quantity Objects quantity_objects = [np.array([4.0e14, 2, 2e14, 27.5]), mock_df] + @pytest.mark.parametrize("attr", quantity_objects) def test_quantity_objects_write(tmpdir, attr): - fname = str(tmpdir.mkdir('data').join('test.hdf')) - attr_quantity = u.Quantity(attr, 'g/cm**3') + fname = str(tmpdir.mkdir("data").join("test.hdf")) + attr_quantity = u.Quantity(attr, "g/cm**3") actual = MockHDF(attr_quantity) - actual.to_hdf(fname, path='test') - expected = pd.read_hdf(fname, key='/test/mock_hdf/property') + actual.to_hdf(fname, path="test") + expected = pd.read_hdf(fname, key="/test/mock_hdf/property") assert_array_almost_equal(actual.property.cgs.value, expected) + scalar_quantity_objects = [1.5, 4.2e7] + @pytest.mark.parametrize("attr", scalar_quantity_objects) def test_scalar_quantity_objects_write(tmpdir, attr): - fname = str(tmpdir.mkdir('data').join('test.hdf')) - attr_quantity = u.Quantity(attr, 'g/cm**3') + fname = str(tmpdir.mkdir("data").join("test.hdf")) + attr_quantity = u.Quantity(attr, "g/cm**3") actual = MockHDF(attr_quantity) - actual.to_hdf(fname, path='test') - expected = pd.read_hdf(fname, key='/test/mock_hdf/scalars/')['property'] + actual.to_hdf(fname, path="test") + expected = pd.read_hdf(fname, key="/test/mock_hdf/scalars/")["property"] assert_array_almost_equal(actual.property.cgs.value, expected) + def test_none_write(tmpdir): - fname = str(tmpdir.mkdir('data').join('test.hdf')) + fname = str(tmpdir.mkdir("data").join("test.hdf")) actual = MockHDF(None) - actual.to_hdf(fname, path='test') - expected = pd.read_hdf(fname, key='/test/mock_hdf/scalars/')['property'] - if expected == 'none': + actual.to_hdf(fname, path="test") + expected = pd.read_hdf(fname, key="/test/mock_hdf/scalars/")["property"] + if expected == "none": expected = None assert actual.property == expected + # Test class_properties parameter (like homologous_density is a class # instance/object inside Model class) + class MockClass(HDFWriterMixin): - hdf_properties = ['property', 'nested_object'] + hdf_properties = ["property", "nested_object"] def __init__(self, property, nested_object): self.property = property self.nested_object = nested_object + @pytest.mark.parametrize("attr", quantity_objects) def test_objects_write(tmpdir, attr): - fname = str(tmpdir.mkdir('data').join('test.hdf')) + fname = str(tmpdir.mkdir("data").join("test.hdf")) nested_object = MockHDF(np.array([4.0e14, 2, 2e14, 27.5])) - attr_quantity = u.Quantity(attr, 'g/cm**3') + attr_quantity = u.Quantity(attr, "g/cm**3") actual = MockClass(attr_quantity, nested_object) - actual.to_hdf(fname, path='test') - expected_property = pd.read_hdf(fname, key='/test/mock_class/property') + actual.to_hdf(fname, path="test") + expected_property = pd.read_hdf(fname, key="/test/mock_class/property") assert_array_almost_equal(actual.property.cgs.value, expected_property) nested_property = pd.read_hdf( - fname, key='/test/mock_class/nested_object/property') - assert_array_almost_equal( - actual.nested_object.property, nested_property) + fname, key="/test/mock_class/nested_object/property" + ) + assert_array_almost_equal(actual.nested_object.property, nested_property) def test_snake_case(): - assert MockHDF.convert_to_snake_case( - "HomologousDensity") == "homologous_density" + assert ( + MockHDF.convert_to_snake_case("HomologousDensity") + == "homologous_density" + ) assert MockHDF.convert_to_snake_case("TARDISSpectrum") == "tardis_spectrum" assert MockHDF.convert_to_snake_case("BasePlasma") == "base_plasma" assert MockHDF.convert_to_snake_case("LTEPlasma") == "lte_plasma" - assert MockHDF.convert_to_snake_case( - "MonteCarloRunner") == "monte_carlo_runner" - assert MockHDF.convert_to_snake_case( - "homologous_density") == "homologous_density" + assert ( + MockHDF.convert_to_snake_case("MonteCarloRunner") + == "monte_carlo_runner" + ) + assert ( + MockHDF.convert_to_snake_case("homologous_density") + == "homologous_density" + ) diff --git a/tardis/io/tests/test_ascii_readers.py b/tardis/io/tests/test_ascii_readers.py index 0e420291dbe..12f7519e9ca 100644 --- a/tardis/io/tests/test_ascii_readers.py +++ b/tardis/io/tests/test_ascii_readers.py @@ -13,29 +13,36 @@ def data_path(filename): data_dir = os.path.dirname(__file__) - return os.path.join(data_dir, 'data', filename) + return os.path.join(data_dir, "data", filename) def test_simple_ascii_density_reader_time(): - time_model, velocity, density = io.read_simple_ascii_density(data_path('tardis_simple_ascii_density_test.dat')) + time_model, velocity, density = io.read_simple_ascii_density( + data_path("tardis_simple_ascii_density_test.dat") + ) - assert time_model.unit.physical_type == 'time' + assert time_model.unit.physical_type == "time" npt.assert_almost_equal(time_model.to(u.day).value, 1.0) + def test_simple_ascii_density_reader_data(): - time_model, velocity, density = io.read_simple_ascii_density(data_path('tardis_simple_ascii_density_test.dat')) - assert velocity.unit == u.Unit('cm/s') + time_model, velocity, density = io.read_simple_ascii_density( + data_path("tardis_simple_ascii_density_test.dat") + ) + assert velocity.unit == u.Unit("cm/s") - npt.assert_allclose(velocity[3].value, 1.3e4*1e5) + npt.assert_allclose(velocity[3].value, 1.3e4 * 1e5) def test_simple_ascii_abundance_reader(): - index, abundances = io.read_simple_ascii_abundances(data_path('artis_abundances.dat')) + index, abundances = io.read_simple_ascii_abundances( + data_path("artis_abundances.dat") + ) npt.assert_almost_equal(abundances.loc[1, 0], 1.542953e-08) npt.assert_almost_equal(abundances.loc[14, 54], 0.21864420000000001) def test_ascii_reader_invalid_volumes(): with pytest.raises(io.model_reader.ConfigurationError): - io.read_density_file(data_path('invalid_artis_model.dat'), 'artis') + io.read_density_file(data_path("invalid_artis_model.dat"), "artis") diff --git a/tardis/io/tests/test_atomic.py b/tardis/io/tests/test_atomic.py index 75f77ded901..70de086d9b0 100644 --- a/tardis/io/tests/test_atomic.py +++ b/tardis/io/tests/test_atomic.py @@ -26,37 +26,37 @@ def lines(kurucz_atomic_data): def test_atom_data_basic_atom_data(basic_atom_data): - assert basic_atom_data.loc[2, 'symbol'] == 'He' + assert basic_atom_data.loc[2, "symbol"] == "He" assert_quantity_allclose( - basic_atom_data.at[2, 'mass'] * u.Unit('g'), - 4.002602 * const.u.cgs - ) + basic_atom_data.at[2, "mass"] * u.Unit("g"), 4.002602 * const.u.cgs + ) def test_atom_data_ionization_data(ionization_data): assert_quantity_allclose( - ionization_data.loc[(2, 1)] * u.Unit('erg'), - 24.587387936 * u.Unit('eV') + ionization_data.loc[(2, 1)] * u.Unit("erg"), 24.587387936 * u.Unit("eV") ) def test_atom_data_levels(levels): assert_quantity_allclose( - u.Quantity(levels.at[(2, 0, 2), 'energy'], u.Unit('erg')).to(u.Unit('cm-1'), equivalencies=u.spectral()), - 166277.542 * u.Unit('cm-1') + u.Quantity(levels.at[(2, 0, 2), "energy"], u.Unit("erg")).to( + u.Unit("cm-1"), equivalencies=u.spectral() + ), + 166277.542 * u.Unit("cm-1"), ) def test_atom_data_lines(lines): assert_quantity_allclose( - lines.at[(2, 0, 0, 6), 'wavelength_cm'] * u.Unit('cm'), - 584.335 * u.Unit('Angstrom') + lines.at[(2, 0, 0, 6), "wavelength_cm"] * u.Unit("cm"), + 584.335 * u.Unit("Angstrom"), ) def test_atomic_reprepare(kurucz_atomic_data): kurucz_atomic_data.prepare_atom_data([14, 20]) lines = kurucz_atomic_data.lines.reset_index() - assert lines['atomic_number'].isin([14, 20]).all() - assert len(lines.loc[lines['atomic_number'] == 14]) > 0 - assert len(lines.loc[lines['atomic_number'] == 20]) > 0 + assert lines["atomic_number"].isin([14, 20]).all() + assert len(lines.loc[lines["atomic_number"] == 14]) > 0 + assert len(lines.loc[lines["atomic_number"] == 20]) > 0 diff --git a/tardis/io/tests/test_config_reader.py b/tardis/io/tests/test_config_reader.py index 70034c4f493..7ebd13514e7 100644 --- a/tardis/io/tests/test_config_reader.py +++ b/tardis/io/tests/test_config_reader.py @@ -10,42 +10,50 @@ def data_path(filename): data_dir = os.path.dirname(__file__) - return os.path.abspath(os.path.join(data_dir, 'data', filename)) + return os.path.abspath(os.path.join(data_dir, "data", filename)) def test_convergence_section_parser(): - test_convergence_section = {'type': 'damped', - 'lock_t_inner_cyles': 1, - 't_inner_update_exponent': -0.5, - 'damping_constant': 0.5, - 'threshold': 0.05, - 'fraction': 0.8, - 'hold_iterations': 3, - 't_rad': {'damping_constant': 1.0}} + test_convergence_section = { + "type": "damped", + "lock_t_inner_cyles": 1, + "t_inner_update_exponent": -0.5, + "damping_constant": 0.5, + "threshold": 0.05, + "fraction": 0.8, + "hold_iterations": 3, + "t_rad": {"damping_constant": 1.0}, + } parsed_convergence_section = config_reader.parse_convergence_section( - test_convergence_section) + test_convergence_section + ) - assert_almost_equal(parsed_convergence_section['t_rad']['damping_constant'], - 1.0) + assert_almost_equal( + parsed_convergence_section["t_rad"]["damping_constant"], 1.0 + ) - assert_almost_equal(parsed_convergence_section['w']['damping_constant'], - 0.5) + assert_almost_equal( + parsed_convergence_section["w"]["damping_constant"], 0.5 + ) def test_from_config_dict(tardis_config_verysimple): - conf = Configuration.from_config_dict(tardis_config_verysimple, - validate=True, - config_dirname='test') - assert conf.config_dirname == 'test' - assert_almost_equal(conf.spectrum.start.value, - tardis_config_verysimple['spectrum']['start'].value) - assert_almost_equal(conf.spectrum.stop.value, - tardis_config_verysimple['spectrum']['stop'].value) - - tardis_config_verysimple['spectrum']['start'] = 'Invalid' + conf = Configuration.from_config_dict( + tardis_config_verysimple, validate=True, config_dirname="test" + ) + assert conf.config_dirname == "test" + assert_almost_equal( + conf.spectrum.start.value, + tardis_config_verysimple["spectrum"]["start"].value, + ) + assert_almost_equal( + conf.spectrum.stop.value, + tardis_config_verysimple["spectrum"]["stop"].value, + ) + + tardis_config_verysimple["spectrum"]["start"] = "Invalid" with pytest.raises(ValidationError): - conf = Configuration.from_config_dict(tardis_config_verysimple, - validate=True, - config_dirname='test') - + conf = Configuration.from_config_dict( + tardis_config_verysimple, validate=True, config_dirname="test" + ) diff --git a/tardis/io/tests/test_configuration_namespace.py b/tardis/io/tests/test_configuration_namespace.py index 6101eb5cf19..a44e17ec123 100644 --- a/tardis/io/tests/test_configuration_namespace.py +++ b/tardis/io/tests/test_configuration_namespace.py @@ -5,75 +5,75 @@ from numpy.testing import assert_almost_equal -simple_config_dict = {'a' : {'b' : {'param1' : 1, 'param2': [0, 1, 2 * u.km], - 'param3' : 4.0 * u.km}}} +simple_config_dict = { + "a": {"b": {"param1": 1, "param2": [0, 1, 2 * u.km], "param3": 4.0 * u.km}} +} + def data_path(filename): data_dir = os.path.dirname(__file__) - return os.path.join(data_dir, 'data', filename) - + return os.path.join(data_dir, "data", filename) def test_simple_configuration_namespace(): config_ns = ConfigurationNameSpace(simple_config_dict) assert config_ns.a.b.param1 == 1 config_ns.a.b.param1 = 2 - assert (config_ns['a']['b']['param1'] - == 2) + assert config_ns["a"]["b"]["param1"] == 2 - config_ns['a']['b']['param1'] = 3 + config_ns["a"]["b"]["param1"] = 3 assert config_ns.a.b.param1 == 3 + def test_quantity_configuration_namespace(): config_ns = ConfigurationNameSpace(simple_config_dict) config_ns.a.b.param3 = 3 - assert_almost_equal(config_ns['a']['b']['param3'].to(u.km).value, 3) - + assert_almost_equal(config_ns["a"]["b"]["param3"].to(u.km).value, 3) config_ns.a.b.param3 = 5000 * u.m - assert_almost_equal(config_ns['a']['b']['param3'].to(u.km).value, 5) - + assert_almost_equal(config_ns["a"]["b"]["param3"].to(u.km).value, 5) def test_access_with_config_item_string(): config_ns = ConfigurationNameSpace(simple_config_dict) - assert config_ns.get_config_item('a.b.param1') == 1 + assert config_ns.get_config_item("a.b.param1") == 1 - config_ns.set_config_item('a.b.param1', 2) + config_ns.set_config_item("a.b.param1", 2) assert config_ns.a.b.param1 == 2 + def test_set_with_config_item_string_quantity(): config_ns = ConfigurationNameSpace(simple_config_dict) - config_ns.set_config_item('a.b.param3', 2) + config_ns.set_config_item("a.b.param3", 2) assert_almost_equal(config_ns.a.b.param3.to(u.km).value, 2) def test_get_with_config_item_string_item_access(): config_ns = ConfigurationNameSpace(simple_config_dict) - item = config_ns.get_config_item('a.b.param2.item0') + item = config_ns.get_config_item("a.b.param2.item0") assert item == 0 - item = config_ns.get_config_item('a.b.param2.item1') + item = config_ns.get_config_item("a.b.param2.item1") assert item == 1 + def test_set_with_config_item_string_item_access(): config_ns = ConfigurationNameSpace(simple_config_dict) - config_ns.set_config_item('a.b.param2.item0', 2) + config_ns.set_config_item("a.b.param2.item0", 2) - item = config_ns.get_config_item('a.b.param2.item0') + item = config_ns.get_config_item("a.b.param2.item0") assert item == 2 + def test_set_with_config_item_string_item_access_quantity(): config_ns = ConfigurationNameSpace(simple_config_dict) - config_ns.set_config_item('a.b.param2.item2', 7 ) - - item = config_ns.get_config_item('a.b.param2.item2') - - assert_almost_equal(item.to(u.km).value ,7) + config_ns.set_config_item("a.b.param2.item2", 7) + item = config_ns.get_config_item("a.b.param2.item2") + assert_almost_equal(item.to(u.km).value, 7) def test_config_namespace_copy(): @@ -82,5 +82,6 @@ def test_config_namespace_copy(): config_ns2.a.b.param1 = 2 assert config_ns2.a.b.param1 != config_ns.a.b.param1 + def test_config_namespace_quantity_set(): - data_path('paper1_tardis_configv1.yml') \ No newline at end of file + data_path("paper1_tardis_configv1.yml") diff --git a/tardis/io/tests/test_csvy_reader.py b/tardis/io/tests/test_csvy_reader.py index e858f4ef538..7bb6e8762f3 100644 --- a/tardis/io/tests/test_csvy_reader.py +++ b/tardis/io/tests/test_csvy_reader.py @@ -8,38 +8,49 @@ import numpy.testing as npt -DATA_PATH = os.path.join(tardis.__path__[0], 'io', 'tests', 'data') +DATA_PATH = os.path.join(tardis.__path__[0], "io", "tests", "data") + @pytest.fixture def csvy_full_fname(): - return os.path.join(DATA_PATH, 'csvy_full.csvy') + return os.path.join(DATA_PATH, "csvy_full.csvy") + @pytest.fixture def csvy_nocsv_fname(): - return os.path.join(DATA_PATH, 'csvy_nocsv.csvy') + return os.path.join(DATA_PATH, "csvy_nocsv.csvy") + @pytest.fixture def csvy_missing_fname(): - return os.path.join(DATA_PATH, 'csvy_missing.csvy') + return os.path.join(DATA_PATH, "csvy_missing.csvy") + def test_csvy_finds_csv_first_line(csvy_full_fname): yaml_dict, csv = csvy.load_csvy(csvy_full_fname) - npt.assert_almost_equal(csv['velocity'][0],10000) + npt.assert_almost_equal(csv["velocity"][0], 10000) + def test_csv_colnames_equiv_datatype_fields(csvy_full_fname): yaml_dict, csv = csvy.load_csvy(csvy_full_fname) - datatype_names = [od['name'] for od in yaml_dict['datatype']['fields']] + datatype_names = [od["name"] for od in yaml_dict["datatype"]["fields"]] for key in csv.columns: assert key in datatype_names for name in datatype_names: assert name in csv.columns + def test_csvy_nocsv_data_is_none(csvy_nocsv_fname): yaml_dict, csv = csvy.load_csvy(csvy_nocsv_fname) assert csv is None + def test_missing_required_property(csvy_missing_fname): yaml_dict, csv = csvy.load_csvy(csvy_missing_fname) with pytest.raises(Exception): - vy = validate_dict(yaml_dict, schemapath=os.path.join(tardis.__path__[0], 'io', 'schemas', - 'csvy_model.yml')) + vy = validate_dict( + yaml_dict, + schemapath=os.path.join( + tardis.__path__[0], "io", "schemas", "csvy_model.yml" + ), + ) diff --git a/tardis/io/tests/test_decay.py b/tardis/io/tests/test_decay.py index 09c5abf4601..42856c47b78 100644 --- a/tardis/io/tests/test_decay.py +++ b/tardis/io/tests/test_decay.py @@ -4,12 +4,15 @@ from tardis.io.decay import IsotopeAbundances from numpy.testing import assert_almost_equal + @pytest.fixture def simple_abundance_model(): - index = pd.MultiIndex.from_tuples([(28, 56)], - names=['atomic_number', 'mass_number']) + index = pd.MultiIndex.from_tuples( + [(28, 56)], names=["atomic_number", "mass_number"] + ) return IsotopeAbundances([[1.0, 1.0]], index=index) + def test_simple_decay(simple_abundance_model): decayed_abundance = simple_abundance_model.decay(100) assert_almost_equal(decayed_abundance.loc[26, 56][0], 0.55752) @@ -19,18 +22,26 @@ def test_simple_decay(simple_abundance_model): assert_almost_equal(decayed_abundance.loc[28, 56][0], 1.1086e-05) assert_almost_equal(decayed_abundance.loc[28, 56][1], 1.1086e-05) + @pytest.fixture def raw_abundance_simple(): abundances = pd.DataFrame([[0.2, 0.2], [0.1, 0.1]], index=[28, 30]) - abundances.index.rename('atomic_number', inplace=True) + abundances.index.rename("atomic_number", inplace=True) return abundances + def test_abundance_merge(simple_abundance_model, raw_abundance_simple): decayed_df = simple_abundance_model.decay(100) isotope_df = decayed_df.as_atoms() combined_df = decayed_df.merge(raw_abundance_simple, normalize=False) - - assert_almost_equal(combined_df.loc[28][0], raw_abundance_simple.loc[28][0] + isotope_df.loc[28][0]) - assert_almost_equal(combined_df.loc[28][1], raw_abundance_simple.loc[28][1] + isotope_df.loc[28][1]) + + assert_almost_equal( + combined_df.loc[28][0], + raw_abundance_simple.loc[28][0] + isotope_df.loc[28][0], + ) + assert_almost_equal( + combined_df.loc[28][1], + raw_abundance_simple.loc[28][1] + isotope_df.loc[28][1], + ) assert_almost_equal(combined_df.loc[30][1], raw_abundance_simple.loc[30][1]) - assert_almost_equal(combined_df.loc[26][0], isotope_df.loc[26][0]) \ No newline at end of file + assert_almost_equal(combined_df.loc[26][0], isotope_df.loc[26][0]) diff --git a/tardis/io/tests/test_model_reader.py b/tardis/io/tests/test_model_reader.py index aa0b646716d..60a110b1a91 100644 --- a/tardis/io/tests/test_model_reader.py +++ b/tardis/io/tests/test_model_reader.py @@ -7,91 +7,119 @@ import tardis from tardis.io.config_reader import Configuration from tardis.io.model_reader import ( - read_artis_density, read_simple_ascii_abundances, read_csv_composition, read_uniform_abundances, read_cmfgen_density, read_cmfgen_composition) + read_artis_density, + read_simple_ascii_abundances, + read_csv_composition, + read_uniform_abundances, + read_cmfgen_density, + read_cmfgen_composition, +) + +data_path = os.path.join(tardis.__path__[0], "io", "tests", "data") -data_path = os.path.join(tardis.__path__[0], 'io', 'tests', 'data') @pytest.fixture def artis_density_fname(): - return os.path.join(data_path, 'artis_model.dat') + return os.path.join(data_path, "artis_model.dat") + @pytest.fixture def artis_abundances_fname(): - return os.path.join(data_path, 'artis_abundances.dat') + return os.path.join(data_path, "artis_abundances.dat") + @pytest.fixture def cmfgen_fname(): - return os.path.join(data_path, 'cmfgen_model.csv') + return os.path.join(data_path, "cmfgen_model.csv") + @pytest.fixture def csv_composition_fname(): - return os.path.join(data_path, 'csv_composition.csv') + return os.path.join(data_path, "csv_composition.csv") @pytest.fixture def isotope_uniform_abundance(): config_path = os.path.join( - data_path, 'tardis_configv1_isotope_uniabund.yml') + data_path, "tardis_configv1_isotope_uniabund.yml" + ) config = Configuration.from_yaml(config_path) return config.model.abundances + def test_simple_read_artis_density(artis_density_fname): - time_of_model, velocity, mean_density = read_artis_density(artis_density_fname) + time_of_model, velocity, mean_density = read_artis_density( + artis_density_fname + ) assert np.isclose(0.00114661 * u.day, time_of_model, atol=1e-7 * u.day) - assert np.isclose(mean_density[23], 0.2250048 * u.g / u.cm**3, atol=1.e-6 - * u.g / u.cm**3) + assert np.isclose( + mean_density[23], + 0.2250048 * u.g / u.cm ** 3, + atol=1.0e-6 * u.g / u.cm ** 3, + ) assert len(mean_density) == 69 assert len(velocity) == len(mean_density) + 1 + # Artis files are currently read with read ascii files function def test_read_simple_ascii_abundances(artis_abundances_fname): index, abundances = read_simple_ascii_abundances(artis_abundances_fname) assert len(abundances.columns) == 69 - assert np.isclose(abundances[23].loc[2], 2.672351e-08 , atol=1.e-12) + assert np.isclose(abundances[23].loc[2], 2.672351e-08, atol=1.0e-12) def test_read_simple_isotope_abundances(csv_composition_fname): index, abundances, isotope_abundance = read_csv_composition( - csv_composition_fname) - assert np.isclose(abundances.loc[6, 8], 0.5, atol=1.e-12) - assert np.isclose(abundances.loc[12, 5], 0.8, atol=1.e-12) - assert np.isclose(abundances.loc[14, 1], 0.1, atol=1.e-12) - assert np.isclose(isotope_abundance.loc[(28, 56), 0], 0.4, atol=1.e-12) - assert np.isclose(isotope_abundance.loc[(28, 58), 2], 0.7, atol=1.e-12) + csv_composition_fname + ) + assert np.isclose(abundances.loc[6, 8], 0.5, atol=1.0e-12) + assert np.isclose(abundances.loc[12, 5], 0.8, atol=1.0e-12) + assert np.isclose(abundances.loc[14, 1], 0.1, atol=1.0e-12) + assert np.isclose(isotope_abundance.loc[(28, 56), 0], 0.4, atol=1.0e-12) + assert np.isclose(isotope_abundance.loc[(28, 58), 2], 0.7, atol=1.0e-12) assert abundances.shape == (4, 10) assert isotope_abundance.shape == (2, 10) def test_read_cmfgen_isotope_abundances(cmfgen_fname): - index, abundances, isotope_abundance = read_cmfgen_composition( - cmfgen_fname) - assert np.isclose(abundances.loc[6, 8], 0.5, atol=1.e-12) - assert np.isclose(abundances.loc[12, 5], 0.8, atol=1.e-12) - assert np.isclose(abundances.loc[14, 1], 0.3, atol=1.e-12) - assert np.isclose(isotope_abundance.loc[(28, 56), 0], 0.5, atol=1.e-12) - assert np.isclose(isotope_abundance.loc[(28, 58), 1], 0.7, atol=1.e-12) + index, abundances, isotope_abundance = read_cmfgen_composition(cmfgen_fname) + assert np.isclose(abundances.loc[6, 8], 0.5, atol=1.0e-12) + assert np.isclose(abundances.loc[12, 5], 0.8, atol=1.0e-12) + assert np.isclose(abundances.loc[14, 1], 0.3, atol=1.0e-12) + assert np.isclose(isotope_abundance.loc[(28, 56), 0], 0.5, atol=1.0e-12) + assert np.isclose(isotope_abundance.loc[(28, 58), 1], 0.7, atol=1.0e-12) assert abundances.shape == (4, 9) assert isotope_abundance.shape == (2, 9) def test_read_uniform_abundances(isotope_uniform_abundance): abundances, isotope_abundance = read_uniform_abundances( - isotope_uniform_abundance, 20) - assert np.isclose(abundances.loc[8, 2], 0.19, atol=1.e-12) - assert np.isclose(abundances.loc[20, 5], 0.03, atol=1.e-12) - assert np.isclose(isotope_abundance.loc[(28, 56), 15], 0.05, atol=1.e-12) - assert np.isclose(isotope_abundance.loc[(28, 58), 2], 0.05, atol=1.e-12) + isotope_uniform_abundance, 20 + ) + assert np.isclose(abundances.loc[8, 2], 0.19, atol=1.0e-12) + assert np.isclose(abundances.loc[20, 5], 0.03, atol=1.0e-12) + assert np.isclose(isotope_abundance.loc[(28, 56), 15], 0.05, atol=1.0e-12) + assert np.isclose(isotope_abundance.loc[(28, 58), 2], 0.05, atol=1.0e-12) def test_simple_read_cmfgen_density(cmfgen_fname): - time_of_model, velocity, mean_density, electron_densities, temperature = read_cmfgen_density( - cmfgen_fname) + ( + time_of_model, + velocity, + mean_density, + electron_densities, + temperature, + ) = read_cmfgen_density(cmfgen_fname) assert np.isclose(0.976 * u.day, time_of_model, atol=1e-7 * u.day) - assert np.isclose(mean_density[4], 4.2539537e-09 * u.g / u.cm**3, atol=1.e-6 - * u.g / u.cm**3) - assert np.isclose(electron_densities[5], 2.6e+14 * u.cm**-3, atol=1.e-6 - * u.cm**-3) + assert np.isclose( + mean_density[4], + 4.2539537e-09 * u.g / u.cm ** 3, + atol=1.0e-6 * u.g / u.cm ** 3, + ) + assert np.isclose( + electron_densities[5], 2.6e14 * u.cm ** -3, atol=1.0e-6 * u.cm ** -3 + ) assert len(mean_density) == 9 assert len(velocity) == len(mean_density) + 1 diff --git a/tardis/io/util.py b/tardis/io/util.py index 8644122db99..7ad89813243 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -1,4 +1,4 @@ -#Utility functions for the IO part of TARDIS +# Utility functions for the IO part of TARDIS import os import re @@ -21,7 +21,6 @@ logger = logging.getLogger(__name__) - def get_internal_data_path(fname): """ Get internal data path of TARDIS @@ -32,7 +31,8 @@ def get_internal_data_path(fname): internal data path of TARDIS """ - return os.path.join(TARDIS_PATH[0], 'data', fname) + return os.path.join(TARDIS_PATH[0], "data", fname) + def quantity_from_str(text): """ @@ -47,9 +47,9 @@ def quantity_from_str(text): """ value_str, unit_str = text.split(None, 1) value = float(value_str) - if unit_str.strip() == 'log_lsun': + if unit_str.strip() == "log_lsun": value = 10 ** (value + np.log10(constants.L_sun.cgs.value)) - unit_str = 'erg/s' + unit_str = "erg/s" unit = u.Unit(unit_str) if unit == u.L_sun: @@ -65,6 +65,7 @@ class MockRegexPattern(object): Note: This is usually a lot slower than regex matching. """ + def __init__(self, target_type): self.type = target_type @@ -115,13 +116,18 @@ def construct_quantity(self, node): def mapping_constructor(self, node): return OrderedDict(self.construct_pairs(node)) -YAMLLoader.add_constructor(u'!quantity', YAMLLoader.construct_quantity) -YAMLLoader.add_implicit_resolver(u'!quantity', - MockRegexPattern(quantity_from_str), None) -YAMLLoader.add_implicit_resolver(u'tag:yaml.org,2002:float', - MockRegexPattern(float), None) -YAMLLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, - YAMLLoader.mapping_constructor) + +YAMLLoader.add_constructor("!quantity", YAMLLoader.construct_quantity) +YAMLLoader.add_implicit_resolver( + "!quantity", MockRegexPattern(quantity_from_str), None +) +YAMLLoader.add_implicit_resolver( + "tag:yaml.org,2002:float", MockRegexPattern(float), None +) +YAMLLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, + YAMLLoader.mapping_constructor, +) def yaml_load_file(filename, loader=yaml.Loader): @@ -154,7 +160,11 @@ def traverse_configs(base, other, func, *args): if isinstance(base, collections.Mapping): for k in base: traverse_configs(base[k], other[k], func, *args) - elif isinstance(base, collections.Iterable) and not isinstance(base, basestring) and not hasattr(base, 'shape'): + elif ( + isinstance(base, collections.Iterable) + and not isinstance(base, basestring) + and not hasattr(base, "shape") + ): for val1, val2 in zip(base, other): traverse_configs(val1, val2, func, *args) else: @@ -164,7 +174,7 @@ def traverse_configs(base, other, func, *args): def assert_equality(item1, item2): assert type(item1) is type(item2) try: - if hasattr(item1, 'unit'): + if hasattr(item1, "unit"): assert item1.unit == item2.unit assert np.allclose(item1, item2, atol=0.0) except (ValueError, TypeError): @@ -188,7 +198,7 @@ def __new__(cls, *args, **kwargs): return instance @staticmethod - def to_hdf_util(path_or_buf, path, elements, complevel=9, complib='blosc'): + def to_hdf_util(path_or_buf, path, elements, complevel=9, complib="blosc"): """ A function to uniformly store TARDIS data to an HDF file. @@ -216,13 +226,9 @@ def to_hdf_util(path_or_buf, path, elements, complevel=9, complib='blosc'): we_opened = False try: - buf = pd.HDFStore( - path_or_buf, - complevel=complevel, - complib=complib - ) + buf = pd.HDFStore(path_or_buf, complevel=complevel, complib=complib) except TypeError as e: # Already a HDFStore - if e.message == 'Expected bytes, got HDFStore': + if e.message == "Expected bytes, got HDFStore": buf = path_or_buf else: raise e @@ -236,23 +242,20 @@ def to_hdf_util(path_or_buf, path, elements, complevel=9, complib='blosc'): scalars = {} for key, value in elements.items(): if value is None: - value = 'none' - if hasattr(value, 'cgs'): + value = "none" + if hasattr(value, "cgs"): value = value.cgs.value if np.isscalar(value): scalars[key] = value - elif hasattr(value, 'shape'): + elif hasattr(value, "shape"): if value.ndim == 1: # This try,except block is only for model.plasma.levels try: - pd.Series(value).to_hdf(buf, - os.path.join(path, key)) + pd.Series(value).to_hdf(buf, os.path.join(path, key)) except NotImplementedError: - pd.DataFrame(value).to_hdf(buf, - os.path.join(path, key)) + pd.DataFrame(value).to_hdf(buf, os.path.join(path, key)) else: - pd.DataFrame(value).to_hdf( - buf, os.path.join(path, key)) + pd.DataFrame(value).to_hdf(buf, os.path.join(path, key)) else: try: value.to_hdf(buf, path, name=key) @@ -264,12 +267,12 @@ def to_hdf_util(path_or_buf, path, elements, complevel=9, complib='blosc'): scalars_series = pd.Series(scalars) # Unfortunately, with to_hdf we cannot append, so merge beforehand - scalars_path = os.path.join(path, 'scalars') + scalars_path = os.path.join(path, "scalars") try: scalars_series = buf[scalars_path].append(scalars_series) except KeyError: # no scalars in HDFStore pass - scalars_series.to_hdf(buf, os.path.join(path, 'scalars')) + scalars_series.to_hdf(buf, os.path.join(path, "scalars")) if we_opened: buf.close() @@ -284,10 +287,10 @@ def full_hdf_properties(self): @staticmethod def convert_to_snake_case(s): - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', s) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def to_hdf(self, file_path, path='', name=None): + def to_hdf(self, file_path, path="", name=None): """ Parameters ---------- @@ -314,27 +317,29 @@ def to_hdf(self, file_path, path='', name=None): class PlasmaWriterMixin(HDFWriterMixin): - def get_properties(self): data = {} if self.collection: - properties = [name for name in self.plasma_properties - if isinstance(name, tuple(self.collection))] + properties = [ + name + for name in self.plasma_properties + if isinstance(name, tuple(self.collection)) + ] else: properties = self.plasma_properties for prop in properties: for output in prop.outputs: data[output] = getattr(prop, output) - data['atom_data_uuid'] = self.atomic_data.uuid1 - if 'atomic_data' in data: - data.pop('atomic_data') - if 'nlte_data' in data: + data["atom_data_uuid"] = self.atomic_data.uuid1 + if "atomic_data" in data: + data.pop("atomic_data") + if "nlte_data" in data: logger.warning("nlte_data can't be saved") - data.pop('nlte_data') + data.pop("nlte_data") return data - def to_hdf(self, file_path, path='', name=None, collection=None): - ''' + def to_hdf(self, file_path, path="", name=None, collection=None): + """ Parameters ---------- file_path: str @@ -355,7 +360,7 @@ def to_hdf(self, file_path, path='', name=None, collection=None): Returns ------- - ''' + """ self.collection = collection super(PlasmaWriterMixin, self).to_hdf(file_path, path, name) @@ -376,10 +381,14 @@ def download_from_url(url, dst): return file_size header = {"Range": "bytes=%s-%s" % (first_byte, file_size)} pbar = tqdm( - total=file_size, initial=first_byte, - unit='B', unit_scale=True, desc=url.split('/')[-1]) + total=file_size, + initial=first_byte, + unit="B", + unit_scale=True, + desc=url.split("/")[-1], + ) req = requests.get(url, headers=header, stream=True) - with(open(dst, 'ab')) as f: + with open(dst, "ab") as f: for chunk in req.iter_content(chunk_size=1024): if chunk: f.write(chunk) diff --git a/tardis/plasma/base.py b/tardis/plasma/base.py index b90f3ab2ddd..56816bfb4a2 100644 --- a/tardis/plasma/base.py +++ b/tardis/plasma/base.py @@ -13,17 +13,20 @@ logger = logging.getLogger(__name__) + class BasePlasma(PlasmaWriterMixin): outputs_dict = {} - hdf_name = 'plasma' + hdf_name = "plasma" + def __init__(self, plasma_properties, property_kwargs=None, **kwargs): self.outputs_dict = {} self.input_properties = [] - self.plasma_properties = self._init_properties(plasma_properties, - property_kwargs, **kwargs) + self.plasma_properties = self._init_properties( + plasma_properties, property_kwargs, **kwargs + ) self._build_graph() -# self.write_to_tex('Plasma_Graph') + # self.write_to_tex('Plasma_Graph') self.update(**kwargs) def __getattr__(self, item): @@ -33,23 +36,24 @@ def __getattr__(self, item): super(BasePlasma, self).__getattribute__(item) def __setattr__(self, key, value): - if key != 'module_dict' and key in self.outputs_dict: - raise AttributeError('Plasma inputs can only be updated using ' - 'the \'update\' method') + if key != "module_dict" and key in self.outputs_dict: + raise AttributeError( + "Plasma inputs can only be updated using " "the 'update' method" + ) else: super(BasePlasma, self).__setattr__(key, value) def __dir__(self): - attrs = [item for item in self.__dict__ - if not item.startswith('_')] - attrs += [item for item in self.__class__.__dict__ - if not item.startswith('_')] + attrs = [item for item in self.__dict__ if not item.startswith("_")] + attrs += [ + item for item in self.__class__.__dict__ if not item.startswith("_") + ] attrs += self.outputs_dict.keys() return attrs @property def plasma_properties_dict(self): - return {item.name:item for item in self.plasma_properties} + return {item.name: item for item in self.plasma_properties} def get_value(self, item): return getattr(self.outputs_dict[item], item) @@ -64,36 +68,48 @@ def _build_graph(self): self.graph = nx.DiGraph() ## Adding all nodes - self.graph.add_nodes_from([(plasma_property.name, {}) - for plasma_property - in self.plasma_properties]) - - #Flagging all input modules - self.input_properties = [item for item in self.plasma_properties - if not hasattr(item, 'inputs')] + self.graph.add_nodes_from( + [ + (plasma_property.name, {}) + for plasma_property in self.plasma_properties + ] + ) + + # Flagging all input modules + self.input_properties = [ + item + for item in self.plasma_properties + if not hasattr(item, "inputs") + ] for plasma_property in self.plasma_properties: - #Skipping any module that is an input module + # Skipping any module that is an input module if plasma_property in self.input_properties: continue for input in plasma_property.inputs: if input not in self.outputs_dict: - raise PlasmaMissingModule('Module {0} requires input ' - '{1} which has not been added' - ' to this plasma'.format( - plasma_property.name, input)) + raise PlasmaMissingModule( + "Module {0} requires input " + "{1} which has not been added" + " to this plasma".format(plasma_property.name, input) + ) try: position = self.outputs_dict[input].outputs.index(input) label = self.outputs_dict[input].latex_name[position] - label = '$' + label + '$' - label = label.replace('\\', '\\\\') + label = "$" + label + "$" + label = label.replace("\\", "\\\\") except: - label = input.replace('_', '-') - self.graph.add_edge(self.outputs_dict[input].name, - plasma_property.name, label = label) - - def _init_properties(self, plasma_properties, property_kwargs=None, **kwargs): + label = input.replace("_", "-") + self.graph.add_edge( + self.outputs_dict[input].name, + plasma_property.name, + label=label, + ) + + def _init_properties( + self, plasma_properties, property_kwargs=None, **kwargs + ): """ Builds a dictionary with the plasma module names as keys @@ -120,25 +136,31 @@ def _init_properties(self, plasma_properties, property_kwargs=None, **kwargs): if issubclass(plasma_property, PreviousIterationProperty): current_property_object = plasma_property( - **property_kwargs.get(plasma_property, {})) + **property_kwargs.get(plasma_property, {}) + ) current_property_object.set_initial_value(kwargs) self.previous_iteration_properties.append( - current_property_object) + current_property_object + ) elif issubclass(plasma_property, Input): if not set(kwargs.keys()).issuperset(plasma_property.outputs): - missing_input_values = (set(plasma_property.outputs) - - set(kwargs.keys())) - raise NotInitializedModule('Input {0} required for ' - 'plasma but not given when ' - 'instantiating the ' - 'plasma'.format( - missing_input_values)) + missing_input_values = set(plasma_property.outputs) - set( + kwargs.keys() + ) + raise NotInitializedModule( + "Input {0} required for " + "plasma but not given when " + "instantiating the " + "plasma".format(missing_input_values) + ) current_property_object = plasma_property( - **property_kwargs.get(plasma_property, {})) + **property_kwargs.get(plasma_property, {}) + ) else: current_property_object = plasma_property( - self, **property_kwargs.get(plasma_property, {})) + self, **property_kwargs.get(plasma_property, {}) + ) for output in plasma_property.outputs: self.outputs_dict[output] = current_property_object plasma_property_objects.append(current_property_object) @@ -148,13 +170,16 @@ def store_previous_properties(self): for property in self.previous_iteration_properties: p = property.outputs[0] self.outputs_dict[p].set_value( - self.get_value(re.sub(r'^previous_', '', p))) + self.get_value(re.sub(r"^previous_", "", p)) + ) def update(self, **kwargs): for key in kwargs: if key not in self.outputs_dict: - raise PlasmaMissingModule('Trying to update property {0}' - ' that is unavailable'.format(key)) + raise PlasmaMissingModule( + "Trying to update property {0}" + " that is unavailable".format(key) + ) self.outputs_dict[key].set_value(kwargs[key]) for module_name in self._resolve_update_list(kwargs.keys()): @@ -192,35 +217,38 @@ def _resolve_update_list(self, changed_properties): descendants_ob = list(set(descendants_ob)) sort_order = list(nx.topological_sort(self.graph)) - descendants_ob.sort(key=lambda val: sort_order.index(val) ) + descendants_ob.sort(key=lambda val: sort_order.index(val)) - logger.debug('Updating modules in the following order: {}'.format( - '->'.join(descendants_ob))) + logger.debug( + "Updating modules in the following order: {}".format( + "->".join(descendants_ob) + ) + ) return descendants_ob def write_to_dot(self, fname, latex_label=True): -# self._update_module_type_str() + # self._update_module_type_str() try: import pygraphviz except: - logger.warn('pygraphviz missing. Plasma graph will not be ' - 'generated.') + logger.warn( + "pygraphviz missing. Plasma graph will not be " "generated." + ) return print_graph = self.graph.copy() print_graph = self.remove_hidden_properties(print_graph) for node in print_graph: - print_graph.node[str(node)]['label'] = node - if hasattr(self.plasma_properties_dict[node], - 'latex_formula'): + print_graph.node[str(node)]["label"] = node + if hasattr(self.plasma_properties_dict[node], "latex_formula"): formulae = self.plasma_properties_dict[node].latex_formula for output in range(0, len(formulae)): formula = formulae[output] - label = formula.replace('\\', '\\\\') - print_graph.node[str(node)]['label']+='\\n$' - print_graph.node[str(node)]['label']+=label - print_graph.node[str(node)]['label']+='$' + label = formula.replace("\\", "\\\\") + print_graph.node[str(node)]["label"] += "\\n$" + print_graph.node[str(node)]["label"] += label + print_graph.node[str(node)]["label"] += "$" nx.drawing.nx_agraph.write_dot(print_graph, fname) @@ -228,8 +256,9 @@ def write_to_tex(self, fname_graph): try: import dot2tex except: - logger.warn('dot2tex missing. Plasma graph will not be ' - 'generated.') + logger.warn( + "dot2tex missing. Plasma graph will not be " "generated." + ) return temp_fname = tempfile.NamedTemporaryFile().name @@ -238,37 +267,47 @@ def write_to_tex(self, fname_graph): dot_string = open(temp_fname).read() - open(fname_graph, 'w').write(dot2tex.dot2tex(dot_string, - texmode='raw')) + open(fname_graph, "w").write(dot2tex.dot2tex(dot_string, texmode="raw")) - for line in fileinput.input(fname_graph, inplace = 1): - print(line.replace(r'\documentclass{article}', - r'\documentclass[class=minimal,border=20pt]{standalone}'), - end='') + for line in fileinput.input(fname_graph, inplace=1): + print( + line.replace( + r"\documentclass{article}", + r"\documentclass[class=minimal,border=20pt]{standalone}", + ), + end="", + ) - for line in fileinput.input(fname_graph, inplace = 1): - print(line.replace(r'\enlargethispage{100cm}', ''), end='') + for line in fileinput.input(fname_graph, inplace=1): + print(line.replace(r"\enlargethispage{100cm}", ""), end="") def remove_hidden_properties(self, print_graph): for item in self.plasma_properties_dict.values(): module = self.plasma_properties_dict[item.name].__class__ - if (issubclass(module, HiddenPlasmaProperty)): + if issubclass(module, HiddenPlasmaProperty): output = module.outputs[0] for value in self.plasma_properties_dict.keys(): if output in getattr( - self.plasma_properties_dict[value], 'inputs', []): + self.plasma_properties_dict[value], "inputs", [] + ): for input in self.plasma_properties_dict[ - item.name].inputs: + item.name + ].inputs: try: position = self.outputs_dict[ - input].outputs.index(input) - label = self.outputs_dict[ - input].latex_name[position] - label = '$' + label + '$' - label = label.replace('\\', '\\\\') + input + ].outputs.index(input) + label = self.outputs_dict[input].latex_name[ + position + ] + label = "$" + label + "$" + label = label.replace("\\", "\\\\") except: - label = input.replace('_', '-') - self.graph.add_edge(self.outputs_dict[input].name, - value, label = label) + label = input.replace("_", "-") + self.graph.add_edge( + self.outputs_dict[input].name, + value, + label=label, + ) print_graph.remove_node(str(item.name)) return print_graph diff --git a/tardis/plasma/exceptions.py b/tardis/plasma/exceptions.py index 70108c0b308..897ce4cf658 100644 --- a/tardis/plasma/exceptions.py +++ b/tardis/plasma/exceptions.py @@ -1,25 +1,33 @@ class PlasmaException(Exception): pass + class IncompleteAtomicData(PlasmaException): def __init__(self, atomic_data_name): - message = ('The current plasma calculation requires {0}, ' - 'which is not provided by the given atomic data'.format( - atomic_data_name)) + message = ( + "The current plasma calculation requires {0}, " + "which is not provided by the given atomic data".format( + atomic_data_name + ) + ) super(PlasmaException, self).__init__(message) class PlasmaMissingModule(PlasmaException): pass + class PlasmaIsolatedModule(PlasmaException): pass + class NotInitializedModule(PlasmaException): pass + class PlasmaIonizationError(PlasmaException): pass + class PlasmaConfigError(PlasmaException): - pass \ No newline at end of file + pass diff --git a/tardis/plasma/properties/atomic.py b/tardis/plasma/properties/atomic.py index 503daba3856..3e1c7014a7d 100644 --- a/tardis/plasma/properties/atomic.py +++ b/tardis/plasma/properties/atomic.py @@ -4,15 +4,27 @@ import pandas as pd from collections import Counter as counter -from tardis.plasma.properties.base import (ProcessingPlasmaProperty, - HiddenPlasmaProperty, BaseAtomicDataProperty) +from tardis.plasma.properties.base import ( + ProcessingPlasmaProperty, + HiddenPlasmaProperty, + BaseAtomicDataProperty, +) from tardis.plasma.exceptions import IncompleteAtomicData logger = logging.getLogger(__name__) -__all__ = ['Levels', 'Lines', 'LinesLowerLevelIndex', 'LinesUpperLevelIndex', - 'AtomicMass', 'IonizationData', 'ZetaData', 'NLTEData', - 'PhotoIonizationData'] +__all__ = [ + "Levels", + "Lines", + "LinesLowerLevelIndex", + "LinesUpperLevelIndex", + "AtomicMass", + "IonizationData", + "ZetaData", + "NLTEData", + "PhotoIonizationData", +] + class Levels(BaseAtomicDataProperty): """ @@ -27,9 +39,14 @@ class Levels(BaseAtomicDataProperty): g : Pandas DataFrame (index=levels), dtype float Statistical weights of atomic levels """ - outputs = ('levels', 'excitation_energy', 'metastability', 'g') - latex_name = ('\\textrm{levels}', '\\epsilon_{\\textrm{k}}', '\\textrm{metastability}', - 'g') + + outputs = ("levels", "excitation_energy", "metastability", "g") + latex_name = ( + "\\textrm{levels}", + "\\epsilon_{\\textrm{k}}", + "\\textrm{metastability}", + "g", + ) def _filter_atomic_property(self, levels, selected_atoms): return levels @@ -38,8 +55,13 @@ def _filter_atomic_property(self, levels, selected_atoms): def _set_index(self, levels): # levels = levels.set_index(['atomic_number', 'ion_number', # 'level_number']) - return (levels.index, levels['energy'], levels['metastable'], - levels['g']) + return ( + levels.index, + levels["energy"], + levels["metastable"], + levels["g"], + ) + class Lines(BaseAtomicDataProperty): """ @@ -55,8 +77,9 @@ class Lines(BaseAtomicDataProperty): wavelength_cm: Pandas DataFrame (index=line_id), dtype float Line wavelengths in cm """ -# Would like for lines to just be the line_id values - outputs = ('lines', 'nu', 'f_lu', 'wavelength_cm') + + # Would like for lines to just be the line_id values + outputs = ("lines", "nu", "f_lu", "wavelength_cm") def _filter_atomic_property(self, lines, selected_atoms): # return lines[lines.atomic_number.isin(selected_atoms)] @@ -64,7 +87,7 @@ def _filter_atomic_property(self, lines, selected_atoms): def _set_index(self, lines): # lines.set_index('line_id', inplace=True) - return lines, lines['nu'], lines['f_lu'], lines['wavelength_cm'] + return lines, lines["nu"], lines["f_lu"], lines["wavelength_cm"] class PhotoIonizationData(ProcessingPlasmaProperty): @@ -86,23 +109,27 @@ class PhotoIonizationData(ProcessingPlasmaProperty): Atomic, ion and level numbers for which photoionization data exists. """ - outputs = ('photo_ion_cross_sections', 'photo_ion_block_references', - 'photo_ion_index') - latex_name = ('\\xi_{\\textrm{i}}(\\nu)', '', '') + + outputs = ( + "photo_ion_cross_sections", + "photo_ion_block_references", + "photo_ion_index", + ) + latex_name = ("\\xi_{\\textrm{i}}(\\nu)", "", "") def calculate(self, atomic_data, continuum_interaction_species): photoionization_data = atomic_data.photoionization_data.set_index( - ['atomic_number', 'ion_number', 'level_number'] + ["atomic_number", "ion_number", "level_number"] ) selected_species_idx = pd.IndexSlice[ - continuum_interaction_species.get_level_values('atomic_number'), - continuum_interaction_species.get_level_values('ion_number'), - slice(None) + continuum_interaction_species.get_level_values("atomic_number"), + continuum_interaction_species.get_level_values("ion_number"), + slice(None), ] photoionization_data = photoionization_data.loc[selected_species_idx] - phot_nus = photoionization_data['nu'] + phot_nus = photoionization_data["nu"] block_references = np.hstack( - [[0], phot_nus.groupby(level=[0,1,2]).count().values.cumsum()] + [[0], phot_nus.groupby(level=[0, 1, 2]).count().values.cumsum()] ) photo_ion_index = photoionization_data.index.unique() return photoionization_data, block_references, photo_ion_index @@ -114,25 +141,31 @@ class LinesLowerLevelIndex(HiddenPlasmaProperty): lines_lower_level_index : One-dimensional Numpy Array, dtype int Levels data for lower levels of particular lines """ - outputs = ('lines_lower_level_index',) + + outputs = ("lines_lower_level_index",) + def calculate(self, levels, lines): - levels_index = pd.Series(np.arange(len(levels), dtype=np.int64), - index=levels) - lines_index = lines.index.droplevel('level_number_upper') + levels_index = pd.Series( + np.arange(len(levels), dtype=np.int64), index=levels + ) + lines_index = lines.index.droplevel("level_number_upper") return np.array(levels_index.loc[lines_index]) + class LinesUpperLevelIndex(HiddenPlasmaProperty): """ Attributes: lines_upper_level_index : One-dimensional Numpy Array, dtype int Levels data for upper levels of particular lines """ - outputs = ('lines_upper_level_index',) + + outputs = ("lines_upper_level_index",) def calculate(self, levels, lines): - levels_index = pd.Series(np.arange(len(levels), dtype=np.int64), - index=levels) - lines_index = lines.index.droplevel('level_number_lower') + levels_index = pd.Series( + np.arange(len(levels), dtype=np.int64), index=levels + ) + lines_index = lines.index.droplevel("level_number_lower") return np.array(levels_index.loc[lines_index]) @@ -142,44 +175,44 @@ class AtomicMass(ProcessingPlasmaProperty): atomic_mass : Pandas Series Atomic masses of the elements used. Indexed by atomic number. """ - outputs = ('atomic_mass',) + + outputs = ("atomic_mass",) def calculate(self, atomic_data, selected_atoms): if getattr(self, self.outputs[0]) is not None: - return getattr(self, self.outputs[0]), + return (getattr(self, self.outputs[0]),) else: return atomic_data.atom_data.loc[selected_atoms].mass + class IonizationData(BaseAtomicDataProperty): """ Attributes: ionization_data : Pandas Series holding ionization energies Indexed by atomic number, ion number. """ - outputs = ('ionization_data',) + + outputs = ("ionization_data",) def _filter_atomic_property(self, ionization_data, selected_atoms): - mask = ionization_data.index.isin( - selected_atoms, - level='atomic_number' - ) + mask = ionization_data.index.isin(selected_atoms, level="atomic_number") ionization_data = ionization_data[mask] - counts = ionization_data.groupby( - level='atomic_number').count() + counts = ionization_data.groupby(level="atomic_number").count() if np.alltrue(counts.index == counts): return ionization_data else: raise IncompleteAtomicData( - 'ionization data for the ion ({}, {})'.format( - str(counts.index[counts.index != counts]), - str(counts[counts.index != counts]) - ) - ) + "ionization data for the ion ({}, {})".format( + str(counts.index[counts.index != counts]), + str(counts[counts.index != counts]), + ) + ) def _set_index(self, ionization_data): return ionization_data + class ZetaData(BaseAtomicDataProperty): """ Attributes: @@ -189,11 +222,12 @@ class ZetaData(BaseAtomicDataProperty): The zeta value represents the fraction of recombination events from the ionized state that go directly to the ground state. """ - outputs = ('zeta_data',) + + outputs = ("zeta_data",) def _filter_atomic_property(self, zeta_data, selected_atoms): - zeta_data['atomic_number'] = zeta_data.index.labels[0] + 1 - zeta_data['ion_number'] = zeta_data.index.labels[1] + 1 + zeta_data["atomic_number"] = zeta_data.index.labels[0] + 1 + zeta_data["ion_number"] = zeta_data.index.labels[1] + 1 zeta_data = zeta_data[zeta_data.atomic_number.isin(selected_atoms)] zeta_data_check = counter(zeta_data.atomic_number.values) keys = np.array(list(zeta_data_check.keys())) @@ -201,36 +235,47 @@ def _filter_atomic_property(self, zeta_data, selected_atoms): if np.alltrue(keys + 1 == values): return zeta_data else: -# raise IncompleteAtomicData('zeta data') -# This currently replaces missing zeta data with 1, which is necessary with -# the present atomic data. Will replace with the error above when I have -# complete atomic data. + # raise IncompleteAtomicData('zeta data') + # This currently replaces missing zeta data with 1, which is necessary with + # the present atomic data. Will replace with the error above when I have + # complete atomic data. missing_ions = [] updated_index = [] for atom in selected_atoms: for ion in range(1, atom + 2): if (atom, ion) not in zeta_data.index: - missing_ions.append((atom,ion)) + missing_ions.append((atom, ion)) updated_index.append([atom, ion]) - logger.warn('Zeta_data missing - replaced with 1s. Missing ions: {}'.format(missing_ions)) + logger.warn( + "Zeta_data missing - replaced with 1s. Missing ions: {}".format( + missing_ions + ) + ) updated_index = np.array(updated_index) - updated_dataframe = pd.DataFrame(index=pd.MultiIndex.from_arrays( - updated_index.transpose().astype(int)), - columns=zeta_data.columns) + updated_dataframe = pd.DataFrame( + index=pd.MultiIndex.from_arrays( + updated_index.transpose().astype(int) + ), + columns=zeta_data.columns, + ) for value in range(len(zeta_data)): - updated_dataframe.loc[zeta_data.atomic_number.values[value], - zeta_data.ion_number.values[value]] = \ - zeta_data.loc[zeta_data.atomic_number.values[value], - zeta_data.ion_number.values[value]] + updated_dataframe.loc[ + zeta_data.atomic_number.values[value], + zeta_data.ion_number.values[value], + ] = zeta_data.loc[ + zeta_data.atomic_number.values[value], + zeta_data.ion_number.values[value], + ] updated_dataframe = updated_dataframe.astype(float) updated_index = pd.DataFrame(updated_index) - updated_dataframe['atomic_number'] = np.array(updated_index[0]) - updated_dataframe['ion_number'] = np.array(updated_index[1]) + updated_dataframe["atomic_number"] = np.array(updated_index[0]) + updated_dataframe["ion_number"] = np.array(updated_index[1]) updated_dataframe.fillna(1.0, inplace=True) return updated_dataframe def _set_index(self, zeta_data): - return zeta_data.set_index(['atomic_number', 'ion_number']) + return zeta_data.set_index(["atomic_number", "ion_number"]) + class NLTEData(ProcessingPlasmaProperty): """ @@ -238,7 +283,8 @@ class NLTEData(ProcessingPlasmaProperty): nlte_data : #Finish later (need atomic dataset with NLTE data). """ - outputs = ('nlte_data',) + + outputs = ("nlte_data",) def calculate(self, atomic_data): if getattr(self, self.outputs[0]) is not None: diff --git a/tardis/plasma/properties/base.py b/tardis/plasma/properties/base.py index e38dd2f5b5f..3a8138e1014 100644 --- a/tardis/plasma/properties/base.py +++ b/tardis/plasma/properties/base.py @@ -5,14 +5,22 @@ import pandas as pd -__all__ = ['BasePlasmaProperty', 'BaseAtomicDataProperty', - 'HiddenPlasmaProperty', 'Input', 'ArrayInput', 'DataFrameInput', - 'ProcessingPlasmaProperty', 'PreviousIterationProperty'] +__all__ = [ + "BasePlasmaProperty", + "BaseAtomicDataProperty", + "HiddenPlasmaProperty", + "Input", + "ArrayInput", + "DataFrameInput", + "ProcessingPlasmaProperty", + "PreviousIterationProperty", +] logger = logging.getLogger(__name__) import os + class BasePlasmaProperty(object, metaclass=ABCMeta): """ Attributes @@ -45,25 +53,19 @@ def get_latex_label(self): \textbf{{Formula}} {formula} {description} """ - outputs = self.outputs.replace('_', r'\_') - latex_name = getattr(self, 'latex_name', '') - if latex_name != '': - complete_name = '{0} [{1}]'.format(latex_name, self.latex_name) + outputs = self.outputs.replace("_", r"\_") + latex_name = getattr(self, "latex_name", "") + if latex_name != "": + complete_name = "{0} [{1}]".format(latex_name, self.latex_name) else: complete_name = latex_name latex_label = latex_template.format( - name=complete_name, - formula=getattr( - self, - 'latex_formula', '--'), - description=getattr( - self, - 'latex_description', - '')) - return latex_label.replace('\\', r'\\') - - + name=complete_name, + formula=getattr(self, "latex_formula", "--"), + description=getattr(self, "latex_description", ""), + ) + return latex_label.replace("\\", r"\\") class ProcessingPlasmaProperty(BasePlasmaProperty, metaclass=ABCMeta): @@ -85,10 +87,11 @@ def _update_inputs(self): `calculate`-function and makes the plasma routines easily programmable. """ calculate_call_signature = self.calculate.__code__.co_varnames[ - :self.calculate.__code__.co_argcount] + : self.calculate.__code__.co_argcount + ] self.inputs = [ - item for item in calculate_call_signature - if item != 'self'] + item for item in calculate_call_signature if item != "self" + ] def _get_input_values(self): return (self.plasma_parent.get_value(item) for item in self.inputs) @@ -101,8 +104,9 @@ def update(self): :return: """ if len(self.outputs) == 1: - setattr(self, self.outputs[0], self.calculate( - *self._get_input_values())) + setattr( + self, self.outputs[0], self.calculate(*self._get_input_values()) + ) else: new_values = self.calculate(*self._get_input_values()) for i, output in enumerate(self.outputs): @@ -110,8 +114,10 @@ def update(self): @abstractmethod def calculate(self, *args, **kwargs): - raise NotImplementedError('This method needs to be implemented by ' - 'processing plasma modules') + raise NotImplementedError( + "This method needs to be implemented by " + "processing plasma modules" + ) class HiddenPlasmaProperty(ProcessingPlasmaProperty, metaclass=ABCMeta): @@ -131,7 +137,7 @@ class BaseAtomicDataProperty(ProcessingPlasmaProperty, metaclass=ABCMeta): the simulation. """ - inputs = ['atomic_data', 'selected_atoms'] + inputs = ["atomic_data", "selected_atoms"] def __init__(self, plasma_parent): @@ -140,11 +146,11 @@ def __init__(self, plasma_parent): @abstractmethod def _set_index(self, raw_atomic_property): - raise NotImplementedError('Needs to be implemented in subclasses') + raise NotImplementedError("Needs to be implemented in subclasses") @abstractmethod def _filter_atomic_property(self, raw_atomic_property): - raise NotImplementedError('Needs to be implemented in subclasses') + raise NotImplementedError("Needs to be implemented in subclasses") def calculate(self, atomic_data, selected_atoms): @@ -153,9 +159,10 @@ def calculate(self, atomic_data, selected_atoms): else: raw_atomic_property = getattr(atomic_data, self.outputs[0]) return self._set_index( - self._filter_atomic_property( - raw_atomic_property, selected_atoms) - ) + self._filter_atomic_property( + raw_atomic_property, selected_atoms + ) + ) class Input(BasePlasmaProperty): @@ -163,6 +170,7 @@ class Input(BasePlasmaProperty): The plasma property class for properties that are input directly from model and not calculated within the plasma module, e.g. t_rad. """ + def _set_output_value(self, output, value): setattr(self, output, value) @@ -188,6 +196,7 @@ class PreviousIterationProperty(BasePlasmaProperty): calculations. Given a sufficient number of iterations, the values should converge successfully on the correct solution. """ + def _set_initial_value(self, value): self.set_value(value) diff --git a/tardis/plasma/properties/continuum_processes.py b/tardis/plasma/properties/continuum_processes.py index 5de3f3825e0..5c53d02617f 100644 --- a/tardis/plasma/properties/continuum_processes.py +++ b/tardis/plasma/properties/continuum_processes.py @@ -8,11 +8,11 @@ from tardis.plasma.properties.base import ProcessingPlasmaProperty -__all__ = ['SpontRecombRateCoeff'] +__all__ = ["SpontRecombRateCoeff"] logger = logging.getLogger(__name__) -njit_dict = {'fastmath': False, 'parallel': False} +njit_dict = {"fastmath": False, "parallel": False} @njit(**njit_dict) @@ -73,20 +73,31 @@ class SpontRecombRateCoeff(ProcessingPlasmaProperty): alpha_sp : Pandas DataFrame, dtype float The rate coefficient for spontaneous recombination. """ - outputs = ('alpha_sp',) - latex_name = ('\\alpha^{\\textrm{sp}}',) - def calculate(self, photo_ion_cross_sections, t_electrons, - photo_ion_block_references, photo_ion_index, phi_ik): - x_sect = photo_ion_cross_sections['x_sect'].values - nu = photo_ion_cross_sections['nu'].values - - alpha_sp = (8 * np.pi * x_sect * nu ** 2 / (const.c.cgs.value) ** 2) + outputs = ("alpha_sp",) + latex_name = ("\\alpha^{\\textrm{sp}}",) + + def calculate( + self, + photo_ion_cross_sections, + t_electrons, + photo_ion_block_references, + photo_ion_index, + phi_ik, + ): + x_sect = photo_ion_cross_sections["x_sect"].values + nu = photo_ion_cross_sections["nu"].values + + alpha_sp = 8 * np.pi * x_sect * nu ** 2 / (const.c.cgs.value) ** 2 alpha_sp = alpha_sp[:, np.newaxis] - boltzmann_factor = np.exp(-nu[np.newaxis].T / t_electrons * - (const.h.cgs.value / const.k_B.cgs.value)) + boltzmann_factor = np.exp( + -nu[np.newaxis].T + / t_electrons + * (const.h.cgs.value / const.k_B.cgs.value) + ) alpha_sp = alpha_sp * boltzmann_factor - alpha_sp = integrate_array_by_blocks(alpha_sp, nu, - photo_ion_block_references) + alpha_sp = integrate_array_by_blocks( + alpha_sp, nu, photo_ion_block_references + ) alpha_sp = pd.DataFrame(alpha_sp, index=photo_ion_index) return alpha_sp * phi_ik diff --git a/tardis/plasma/properties/general.py b/tardis/plasma/properties/general.py index cd8717a8d71..e6dbfc33188 100644 --- a/tardis/plasma/properties/general.py +++ b/tardis/plasma/properties/general.py @@ -9,9 +9,18 @@ logger = logging.getLogger(__name__) -__all__ = ['BetaRadiation', 'GElectron', 'NumberDensity', 'SelectedAtoms', - 'ElectronTemperature', 'BetaElectron', 'LuminosityInner', - 'TimeSimulation', 'ThermalGElectron'] +__all__ = [ + "BetaRadiation", + "GElectron", + "NumberDensity", + "SelectedAtoms", + "ElectronTemperature", + "BetaElectron", + "LuminosityInner", + "TimeSimulation", + "ThermalGElectron", +] + class BetaRadiation(ProcessingPlasmaProperty): """ @@ -19,9 +28,10 @@ class BetaRadiation(ProcessingPlasmaProperty): ---------- beta_rad : Numpy Array, dtype float """ - outputs = ('beta_rad',) - latex_name = ('\\beta_{\\textrm{rad}}',) - latex_formula = ('\\dfrac{1}{k_{B} T_{\\textrm{rad}}}',) + + outputs = ("beta_rad",) + latex_name = ("\\beta_{\\textrm{rad}}",) + latex_formula = ("\\dfrac{1}{k_{B} T_{\\textrm{rad}}}",) def __init__(self, plasma_parent): super(BetaRadiation, self).__init__(plasma_parent) @@ -30,20 +40,23 @@ def __init__(self, plasma_parent): def calculate(self, t_rad): return 1 / (self.k_B_cgs * t_rad) + class GElectron(ProcessingPlasmaProperty): """ Attributes ---------- g_electron : Numpy Array, dtype float """ - outputs = ('g_electron',) - latex_name = ('g_{\\textrm{electron}}',) - latex_formula = ('\\Big(\\dfrac{2\\pi m_{e}/\ - \\beta_{\\textrm{rad}}}{h^2}\\Big)^{3/2}',) + + outputs = ("g_electron",) + latex_name = ("g_{\\textrm{electron}}",) + latex_formula = (r"\Big(\dfrac{2\pi m_{e}/\beta_{\textrm{rad}}}{h^2}\Big)^{3/2}",) def calculate(self, beta_rad): - return ((2 * np.pi * const.m_e.cgs.value / beta_rad) / - (const.h.cgs.value ** 2)) ** 1.5 + return ( + (2 * np.pi * const.m_e.cgs.value / beta_rad) + / (const.h.cgs.value ** 2) + ) ** 1.5 class ThermalGElectron(GElectron): @@ -52,10 +65,12 @@ class ThermalGElectron(GElectron): ---------- thermal_g_electron : Numpy Array, dtype float """ - outputs = ('thermal_g_electron',) - latex_name = ('g_{\\textrm{electron_thermal}}',) - latex_formula = ('\\Big(\\dfrac{2\\pi m_{e}/\ - \\beta_{\\textrm{electron}}}{h^2}\\Big)^{3/2}',) + + outputs = ("thermal_g_electron",) + latex_name = ("g_{\\textrm{electron_thermal}}",) + latex_formula = ( + r"\Big(\dfrac{2\pi m_{e}/\beta_{\textrm{electron}}}{h^2}\Big)^{3/2}", + ) def calculate(self, beta_electron): return super(ThermalGElectron, self).calculate(beta_electron) @@ -68,14 +83,16 @@ class NumberDensity(ProcessingPlasmaProperty): number_density : Pandas DataFrame, dtype float Indexed by atomic number, columns corresponding to zones """ - outputs = ('number_density',) - latex_name = ('N_{i}',) + + outputs = ("number_density",) + latex_name = ("N_{i}",) @staticmethod def calculate(atomic_mass, abundance, density): - number_densities = (abundance * density) + number_densities = abundance * density return number_densities.div(atomic_mass.loc[abundance.index], axis=0) + class SelectedAtoms(ProcessingPlasmaProperty): """ Attributes @@ -83,33 +100,38 @@ class SelectedAtoms(ProcessingPlasmaProperty): selected_atoms : Pandas Int64Index, dtype int Atomic numbers of elements required for particular simulation """ - outputs = ('selected_atoms',) + + outputs = ("selected_atoms",) def calculate(self, abundance): return abundance.index + class ElectronTemperature(ProcessingPlasmaProperty): """ Attributes ---------- t_electron : Numpy Array, dtype float """ - outputs = ('t_electrons',) - latex_name = ('T_{\\textrm{electron}}',) - latex_formula = ('\\textrm{const.}\\times T_{\\textrm{rad}}',) + + outputs = ("t_electrons",) + latex_name = ("T_{\\textrm{electron}}",) + latex_formula = ("\\textrm{const.}\\times T_{\\textrm{rad}}",) def calculate(self, t_rad, link_t_rad_t_electron): return t_rad * link_t_rad_t_electron + class BetaElectron(ProcessingPlasmaProperty): """ Attributes ---------- beta_electron : Numpy Array, dtype float """ - outputs = ('beta_electron',) - latex_name = ('\\beta_{\\textrm{electron}}',) - latex_formula = ('\\frac{1}{K_{B} T_{\\textrm{electron}}}',) + + outputs = ("beta_electron",) + latex_name = ("\\beta_{\\textrm{electron}}",) + latex_formula = ("\\frac{1}{K_{B} T_{\\textrm{electron}}}",) def __init__(self, plasma_parent): super(BetaElectron, self).__init__(plasma_parent) @@ -118,16 +140,19 @@ def __init__(self, plasma_parent): def calculate(self, t_electrons): return 1 / (self.k_B_cgs * t_electrons) + class LuminosityInner(ProcessingPlasmaProperty): - outputs = ('luminosity_inner',) + outputs = ("luminosity_inner",) @staticmethod def calculate(r_inner, t_inner): - return (4 * np.pi * const.sigma_sb.cgs * r_inner[0] ** 2 - * t_inner ** 4).to('erg/s') + return ( + 4 * np.pi * const.sigma_sb.cgs * r_inner[0] ** 2 * t_inner ** 4 + ).to("erg/s") + class TimeSimulation(ProcessingPlasmaProperty): - outputs = ('time_simulation',) + outputs = ("time_simulation",) @staticmethod def calculate(luminosity_inner): diff --git a/tardis/plasma/properties/ion_population.py b/tardis/plasma/properties/ion_population.py index 92ef72747e9..31bdcaf6a6f 100644 --- a/tardis/plasma/properties/ion_population.py +++ b/tardis/plasma/properties/ion_population.py @@ -15,15 +15,22 @@ logger = logging.getLogger(__name__) -__all__ = ['PhiSahaNebular', 'PhiSahaLTE', 'RadiationFieldCorrection', - 'IonNumberDensity', 'IonNumberDensityHeNLTE', 'SahaFactor', - 'ThermalPhiSahaLTE'] +__all__ = [ + "PhiSahaNebular", + "PhiSahaLTE", + "RadiationFieldCorrection", + "IonNumberDensity", + "IonNumberDensityHeNLTE", + "SahaFactor", + "ThermalPhiSahaLTE", +] def calculate_block_ids_from_dataframe(dataframe): - block_start_id = np.where(np.diff( - dataframe.index.get_level_values(0)) != 0.0)[0] + 1 - return np.hstack(([0], block_start_id, [len(dataframe)])) + block_start_id = ( + np.where(np.diff(dataframe.index.get_level_values(0)) != 0.0)[0] + 1 + ) + return np.hstack(([0], block_start_id, [len(dataframe)])) class PhiSahaLTE(ProcessingPlasmaProperty): @@ -33,12 +40,15 @@ class PhiSahaLTE(ProcessingPlasmaProperty): Used for LTE ionization (at the radiation temperature). Indexed by atomic number, ion number. Columns are zones. """ - outputs = ('phi',) - latex_name = ('\\Phi',) - latex_formula = ('\\dfrac{2Z_{i,j+1}}{Z_{i,j}}\\Big(\ + + outputs = ("phi",) + latex_name = ("\\Phi",) + latex_formula = ( + "\\dfrac{2Z_{i,j+1}}{Z_{i,j}}\\Big(\ \\dfrac{2\\pi m_{e}/\\beta_{\\textrm{rad}}}{h^2}\ \\Big)^{3/2}e^{\\dfrac{-\\chi_{i,j}}{kT_{\ - \\textrm{rad}}}}',) + \\textrm{rad}}}}", + ) broadcast_ionization_energy = None @@ -46,9 +56,12 @@ class PhiSahaLTE(ProcessingPlasmaProperty): def calculate(g_electron, beta_rad, partition_function, ionization_data): phis = np.empty( - (partition_function.shape[0] - - partition_function.index.get_level_values(0).unique().size, - partition_function.shape[1])) + ( + partition_function.shape[0] + - partition_function.index.get_level_values(0).unique().size, + partition_function.shape[1], + ) + ) block_ids = calculate_block_ids_from_dataframe(partition_function) @@ -56,15 +69,19 @@ def calculate(g_electron, beta_rad, partition_function, ionization_data): end_id = block_ids[i + 1] current_block = partition_function.values[start_id:end_id] current_phis = current_block[1:] / current_block[:-1] - phis[start_id - i:end_id - i - 1] = current_phis + phis[start_id - i : end_id - i - 1] = current_phis - broadcast_ionization_energy = ( - ionization_data[partition_function.index].dropna()) + broadcast_ionization_energy = ionization_data[ + partition_function.index + ].dropna() phi_index = broadcast_ionization_energy.index broadcast_ionization_energy = broadcast_ionization_energy.values - phi_coefficient = (2 * g_electron * np.exp( - np.outer(broadcast_ionization_energy, -beta_rad))) + phi_coefficient = ( + 2 + * g_electron + * np.exp(np.outer(broadcast_ionization_energy, -beta_rad)) + ) return pd.DataFrame(phis * phi_coefficient, index=phi_index) @@ -80,19 +97,28 @@ class ThermalPhiSahaLTE(PhiSahaLTE): Used for LTE ionization (at the electron temperature). Indexed by atomic number, ion number. Columns are zones. """ - outputs = ('thermal_phi_lte',) - latex_name = ('\\Phi^{*}(T_\\mathrm{e})',) - latex_formula = ('\\dfrac{2Z_{i,j+1}}{Z_{i,j}}\\Big(\ + + outputs = ("thermal_phi_lte",) + latex_name = ("\\Phi^{*}(T_\\mathrm{e})",) + latex_formula = ( + "\\dfrac{2Z_{i,j+1}}{Z_{i,j}}\\Big(\ \\dfrac{2\\pi m_{e}/\\beta_{\\textrm{electron}}}{h^2}\ \\Big)^{3/2}e^{\\dfrac{-\\chi_{i,j}}{kT_{\ - \\textrm{electron}}}}',) + \\textrm{electron}}}}", + ) @staticmethod - def calculate(thermal_g_electron, beta_electron, - thermal_lte_partition_function, ionization_data): + def calculate( + thermal_g_electron, + beta_electron, + thermal_lte_partition_function, + ionization_data, + ): return super(ThermalPhiSahaLTE, ThermalPhiSahaLTE).calculate( - thermal_g_electron, beta_electron, thermal_lte_partition_function, - ionization_data + thermal_g_electron, + beta_electron, + thermal_lte_partition_function, + ionization_data, ) @@ -102,39 +128,63 @@ class PhiSahaNebular(ProcessingPlasmaProperty): phi : Pandas DataFrame, dtype float Used for nebular ionization. Indexed by atomic number, ion number. Columns are zones. """ - outputs = ('phi',) - latex_name = ('\\Phi',) - latex_formula = ('W(\\delta\\zeta_{i,j}+W(1-\\zeta_{i,j}))\\left(\ + + outputs = ("phi",) + latex_name = ("\\Phi",) + latex_formula = ( + "W(\\delta\\zeta_{i,j}+W(1-\\zeta_{i,j}))\\left(\ \\dfrac{T_{\\textrm{electron}}}{T_{\\textrm{rad}}}\ - \\right)^{1/2}',) + \\right)^{1/2}", + ) + @staticmethod - def calculate(t_rad, w, zeta_data, t_electrons, delta, - g_electron, beta_rad, partition_function, ionization_data): - phi_lte = PhiSahaLTE.calculate(g_electron, beta_rad, - partition_function, ionization_data) + def calculate( + t_rad, + w, + zeta_data, + t_electrons, + delta, + g_electron, + beta_rad, + partition_function, + ionization_data, + ): + phi_lte = PhiSahaLTE.calculate( + g_electron, beta_rad, partition_function, ionization_data + ) zeta = PhiSahaNebular.get_zeta_values(zeta_data, phi_lte.index, t_rad) - phis = phi_lte * w * ((zeta * delta) + w * (1 - zeta)) * \ - (t_electrons/t_rad) ** .5 + phis = ( + phi_lte + * w + * ((zeta * delta) + w * (1 - zeta)) + * (t_electrons / t_rad) ** 0.5 + ) return phis @staticmethod def get_zeta_values(zeta_data, ion_index, t_rad): zeta_t_rad = zeta_data.columns.values.astype(np.float64) zeta_values = zeta_data.loc[ion_index].values.astype(np.float64) - zeta = interpolate.interp1d(zeta_t_rad, zeta_values, bounds_error=False, - fill_value=np.nan)(t_rad) + zeta = interpolate.interp1d( + zeta_t_rad, zeta_values, bounds_error=False, fill_value=np.nan + )(t_rad) zeta = zeta.astype(float) if np.any(np.isnan(zeta)): - warnings.warn('t_rads outside of zeta factor interpolation' - ' zeta_min={0:.2f} zeta_max={1:.2f} ' - '- replacing with 1s'.format( - zeta_data.columns.values.min(), zeta_data.columns.values.max(), - t_rad)) + warnings.warn( + "t_rads outside of zeta factor interpolation" + " zeta_min={0:.2f} zeta_max={1:.2f} " + "- replacing with 1s".format( + zeta_data.columns.values.min(), + zeta_data.columns.values.max(), + t_rad, + ) + ) zeta[np.isnan(zeta)] = 1.0 return zeta + class RadiationFieldCorrection(ProcessingPlasmaProperty): """ Attributes: @@ -144,11 +194,17 @@ class RadiationFieldCorrection(ProcessingPlasmaProperty): Ca II, which is good for type Ia supernovae. For type II supernovae, (1, 1) should be used. Indexed by atomic number, ion number. The columns are zones. """ - outputs = ('delta',) - latex_name = ('\\delta',) - def __init__(self, plasma_parent=None, departure_coefficient=None, - chi_0_species=(20,2), delta_treatment=None): + outputs = ("delta",) + latex_name = ("\\delta",) + + def __init__( + self, + plasma_parent=None, + departure_coefficient=None, + chi_0_species=(20, 2), + delta_treatment=None, + ): super(RadiationFieldCorrection, self).__init__(plasma_parent) self.departure_coefficient = departure_coefficient self.delta_treatment = delta_treatment @@ -160,36 +216,48 @@ def _set_chi_0(self, ionization_data): else: self.chi_0 = ionization_data.loc[self.chi_0_species] - def calculate(self, w, ionization_data, beta_rad, t_electrons, t_rad, - beta_electron): - if getattr(self, 'chi_0', None) is None: + def calculate( + self, w, ionization_data, beta_rad, t_electrons, t_rad, beta_electron + ): + if getattr(self, "chi_0", None) is None: self._set_chi_0(ionization_data) if self.delta_treatment is None: if self.departure_coefficient is None: - departure_coefficient = 1. / w + departure_coefficient = 1.0 / w else: departure_coefficient = self.departure_coefficient - radiation_field_correction = -np.ones((len(ionization_data), len( - beta_rad))) - less_than_chi_0 = ( - ionization_data < self.chi_0).values - factor_a = (t_electrons / (departure_coefficient * w * t_rad)) - radiation_field_correction[~less_than_chi_0] = factor_a * \ - np.exp(np.outer(ionization_data.values[ - ~less_than_chi_0], beta_rad - beta_electron)) - radiation_field_correction[less_than_chi_0] = 1 - np.exp(np.outer( - ionization_data.values[less_than_chi_0], - beta_rad) - beta_rad * self.chi_0) + radiation_field_correction = -np.ones( + (len(ionization_data), len(beta_rad)) + ) + less_than_chi_0 = (ionization_data < self.chi_0).values + factor_a = t_electrons / (departure_coefficient * w * t_rad) + radiation_field_correction[~less_than_chi_0] = factor_a * np.exp( + np.outer( + ionization_data.values[~less_than_chi_0], + beta_rad - beta_electron, + ) + ) + radiation_field_correction[less_than_chi_0] = 1 - np.exp( + np.outer(ionization_data.values[less_than_chi_0], beta_rad) + - beta_rad * self.chi_0 + ) radiation_field_correction[less_than_chi_0] += factor_a * np.exp( - np.outer(ionization_data.values[ - less_than_chi_0],beta_rad) - self.chi_0 * beta_electron) + np.outer(ionization_data.values[less_than_chi_0], beta_rad) + - self.chi_0 * beta_electron + ) else: - radiation_field_correction = np.ones((len(ionization_data), - len(beta_rad))) * self.delta_treatment - delta = pd.DataFrame(radiation_field_correction, - columns=np.arange(len(t_rad)), index=ionization_data.index) + radiation_field_correction = ( + np.ones((len(ionization_data), len(beta_rad))) + * self.delta_treatment + ) + delta = pd.DataFrame( + radiation_field_correction, + columns=np.arange(len(t_rad)), + index=ionization_data.index, + ) return delta + class IonNumberDensity(ProcessingPlasmaProperty): """ Attributes: @@ -206,19 +274,30 @@ class IonNumberDensity(ProcessingPlasmaProperty): value, a new guess for the value of the electron density is chosen and the process is repeated. """ - outputs = ('ion_number_density', 'electron_densities') - latex_name = ('N_{i,j}','n_{e}',) - def __init__(self, plasma_parent, ion_zero_threshold=1e-20, electron_densities=None): + outputs = ("ion_number_density", "electron_densities") + latex_name = ( + "N_{i,j}", + "n_{e}", + ) + + def __init__( + self, plasma_parent, ion_zero_threshold=1e-20, electron_densities=None + ): super(IonNumberDensity, self).__init__(plasma_parent) self.ion_zero_threshold = ion_zero_threshold self.block_ids = None self._electron_densities = electron_densities @staticmethod - def calculate_with_n_electron(phi, partition_function, - number_density, n_electron, block_ids, - ion_zero_threshold): + def calculate_with_n_electron( + phi, + partition_function, + number_density, + n_electron, + block_ids, + ion_zero_threshold, + ): if block_ids is None: block_ids = IonNumberDensity._calculate_block_ids(phi) @@ -231,18 +310,22 @@ def calculate_with_n_electron(phi, partition_function, current_phis = phi_electron[start_id:end_id] phis_product = np.cumprod(current_phis, 0) - tmp_ion_populations = np.empty((current_phis.shape[0] + 1, - current_phis.shape[1])) - tmp_ion_populations[0] = (number_density.values[i] / - (1 + np.sum(phis_product, axis=0))) + tmp_ion_populations = np.empty( + (current_phis.shape[0] + 1, current_phis.shape[1]) + ) + tmp_ion_populations[0] = number_density.values[i] / ( + 1 + np.sum(phis_product, axis=0) + ) tmp_ion_populations[1:] = tmp_ion_populations[0] * phis_product - ion_populations[start_id + i:end_id + 1 + i] = tmp_ion_populations + ion_populations[start_id + i : end_id + 1 + i] = tmp_ion_populations ion_populations[ion_populations < ion_zero_threshold] = 0.0 - return pd.DataFrame(data = ion_populations, - index=partition_function.index), block_ids + return ( + pd.DataFrame(data=ion_populations, index=partition_function.index), + block_ids, + ) @staticmethod def _calculate_block_ids(phi): @@ -255,35 +338,56 @@ def calculate(self, phi, partition_function, number_density): n_electron_iterations = 0 while True: - ion_number_density, self.block_ids = \ - self.calculate_with_n_electron( - phi, partition_function, number_density, n_electron, - self.block_ids, self.ion_zero_threshold) - ion_numbers = ion_number_density.index.get_level_values(1).values + ( + ion_number_density, + self.block_ids, + ) = self.calculate_with_n_electron( + phi, + partition_function, + number_density, + n_electron, + self.block_ids, + self.ion_zero_threshold, + ) + ion_numbers = ion_number_density.index.get_level_values( + 1 + ).values ion_numbers = ion_numbers.reshape((ion_numbers.shape[0], 1)) new_n_electron = (ion_number_density.values * ion_numbers).sum( - axis=0) + axis=0 + ) if np.any(np.isnan(new_n_electron)): - raise PlasmaIonizationError('n_electron just turned "nan" -' - ' aborting') + raise PlasmaIonizationError( + 'n_electron just turned "nan" -' " aborting" + ) n_electron_iterations += 1 if n_electron_iterations > 100: - logger.warn('n_electron iterations above 100 ({0}) -' - ' something is probably wrong'.format( - n_electron_iterations)) - if np.all(np.abs(new_n_electron - n_electron) - / n_electron < n_e_convergence_threshold): + logger.warn( + "n_electron iterations above 100 ({0}) -" + " something is probably wrong".format( + n_electron_iterations + ) + ) + if np.all( + np.abs(new_n_electron - n_electron) / n_electron + < n_e_convergence_threshold + ): break n_electron = 0.5 * (new_n_electron + n_electron) else: n_electron = self._electron_densities - ion_number_density, self.block_ids = \ - self.calculate_with_n_electron( - phi, partition_function, number_density, n_electron, - self.block_ids, self.ion_zero_threshold) + ion_number_density, self.block_ids = self.calculate_with_n_electron( + phi, + partition_function, + number_density, + n_electron, + self.block_ids, + self.ion_zero_threshold, + ) return ion_number_density, n_electron + class IonNumberDensityHeNLTE(ProcessingPlasmaProperty): """ Attributes: @@ -300,80 +404,124 @@ class IonNumberDensityHeNLTE(ProcessingPlasmaProperty): value, a new guess for the value of the electron density is chosen and the process is repeated. """ - outputs = ('ion_number_density', 'electron_densities', - 'helium_population_updated') - latex_name = ('N_{i,j}','n_{e}',) - def __init__(self, plasma_parent, ion_zero_threshold=1e-20, electron_densities=None): + outputs = ( + "ion_number_density", + "electron_densities", + "helium_population_updated", + ) + latex_name = ( + "N_{i,j}", + "n_{e}", + ) + + def __init__( + self, plasma_parent, ion_zero_threshold=1e-20, electron_densities=None + ): super(IonNumberDensityHeNLTE, self).__init__(plasma_parent) self.ion_zero_threshold = ion_zero_threshold self.block_ids = None self._electron_densities = electron_densities - def update_he_population(self, helium_population, n_electron, - number_density): + def update_he_population( + self, helium_population, n_electron, number_density + ): helium_population_updated = helium_population.copy() he_one_population = helium_population_updated.loc[0].mul(n_electron) he_three_population = helium_population_updated.loc[2].mul( - 1./n_electron) + 1.0 / n_electron + ) helium_population_updated.loc[0].update(he_one_population) helium_population_updated.loc[2].update(he_three_population) unnormalised = helium_population_updated.sum() - normalised = helium_population_updated.mul(number_density.loc[2] / - unnormalised) + normalised = helium_population_updated.mul( + number_density.loc[2] / unnormalised + ) helium_population_updated.update(normalised) return helium_population_updated - def calculate(self, phi, partition_function, number_density, - helium_population): + def calculate( + self, phi, partition_function, number_density, helium_population + ): if self._electron_densities is None: n_e_convergence_threshold = 0.05 n_electron = number_density.sum(axis=0) n_electron_iterations = 0 while True: - ion_number_density, self.block_ids = \ - IonNumberDensity.calculate_with_n_electron( - phi, partition_function, number_density, n_electron, - self.block_ids, self.ion_zero_threshold) + ( + ion_number_density, + self.block_ids, + ) = IonNumberDensity.calculate_with_n_electron( + phi, + partition_function, + number_density, + n_electron, + self.block_ids, + self.ion_zero_threshold, + ) helium_population_updated = self.update_he_population( - helium_population, n_electron, number_density) - ion_number_density.loc[2, 0].update(helium_population_updated.loc[ - 0].sum(axis=0)) - ion_number_density.loc[2, 1].update(helium_population_updated.loc[ - 1].sum(axis=0)) - ion_number_density.loc[2, 2].update(helium_population_updated.loc[ - 2, 0]) - ion_numbers = ion_number_density.index.get_level_values(1).values + helium_population, n_electron, number_density + ) + ion_number_density.loc[2, 0].update( + helium_population_updated.loc[0].sum(axis=0) + ) + ion_number_density.loc[2, 1].update( + helium_population_updated.loc[1].sum(axis=0) + ) + ion_number_density.loc[2, 2].update( + helium_population_updated.loc[2, 0] + ) + ion_numbers = ion_number_density.index.get_level_values( + 1 + ).values ion_numbers = ion_numbers.reshape((ion_numbers.shape[0], 1)) new_n_electron = (ion_number_density.values * ion_numbers).sum( - axis=0) + axis=0 + ) if np.any(np.isnan(new_n_electron)): - raise PlasmaIonizationError('n_electron just turned "nan" -' - ' aborting') + raise PlasmaIonizationError( + 'n_electron just turned "nan" -' " aborting" + ) n_electron_iterations += 1 if n_electron_iterations > 100: - logger.warn('n_electron iterations above 100 ({0}) -' - ' something is probably wrong'.format( - n_electron_iterations)) - if np.all(np.abs(new_n_electron - n_electron) - / n_electron < n_e_convergence_threshold): + logger.warn( + "n_electron iterations above 100 ({0}) -" + " something is probably wrong".format( + n_electron_iterations + ) + ) + if np.all( + np.abs(new_n_electron - n_electron) / n_electron + < n_e_convergence_threshold + ): break n_electron = 0.5 * (new_n_electron + n_electron) else: n_electron = self._electron_densities - ion_number_density, self.block_ids = \ - IonNumberDensity.calculate_with_n_electron( - phi, partition_function, number_density, n_electron, - self.block_ids, self.ion_zero_threshold) + ( + ion_number_density, + self.block_ids, + ) = IonNumberDensity.calculate_with_n_electron( + phi, + partition_function, + number_density, + n_electron, + self.block_ids, + self.ion_zero_threshold, + ) helium_population_updated = self.update_he_population( - helium_population, n_electron, number_density) - ion_number_density.loc[2, 0].update(helium_population_updated.loc[ - 0].sum(axis=0)) - ion_number_density.loc[2, 1].update(helium_population_updated.loc[ - 1].sum(axis=0)) - ion_number_density.loc[2, 2].update(helium_population_updated.loc[ - 2, 0]) + helium_population, n_electron, number_density + ) + ion_number_density.loc[2, 0].update( + helium_population_updated.loc[0].sum(axis=0) + ) + ion_number_density.loc[2, 1].update( + helium_population_updated.loc[1].sum(axis=0) + ) + ion_number_density.loc[2, 2].update( + helium_population_updated.loc[2, 0] + ) return ion_number_density, n_electron, helium_population_updated @@ -388,27 +536,34 @@ class SahaFactor(ProcessingPlasmaProperty): Indexed by atom number, ion number, level number. Columns are zones. """ - outputs = ('phi_ik',) - latex_name = ('\\Phi_{i,\\kappa}',) - def calculate(self, thermal_phi_lte, thermal_lte_level_boltzmann_factor, - thermal_lte_partition_function): + outputs = ("phi_ik",) + latex_name = ("\\Phi_{i,\\kappa}",) + + def calculate( + self, + thermal_phi_lte, + thermal_lte_level_boltzmann_factor, + thermal_lte_partition_function, + ): boltzmann_factor = self._prepare_boltzmann_factor( thermal_lte_level_boltzmann_factor ) phi_saha_index = get_ion_multi_index(boltzmann_factor.index) - partition_function_index = get_ion_multi_index(boltzmann_factor.index, - next_higher=False) + partition_function_index = get_ion_multi_index( + boltzmann_factor.index, next_higher=False + ) phi_saha = thermal_phi_lte.loc[phi_saha_index].values # Replace zero values in phi_saha to avoid zero division in Saha factor phi_saha[phi_saha == 0.0] = sys.float_info.min partition_function = thermal_lte_partition_function.loc[ - partition_function_index].values + partition_function_index + ].values return boltzmann_factor / (phi_saha * partition_function) @staticmethod def _prepare_boltzmann_factor(boltzmann_factor): atomic_number = boltzmann_factor.index.get_level_values(0) ion_number = boltzmann_factor.index.get_level_values(1) - selected_ions_mask = (atomic_number != ion_number) + selected_ions_mask = atomic_number != ion_number return boltzmann_factor[selected_ions_mask] diff --git a/tardis/plasma/properties/j_blues.py b/tardis/plasma/properties/j_blues.py index b85fe5b551e..3e28302fa98 100644 --- a/tardis/plasma/properties/j_blues.py +++ b/tardis/plasma/properties/j_blues.py @@ -2,81 +2,92 @@ import pandas as pd from tardis import constants as const -from tardis.plasma.properties.base import (ProcessingPlasmaProperty, - DataFrameInput) +from tardis.plasma.properties.base import ( + ProcessingPlasmaProperty, + DataFrameInput, +) from tardis.util.base import intensity_black_body class JBluesBlackBody(ProcessingPlasmaProperty): - ''' + """ Attributes ---------- lte_j_blues : Pandas DataFrame, dtype float J_blue values as calculated in LTE. - ''' - outputs = ('j_blues',) - latex_name = ('J^{b}_{lu(LTE)}') + """ + + outputs = ("j_blues",) + latex_name = "J^{b}_{lu(LTE)}" @staticmethod def calculate(lines, nu, t_rad): j_blues = intensity_black_body(nu.values[np.newaxis].T, t_rad) - j_blues = pd.DataFrame(j_blues, index=lines.index, - columns=np.arange(len(t_rad))) + j_blues = pd.DataFrame( + j_blues, index=lines.index, columns=np.arange(len(t_rad)) + ) return j_blues class JBluesDiluteBlackBody(ProcessingPlasmaProperty): - outputs = ('j_blues',) - latex_name = ('J_{\\textrm{blue}}') + outputs = ("j_blues",) + latex_name = "J_{\\textrm{blue}}" @staticmethod def calculate(lines, nu, t_rad, w): j_blues = w * intensity_black_body(nu.values[np.newaxis].T, t_rad) - j_blues = pd.DataFrame(j_blues, index=lines.index, - columns=np.arange(len(t_rad))) + j_blues = pd.DataFrame( + j_blues, index=lines.index, columns=np.arange(len(t_rad)) + ) return j_blues class JBluesDetailed(ProcessingPlasmaProperty): - outputs = ('j_blues',) - latex_name = ('J_{\\textrm{blue}}') + outputs = ("j_blues",) + latex_name = "J_{\\textrm{blue}}" def __init__(self, plasma_parent, w_epsilon): super(JBluesDetailed, self).__init__(plasma_parent) self.w_epsilon = w_epsilon - def calculate(self, lines, nu, t_rad, w, j_blues_norm_factor, - j_blue_estimator): + def calculate( + self, lines, nu, t_rad, w, j_blues_norm_factor, j_blue_estimator + ): # Used for initialization if len(j_blue_estimator) == 0: return JBluesDiluteBlackBody.calculate(lines, nu, t_rad, w) else: j_blues = pd.DataFrame( - j_blue_estimator * - j_blues_norm_factor.value, + j_blue_estimator * j_blues_norm_factor.value, index=lines.index, - columns=np.arange(len(t_rad))) + columns=np.arange(len(t_rad)), + ) for i in range(len(t_rad)): zero_j_blues = j_blues[i] == 0.0 j_blues[i][zero_j_blues] = ( - self.w_epsilon * - intensity_black_body(nu[zero_j_blues].values, - t_rad[i])) + self.w_epsilon + * intensity_black_body(nu[zero_j_blues].values, t_rad[i]) + ) return j_blues class JBluesNormFactor(ProcessingPlasmaProperty): - outputs = ('j_blues_norm_factor',) - latex = ('\\frac{c time_\\textrm{simulation}}}{4 \\pi ' - 'time_\\textrm{simulation} volume}') + outputs = ("j_blues_norm_factor",) + latex = ( + "\\frac{c time_\\textrm{simulation}}}{4 \\pi " + "time_\\textrm{simulation} volume}" + ) @staticmethod def calculate(time_explosion, time_simulation, volume): - return (const.c.cgs * time_explosion / - (4 * np.pi * time_simulation * volume)) + return ( + const.c.cgs + * time_explosion + / (4 * np.pi * time_simulation * volume) + ) class JBluesEstimator(DataFrameInput): - outputs = ('j_blue_estimator',) + outputs = ("j_blue_estimator",) diff --git a/tardis/plasma/properties/level_population.py b/tardis/plasma/properties/level_population.py index a35a25d3192..6b8dd2aef31 100644 --- a/tardis/plasma/properties/level_population.py +++ b/tardis/plasma/properties/level_population.py @@ -6,7 +6,7 @@ logger = logging.getLogger(__name__) -__all__ = ['LevelNumberDensity', 'LevelNumberDensityHeNLTE'] +__all__ = ["LevelNumberDensity", "LevelNumberDensityHeNLTE"] class LevelNumberDensity(ProcessingPlasmaProperty): @@ -18,9 +18,10 @@ class LevelNumberDensity(ProcessingPlasmaProperty): Index atom number, ion number, level number. Columns are zones. """ - outputs = ('level_number_density',) - latex_name = ('N_{i,j,k}',) - latex_formula = ('N_{i,j}\\dfrac{bf_{i,j,k}}{Z_{i,j}}',) + + outputs = ("level_number_density",) + latex_name = ("N_{i,j,k}",) + latex_formula = ("N_{i,j}\\dfrac{bf_{i,j,k}}{Z_{i,j}}",) def __init__(self, plasma_parent): super(LevelNumberDensity, self).__init__(plasma_parent) @@ -28,13 +29,19 @@ def __init__(self, plasma_parent): self.initialize_indices = True def _initialize_indices(self, levels, partition_function): - indexer = pd.Series(np.arange(partition_function.shape[0]), - index=partition_function.index) + indexer = pd.Series( + np.arange(partition_function.shape[0]), + index=partition_function.index, + ) self._ion2level_idx = indexer.loc[levels.droplevel(2)].values def _calculate_dilute_lte( - self, level_boltzmann_factor, ion_number_density, - levels, partition_function): + self, + level_boltzmann_factor, + ion_number_density, + levels, + partition_function, + ): """ Calculate the level populations from the level_boltzmann_factor, ion_number_density and partition_function @@ -43,15 +50,20 @@ def _calculate_dilute_lte( self._initialize_indices(levels, partition_function) self.initialize_indices = False partition_function_broadcast = partition_function.values[ - self._ion2level_idx] - level_population_fraction = (level_boltzmann_factor.values / - partition_function_broadcast) + self._ion2level_idx + ] + level_population_fraction = ( + level_boltzmann_factor.values / partition_function_broadcast + ) ion_number_density_broadcast = ion_number_density.values[ - self._ion2level_idx] - level_number_density = (level_population_fraction * - ion_number_density_broadcast) - return pd.DataFrame(level_number_density, - index=level_boltzmann_factor.index) + self._ion2level_idx + ] + level_number_density = ( + level_population_fraction * ion_number_density_broadcast + ) + return pd.DataFrame( + level_number_density, index=level_boltzmann_factor.index + ) calculate = _calculate_dilute_lte @@ -65,17 +77,24 @@ class LevelNumberDensityHeNLTE(LevelNumberDensity): """ def calculate( - self, level_boltzmann_factor, - ion_number_density, levels, partition_function, - helium_population_updated): + self, + level_boltzmann_factor, + ion_number_density, + levels, + partition_function, + helium_population_updated, + ): """ If one of the two helium NLTE methods is used, this updates the helium level populations to the appropriate values. """ level_number_density = self._calculate_dilute_lte( - level_boltzmann_factor, ion_number_density, levels, - partition_function) + level_boltzmann_factor, + ion_number_density, + levels, + partition_function, + ) if helium_population_updated is not None: level_number_density.loc[2].update(helium_population_updated) return level_number_density diff --git a/tardis/plasma/properties/nlte.py b/tardis/plasma/properties/nlte.py index e319e5bd8d6..ef7e58c3cc2 100644 --- a/tardis/plasma/properties/nlte.py +++ b/tardis/plasma/properties/nlte.py @@ -4,128 +4,193 @@ import numpy as np import pandas as pd -from tardis.plasma.properties.base import (PreviousIterationProperty, - ProcessingPlasmaProperty) +from tardis.plasma.properties.base import ( + PreviousIterationProperty, + ProcessingPlasmaProperty, +) from tardis.plasma.properties.ion_population import PhiSahaNebular -__all__ = ['PreviousElectronDensities', 'PreviousBetaSobolev', - 'HeliumNLTE', 'HeliumNumericalNLTE'] +__all__ = [ + "PreviousElectronDensities", + "PreviousBetaSobolev", + "HeliumNLTE", + "HeliumNumericalNLTE", +] logger = logging.getLogger(__name__) + class PreviousElectronDensities(PreviousIterationProperty): """ Attributes ---------- previous_electron_densities : The values for the electron densities converged upon in the previous iteration. """ - outputs = ('previous_electron_densities',) + + outputs = ("previous_electron_densities",) def set_initial_value(self, kwargs): - initial_value = pd.Series( - 1000000.0, - index=kwargs['abundance'].columns, - ) + initial_value = pd.Series(1000000.0, index=kwargs["abundance"].columns,) self._set_initial_value(initial_value) + class PreviousBetaSobolev(PreviousIterationProperty): """ Attributes ---------- previous_beta_sobolev : The beta sobolev values converged upon in the previous iteration. """ - outputs = ('previous_beta_sobolev',) + + outputs = ("previous_beta_sobolev",) def set_initial_value(self, kwargs): initial_value = pd.DataFrame( - 1., - index=kwargs['atomic_data'].lines.index, - columns=kwargs['abundance'].columns, - ) + 1.0, + index=kwargs["atomic_data"].lines.index, + columns=kwargs["abundance"].columns, + ) self._set_initial_value(initial_value) + class HeliumNLTE(ProcessingPlasmaProperty): - outputs = ('helium_population',) + outputs = ("helium_population",) @staticmethod - def calculate(level_boltzmann_factor, - ionization_data, beta_rad, g, g_electron, w, t_rad, t_electrons, - delta, zeta_data, number_density, partition_function): + def calculate( + level_boltzmann_factor, + ionization_data, + beta_rad, + g, + g_electron, + w, + t_rad, + t_electrons, + delta, + zeta_data, + number_density, + partition_function, + ): """ Updates all of the helium level populations according to the helium NLTE recomb approximation. """ helium_population = level_boltzmann_factor.loc[2].copy() # He I excited states - he_one_population = HeliumNLTE.calculate_helium_one(g_electron, beta_rad, - ionization_data, level_boltzmann_factor, g, w) + he_one_population = HeliumNLTE.calculate_helium_one( + g_electron, beta_rad, ionization_data, level_boltzmann_factor, g, w + ) helium_population.loc[0].update(he_one_population) - #He I ground state + # He I ground state helium_population.loc[0, 0] = 0.0 - #He II excited states + # He II excited states he_two_population = level_boltzmann_factor.loc[2, 1].mul( - (g.loc[2, 1, 0]**(-1.0))) + (g.loc[2, 1, 0] ** (-1.0)) + ) helium_population.loc[1].update(he_two_population) - #He II ground state + # He II ground state helium_population.loc[1, 0] = 1.0 - #He III states - helium_population.loc[2, 0] = HeliumNLTE.calculate_helium_three(t_rad, w, - zeta_data, t_electrons, delta, g_electron, beta_rad, - ionization_data, g) -# unnormalised = helium_population.sum() -# normalised = helium_population.mul(number_density.ix[2] / -# unnormalised) -# helium_population.update(normalised) + # He III states + helium_population.loc[2, 0] = HeliumNLTE.calculate_helium_three( + t_rad, + w, + zeta_data, + t_electrons, + delta, + g_electron, + beta_rad, + ionization_data, + g, + ) + # unnormalised = helium_population.sum() + # normalised = helium_population.mul(number_density.ix[2] / unnormalised) + # helium_population.update(normalised) return helium_population @staticmethod - def calculate_helium_one(g_electron, beta_rad, ionization_data, - level_boltzmann_factor, g, w): + def calculate_helium_one( + g_electron, beta_rad, ionization_data, level_boltzmann_factor, g, w + ): """ Calculates the He I level population values, in equilibrium with the He II ground state. """ - return level_boltzmann_factor.loc[2,0] * (1./(2*g.loc[2, 1, 0])) * \ - (1/g_electron) * (1/(w**2.)) * np.exp( - ionization_data.loc[2,1] * beta_rad) + return ( + level_boltzmann_factor.loc[2, 0] + * (1.0 / (2 * g.loc[2, 1, 0])) + * (1 / g_electron) + * (1 / (w ** 2.0)) + * np.exp(ionization_data.loc[2, 1] * beta_rad) + ) @staticmethod - def calculate_helium_three(t_rad, w, zeta_data, t_electrons, delta, - g_electron, beta_rad, ionization_data, g): + def calculate_helium_three( + t_rad, + w, + zeta_data, + t_electrons, + delta, + g_electron, + beta_rad, + ionization_data, + g, + ): """ Calculates the He III level population values. """ zeta = PhiSahaNebular.get_zeta_values(zeta_data, 2, t_rad)[1] - he_three_population = 2 * \ - (float(g.loc[2, 2, 0]) / g.loc[2, 1, 0]) * g_electron * \ - np.exp(-ionization_data.loc[2, 2] * beta_rad) \ - * w * (delta.loc[2, 2] * zeta + w * (1. - zeta)) * \ - (t_electrons / t_rad) ** 0.5 + he_three_population = ( + 2 + * (float(g.loc[2, 2, 0]) / g.loc[2, 1, 0]) + * g_electron + * np.exp(-ionization_data.loc[2, 2] * beta_rad) + * w + * (delta.loc[2, 2] * zeta + w * (1.0 - zeta)) + * (t_electrons / t_rad) ** 0.5 + ) return he_three_population + class HeliumNumericalNLTE(ProcessingPlasmaProperty): - ''' + """ IMPORTANT: This particular property requires a specific numerical NLTE solver and a specific atomic dataset (neither of which are distributed with Tardis) to work. - ''' - outputs = ('helium_population',) + """ + + outputs = ("helium_population",) + def __init__(self, plasma_parent, heating_rate_data_file): super(HeliumNumericalNLTE, self).__init__(plasma_parent) self._g_upper = None self._g_lower = None - self.heating_rate_data = np.loadtxt( - heating_rate_data_file, unpack=True) - - def calculate(self, ion_number_density, electron_densities, t_electrons, w, - lines, j_blues, levels, level_boltzmann_factor, t_rad, - zeta_data, g_electron, delta, partition_function, ionization_data, - beta_rad, g, time_explosion): - logger.info('Performing numerical NLTE He calculations.') - if len(j_blues)==0: + self.heating_rate_data = np.loadtxt(heating_rate_data_file, unpack=True) + + def calculate( + self, + ion_number_density, + electron_densities, + t_electrons, + w, + lines, + j_blues, + levels, + level_boltzmann_factor, + t_rad, + zeta_data, + g_electron, + delta, + partition_function, + ionization_data, + beta_rad, + g, + time_explosion, + ): + logger.info("Performing numerical NLTE He calculations.") + if len(j_blues) == 0: return None - #Outputting data required by SH module + # Outputting data required by SH module for zone, _ in enumerate(electron_densities): - with open('He_NLTE_Files/shellconditions_{}.txt'.format(zone), - 'w') as output_file: + with open( + "He_NLTE_Files/shellconditions_{}.txt".format(zone), "w" + ) as output_file: output_file.write(ion_number_density.loc[2].sum()[zone]) output_file.write(electron_densities[zone]) output_file.write(t_electrons[zone]) @@ -137,74 +202,109 @@ def calculate(self, ion_number_density, electron_densities, t_electrons, w, output_file.write(self.plasma_parent.v_outer[zone]) for zone, _ in enumerate(electron_densities): - with open('He_NLTE_Files/abundances_{}.txt'.format(zone), 'w') as \ - output_file: - for element in range(1,31): + with open( + "He_NLTE_Files/abundances_{}.txt".format(zone), "w" + ) as output_file: + for element in range(1, 31): try: - number_density = ion_number_density[zone].loc[ - element].sum() + number_density = ( + ion_number_density[zone].loc[element].sum() + ) except: number_density = 0.0 output_file.write(number_density) - helium_lines = lines[lines['atomic_number']==2] - helium_lines = helium_lines[helium_lines['ion_number']==0] + helium_lines = lines[lines["atomic_number"] == 2] + helium_lines = helium_lines[helium_lines["ion_number"] == 0] for zone, _ in enumerate(electron_densities): - with open('He_NLTE_Files/discradfield_{}.txt'.format(zone), 'w') \ - as output_file: + with open( + "He_NLTE_Files/discradfield_{}.txt".format(zone), "w" + ) as output_file: j_blues = pd.DataFrame(j_blues, index=lines.index) helium_j_blues = j_blues[zone].loc[helium_lines.index] for value in helium_lines.index: - if (helium_lines.level_number_lower.loc[value]<35): + if helium_lines.level_number_lower.loc[value] < 35: output_file.write( - int(helium_lines.level_number_lower.loc[value]+1), - int(helium_lines.level_number_upper.loc[value]+1), - j_blues[zone].loc[value]) - #Running numerical simulations + int(helium_lines.level_number_lower.loc[value] + 1), + int(helium_lines.level_number_upper.loc[value] + 1), + j_blues[zone].loc[value], + ) + # Running numerical simulations for zone, _ in enumerate(electron_densities): - os.rename('He_NLTE_Files/abundances{}.txt'.format(zone), - 'He_NLTE_Files/abundances_current.txt') - os.rename('He_NLTE_Files/shellconditions{}.txt'.format(zone), - 'He_NLTE_Files/shellconditions_current.txt') - os.rename('He_NLTE_Files/discradfield{}.txt'.format(zone), - 'He_NLTE_Files/discradfield_current.txt') + os.rename( + "He_NLTE_Files/abundances{}.txt".format(zone), + "He_NLTE_Files/abundances_current.txt", + ) + os.rename( + "He_NLTE_Files/shellconditions{}.txt".format(zone), + "He_NLTE_Files/shellconditions_current.txt", + ) + os.rename( + "He_NLTE_Files/discradfield{}.txt".format(zone), + "He_NLTE_Files/discradfield_current.txt", + ) os.system("nlte-solver-module/bin/nlte_solvertest >/dev/null") - os.rename('He_NLTE_Files/abundances_current.txt', - 'He_NLTE_Files/abundances{}.txt'.format(zone)) - os.rename('He_NLTE_Files/shellconditions_current.txt', - 'He_NLTE_Files/shellconditions{}.txt'.format(zone)) - os.rename('He_NLTE_Files/discradfield_current.txt', - 'He_NLTE_Files/discradfield{}.txt'.format(zone)) - os.rename('debug_occs.dat', 'He_NLTE_Files/occs{}.txt'.format(zone)) - #Reading in populations from files + os.rename( + "He_NLTE_Files/abundances_current.txt", + "He_NLTE_Files/abundances{}.txt".format(zone), + ) + os.rename( + "He_NLTE_Files/shellconditions_current.txt", + "He_NLTE_Files/shellconditions{}.txt".format(zone), + ) + os.rename( + "He_NLTE_Files/discradfield_current.txt", + "He_NLTE_Files/discradfield{}.txt".format(zone), + ) + os.rename("debug_occs.dat", "He_NLTE_Files/occs{}.txt".format(zone)) + # Reading in populations from files helium_population = level_boltzmann_factor.loc[2].copy() for zone, _ in enumerate(electron_densities): - with open('He_NLTE_Files/discradfield{}.txt'.format(zone), 'r') as \ - read_file: + with open( + "He_NLTE_Files/discradfield{}.txt".format(zone), "r" + ) as read_file: for level in range(0, 35): level_population = read_file.readline() level_population = float(level_population) helium_population[zone].loc[0, level] = level_population - helium_population[zone].loc[1, 0] = float( - read_file.readline()) - #Performing He LTE level populations (upper two energy levels, - #He II excited states, He III) - he_one_population = HeliumNLTE.calculate_helium_one(g_electron, - beta_rad, partition_function, ionization_data, - level_boltzmann_factor, electron_densities, g, w, t_rad, - t_electrons) + helium_population[zone].loc[1, 0] = float(read_file.readline()) + # Performing He LTE level populations (upper two energy levels, + # He II excited states, He III) + he_one_population = HeliumNLTE.calculate_helium_one( + g_electron, + beta_rad, + partition_function, + ionization_data, + level_boltzmann_factor, + electron_densities, + g, + w, + t_rad, + t_electrons, + ) helium_population.loc[0, 35].update(he_one_population.loc[35]) helium_population.loc[0, 36].update(he_one_population.loc[36]) he_two_population = level_boltzmann_factor.loc[2, 1, 1:].mul( - (g.loc[2, 1, 0] ** (-1)) * helium_population.loc[s1, 0]) + (g.loc[2, 1, 0] ** (-1)) * helium_population.loc[s1, 0] + ) helium_population.loc[1, 1:].update(he_two_population) helium_population.loc[2, 0] = HeliumNLTE.calculate_helium_three( - t_rad, w, zeta_data, t_electrons, delta, g_electron, beta_rad, - partition_function, ionization_data, electron_densities) + t_rad, + w, + zeta_data, + t_electrons, + delta, + g_electron, + beta_rad, + partition_function, + ionization_data, + electron_densities, + ) unnormalised = helium_population.sum() - normalised = helium_population.mul(ion_number_density.loc[2].sum() - / unnormalised) + normalised = helium_population.mul( + ion_number_density.loc[2].sum() / unnormalised + ) helium_population.update(normalised) return helium_population diff --git a/tardis/plasma/properties/partition_function.py b/tardis/plasma/properties/partition_function.py index 0b31d3c6966..204ede31aa4 100644 --- a/tardis/plasma/properties/partition_function.py +++ b/tardis/plasma/properties/partition_function.py @@ -9,10 +9,15 @@ logger = logging.getLogger(__name__) -__all__ = ['LevelBoltzmannFactorLTE', 'LevelBoltzmannFactorDiluteLTE', - 'LevelBoltzmannFactorNoNLTE', 'LevelBoltzmannFactorNLTE', - 'PartitionFunction', 'ThermalLevelBoltzmannFactorLTE', - 'ThermalLTEPartitionFunction'] +__all__ = [ + "LevelBoltzmannFactorLTE", + "LevelBoltzmannFactorDiluteLTE", + "LevelBoltzmannFactorNoNLTE", + "LevelBoltzmannFactorNLTE", + "PartitionFunction", + "ThermalLevelBoltzmannFactorLTE", + "ThermalLTEPartitionFunction", +] class LevelBoltzmannFactorLTE(ProcessingPlasmaProperty): @@ -26,20 +31,24 @@ class LevelBoltzmannFactorLTE(ProcessingPlasmaProperty): Columns corresponding to zones. Does not consider NLTE. """ - outputs = ('general_level_boltzmann_factor',) - latex_name = ('bf_{i,j,k}',) - latex_formula = ('g_{i,j,k}e^{\\dfrac{-\\epsilon_{i,j,k}}{k_{\ - \\textrm{B}}T_{\\textrm{rad}}}}',) + + outputs = ("general_level_boltzmann_factor",) + latex_name = ("bf_{i,j,k}",) + latex_formula = ( + "g_{i,j,k}e^{\\dfrac{-\\epsilon_{i,j,k}}{k_{\ + \\textrm{B}}T_{\\textrm{rad}}}}", + ) @staticmethod def calculate(excitation_energy, g, beta_rad, levels): exponential = np.exp(np.outer(excitation_energy.values, -beta_rad)) - level_boltzmann_factor_array = (g.values[np.newaxis].T * - exponential) - level_boltzmann_factor = pd.DataFrame(level_boltzmann_factor_array, - index=levels, - columns=np.arange(len(beta_rad)), - dtype=np.float64) + level_boltzmann_factor_array = g.values[np.newaxis].T * exponential + level_boltzmann_factor = pd.DataFrame( + level_boltzmann_factor_array, + index=levels, + columns=np.arange(len(beta_rad)), + dtype=np.float64, + ) return level_boltzmann_factor @@ -54,17 +63,19 @@ class ThermalLevelBoltzmannFactorLTE(LevelBoltzmannFactorLTE): by atomic number, ion number, level number. Columns corresponding to zones. """ - outputs = ('thermal_lte_level_boltzmann_factor',) - latex_name = ('bf_{i,j,k}^{\\textrm{LTE}}(T_e)',) - latex_formula = ('g_{i,j,k}e^{\\dfrac{-\\epsilon_{i,j,k}}{k_{\ - \\textrm{B}}T_{\\textrm{electron}}}}',) + + outputs = ("thermal_lte_level_boltzmann_factor",) + latex_name = ("bf_{i,j,k}^{\\textrm{LTE}}(T_e)",) + latex_formula = ( + "g_{i,j,k}e^{\\dfrac{-\\epsilon_{i,j,k}}{k_{\ + \\textrm{B}}T_{\\textrm{electron}}}}", + ) @staticmethod def calculate(excitation_energy, g, beta_electron, levels): - return super(ThermalLevelBoltzmannFactorLTE, - ThermalLevelBoltzmannFactorLTE).calculate( - excitation_energy, g, beta_electron, levels - ) + return super( + ThermalLevelBoltzmannFactorLTE, ThermalLevelBoltzmannFactorLTE + ).calculate(excitation_energy, g, beta_electron, levels) class LevelBoltzmannFactorDiluteLTE(ProcessingPlasmaProperty): @@ -79,16 +90,20 @@ class LevelBoltzmannFactorDiluteLTE(ProcessingPlasmaProperty): multiplied by an additional factor W. Does not consider NLTE. """ - outputs = ('general_level_boltzmann_factor',) - latex_name = ('bf_{i,j,k}',) - latex_formula = ('Wg_{i,j,k}e^{\\dfrac{-\\epsilon_{i,j,k}}{k_{\ - \\textrm{B}}T_{\\textrm{rad}}}}',) + + outputs = ("general_level_boltzmann_factor",) + latex_name = ("bf_{i,j,k}",) + latex_formula = ( + "Wg_{i,j,k}e^{\\dfrac{-\\epsilon_{i,j,k}}{k_{\ + \\textrm{B}}T_{\\textrm{rad}}}}", + ) def calculate( - self, levels, g, excitation_energy, beta_rad, w, - metastability): + self, levels, g, excitation_energy, beta_rad, w, metastability + ): level_boltzmann_factor = LevelBoltzmannFactorLTE.calculate( - excitation_energy, g, beta_rad, levels) + excitation_energy, g, beta_rad, levels + ) level_boltzmann_factor[~metastability] *= w return level_boltzmann_factor @@ -101,7 +116,8 @@ class LevelBoltzmannFactorNoNLTE(ProcessingPlasmaProperty): Returns general_level_boltzmann_factor as this property is included if NLTE is not used. """ - outputs = ('level_boltzmann_factor',) + + outputs = ("level_boltzmann_factor",) @staticmethod def calculate(general_level_boltzmann_factor): @@ -116,25 +132,27 @@ class LevelBoltzmannFactorNLTE(ProcessingPlasmaProperty): Returns general_level_boltzmann_factor but updated for those species treated in NLTE. """ - outputs = ('level_boltzmann_factor',) + + outputs = ("level_boltzmann_factor",) def calculate(self): raise AttributeError( - 'This attribute is not defined on the parent class.' - 'Please use one of the subclasses.') + "This attribute is not defined on the parent class." + "Please use one of the subclasses." + ) @staticmethod def from_config(nlte_conf): if nlte_conf.classical_nebular and not nlte_conf.coronal_approximation: return LevelBoltzmannFactorNLTEClassic elif ( - nlte_conf.coronal_approximation and - not nlte_conf.classical_nebular): + nlte_conf.coronal_approximation and not nlte_conf.classical_nebular + ): return LevelBoltzmannFactorNLTECoronal elif nlte_conf.coronal_approximation and nlte_conf.classical_nebular: - raise PlasmaConfigError('Both coronal approximation and ' - 'classical nebular specified in the ' - 'config.') + raise PlasmaConfigError( + "Both coronal approximation and classical nebular specified in the config." + ) else: return LevelBoltzmannFactorNLTEGeneral @@ -148,19 +166,26 @@ def __init__(self, plasma_parent): self._update_inputs() def _main_nlte_calculation( - self, atomic_data, nlte_data, t_electrons, - j_blues, beta_sobolevs, general_level_boltzmann_factor, - previous_electron_densities, g): + self, + atomic_data, + nlte_data, + t_electrons, + j_blues, + beta_sobolevs, + general_level_boltzmann_factor, + previous_electron_densities, + g, + ): """ The core of the NLTE calculation, used with all possible config. options. """ for species in nlte_data.nlte_species: - logger.info('Calculating rates for species %s', species) + logger.info(f"Calculating rates for species {species}") number_of_levels = atomic_data.levels.energy.loc[species].count() lnl = nlte_data.lines_level_number_lower[species] lnu = nlte_data.lines_level_number_upper[species] - lines_index, = nlte_data.lines_idx[species] + (lines_index,) = nlte_data.lines_idx[species] try: j_blues_filtered = j_blues.iloc[lines_index] @@ -176,27 +201,33 @@ def _main_nlte_calculation( r_lu_index = lnu * number_of_levels + lnl r_ul_index = lnl * number_of_levels + lnu r_ul_matrix = np.zeros( - (number_of_levels, number_of_levels, len(t_electrons)), - dtype=np.float64) + (number_of_levels, number_of_levels, len(t_electrons)), + dtype=np.float64, + ) r_ul_matrix_reshaped = r_ul_matrix.reshape( - (number_of_levels**2, len(t_electrons))) - r_ul_matrix_reshaped[r_ul_index] = A_uls[np.newaxis].T + \ - B_uls[np.newaxis].T * j_blues_filtered + (number_of_levels ** 2, len(t_electrons)) + ) + r_ul_matrix_reshaped[r_ul_index] = ( + A_uls[np.newaxis].T + B_uls[np.newaxis].T * j_blues_filtered + ) r_ul_matrix_reshaped[r_ul_index] *= beta_sobolevs_filtered r_lu_matrix = np.zeros_like(r_ul_matrix) r_lu_matrix_reshaped = r_lu_matrix.reshape( - (number_of_levels**2, len(t_electrons))) - r_lu_matrix_reshaped[r_lu_index] = B_lus[np.newaxis].T * \ - j_blues_filtered * beta_sobolevs_filtered + (number_of_levels ** 2, len(t_electrons)) + ) + r_lu_matrix_reshaped[r_lu_index] = ( + B_lus[np.newaxis].T * j_blues_filtered * beta_sobolevs_filtered + ) if atomic_data.collision_data is None: collision_matrix = np.zeros_like(r_ul_matrix) else: if previous_electron_densities is None: collision_matrix = np.zeros_like(r_ul_matrix) else: - collision_matrix = nlte_data.get_collision_matrix( - species, t_electrons - ) * previous_electron_densities.values + collision_matrix = ( + nlte_data.get_collision_matrix(species, t_electrons) + * previous_electron_densities.values + ) rates_matrix = r_lu_matrix + r_ul_matrix + collision_matrix for i in range(number_of_levels): rates_matrix[i, i] = -rates_matrix[:, i].sum(axis=0) @@ -206,23 +237,35 @@ def _main_nlte_calculation( for i in range(len(t_electrons)): try: level_boltzmann_factor = np.linalg.solve( - rates_matrix[:, :, i], x) + rates_matrix[:, :, i], x + ) except LinAlgError as e: - if e.message == 'Singular matrix': + if e.message == "Singular matrix": raise ValueError( - 'SingularMatrixError during solving of the ' - 'rate matrix. Does the atomic data contain ' - 'collision data?') + "SingularMatrixError during solving of the " + "rate matrix. Does the atomic data contain " + "collision data?" + ) else: raise e - general_level_boltzmann_factor[i].ix[species] = \ - level_boltzmann_factor * g.loc[species][0] / level_boltzmann_factor[0] + general_level_boltzmann_factor[i].ix[species] = ( + level_boltzmann_factor + * g.loc[species][0] + / level_boltzmann_factor[0] + ) return general_level_boltzmann_factor def _calculate_classical_nebular( - self, t_electrons, lines, atomic_data, - nlte_data, general_level_boltzmann_factor, j_blues, - previous_electron_densities, g): + self, + t_electrons, + lines, + atomic_data, + nlte_data, + general_level_boltzmann_factor, + j_blues, + previous_electron_densities, + g, + ): """ Performs NLTE calculations using the classical nebular treatment. All beta sobolev values taken as 1. @@ -230,19 +273,27 @@ def _calculate_classical_nebular( beta_sobolevs = 1.0 general_level_boltzmann_factor = self._main_nlte_calculation( - atomic_data, - nlte_data, - t_electrons, - j_blues, - beta_sobolevs, - general_level_boltzmann_factor, - previous_electron_densities, g) + atomic_data, + nlte_data, + t_electrons, + j_blues, + beta_sobolevs, + general_level_boltzmann_factor, + previous_electron_densities, + g, + ) return general_level_boltzmann_factor def _calculate_coronal_approximation( - self, t_electrons, lines, atomic_data, - nlte_data, general_level_boltzmann_factor, - previous_electron_densities, g): + self, + t_electrons, + lines, + atomic_data, + nlte_data, + general_level_boltzmann_factor, + previous_electron_densities, + g, + ): """ Performs NLTE calculations using the coronal approximation. All beta sobolev values taken as 1 and j_blues taken as 0. @@ -250,15 +301,29 @@ def _calculate_coronal_approximation( beta_sobolevs = 1.0 j_blues = 0.0 general_level_boltzmann_factor = self._main_nlte_calculation( - atomic_data, nlte_data, t_electrons, j_blues, - beta_sobolevs, general_level_boltzmann_factor, - previous_electron_densities, g) + atomic_data, + nlte_data, + t_electrons, + j_blues, + beta_sobolevs, + general_level_boltzmann_factor, + previous_electron_densities, + g, + ) return general_level_boltzmann_factor def _calculate_general( - self, t_electrons, lines, atomic_data, nlte_data, - general_level_boltzmann_factor, j_blues, - previous_beta_sobolev, previous_electron_densities, g): + self, + t_electrons, + lines, + atomic_data, + nlte_data, + general_level_boltzmann_factor, + j_blues, + previous_beta_sobolev, + previous_electron_densities, + g, + ): """ Full NLTE calculation without approximations. """ @@ -268,9 +333,15 @@ def _calculate_general( beta_sobolevs = previous_beta_sobolev general_level_boltzmann_factor = self._main_nlte_calculation( - atomic_data, nlte_data, t_electrons, j_blues, - beta_sobolevs, general_level_boltzmann_factor, - previous_electron_densities, g) + atomic_data, + nlte_data, + t_electrons, + j_blues, + beta_sobolevs, + general_level_boltzmann_factor, + previous_electron_densities, + g, + ) return general_level_boltzmann_factor @@ -294,13 +365,15 @@ class PartitionFunction(ProcessingPlasmaProperty): Indexed by atomic number, ion number. Columns are zones. """ - outputs = ('partition_function',) - latex_name = ('Z_{i,j}',) - latex_formula = ('\\sum_{k}bf_{i,j,k}',) + + outputs = ("partition_function",) + latex_name = ("Z_{i,j}",) + latex_formula = ("\\sum_{k}bf_{i,j,k}",) def calculate(self, level_boltzmann_factor): return level_boltzmann_factor.groupby( - level=['atomic_number', 'ion_number']).sum() + level=["atomic_number", "ion_number"] + ).sum() class ThermalLTEPartitionFunction(PartitionFunction): @@ -311,8 +384,9 @@ class ThermalLTEPartitionFunction(PartitionFunction): Indexed by atomic number, ion number. Columns are zones. """ - outputs = ('thermal_lte_partition_function',) - latex_name = ('Z_{i,j}(T_\\mathrm{e}',) + + outputs = ("thermal_lte_partition_function",) + latex_name = ("Z_{i,j}(T_\\mathrm{e}",) def calculate(self, thermal_lte_level_boltzmann_factor): return super(ThermalLTEPartitionFunction, self).calculate( diff --git a/tardis/plasma/properties/plasma_input.py b/tardis/plasma/properties/plasma_input.py index a9e0cb90f55..b5d6c55bd85 100644 --- a/tardis/plasma/properties/plasma_input.py +++ b/tardis/plasma/properties/plasma_input.py @@ -1,9 +1,20 @@ -from tardis.plasma.properties.base import (Input, ArrayInput, DataFrameInput) - -__all__ = ['TRadiative', 'DilutionFactor', 'AtomicData', 'Abundance', 'Density', - 'TimeExplosion', 'JBlueEstimator', 'LinkTRadTElectron', - 'HeliumTreatment', 'RInner', 'TInner', 'Volume', - 'ContinuumInteractionSpecies'] +from tardis.plasma.properties.base import Input, ArrayInput, DataFrameInput + +__all__ = [ + "TRadiative", + "DilutionFactor", + "AtomicData", + "Abundance", + "Density", + "TimeExplosion", + "JBlueEstimator", + "LinkTRadTElectron", + "HeliumTreatment", + "RInner", + "TInner", + "Volume", + "ContinuumInteractionSpecies", +] class TRadiative(ArrayInput): @@ -12,8 +23,9 @@ class TRadiative(ArrayInput): ---------- t_rad : Numpy Array, dtype float """ - outputs = ('t_rad',) - latex_name = ('T_{\\textrm{rad}}',) + + outputs = ("t_rad",) + latex_name = ("T_{\\textrm{rad}}",) class DilutionFactor(ArrayInput): @@ -24,8 +36,9 @@ class DilutionFactor(ArrayInput): Factor used in nebular ionisation / dilute excitation calculations to account for the dilution of the radiation field. """ - outputs = ('w',) - latex_name = ('W',) + + outputs = ("w",) + latex_name = ("W",) class AtomicData(Input): @@ -34,7 +47,8 @@ class AtomicData(Input): ---------- atomic_data : Object """ - outputs = ('atomic_data',) + + outputs = ("atomic_data",) class Abundance(Input): @@ -44,7 +58,8 @@ class Abundance(Input): abundance : Numpy array, dtype float Fractional abundance of elements """ - outputs = ('abundance',) + + outputs = ("abundance",) class Density(ArrayInput): @@ -54,8 +69,9 @@ class Density(ArrayInput): density : Numpy array, dtype float Total density values """ - outputs = ('density',) - latex_name = ('\\rho',) + + outputs = ("density",) + latex_name = ("\\rho",) class TimeExplosion(Input): @@ -65,8 +81,9 @@ class TimeExplosion(Input): time_explosion : Float Time since explosion in seconds """ - outputs = ('time_explosion',) - latex_name = ('t_{\\textrm{exp}}',) + + outputs = ("time_explosion",) + latex_name = ("t_{\\textrm{exp}}",) class JBlueEstimator(ArrayInput): @@ -75,8 +92,9 @@ class JBlueEstimator(ArrayInput): ---------- j_blue_estimators : Numpy array """ - outputs = ('j_blue_estimators',) - latex_name = ('J_{\\textrm{blue-estimator}}',) + + outputs = ("j_blue_estimators",) + latex_name = ("J_{\\textrm{blue-estimator}}",) class LinkTRadTElectron(Input): @@ -87,24 +105,25 @@ class LinkTRadTElectron(Input): Value used for estimate of electron temperature. Default is 0.9. """ - outputs = ('link_t_rad_t_electron',) - latex_name = ('T_{\\textrm{electron}}/T_{\\textrm{rad}}',) + + outputs = ("link_t_rad_t_electron",) + latex_name = ("T_{\\textrm{electron}}/T_{\\textrm{rad}}",) class HeliumTreatment(Input): - outputs = ('helium_treatment',) + outputs = ("helium_treatment",) class RInner(Input): - outputs = ('r_inner',) + outputs = ("r_inner",) class TInner(Input): - outputs = ('t_inner',) + outputs = ("t_inner",) class Volume(Input): - outputs = ('volume',) + outputs = ("volume",) class ContinuumInteractionSpecies(Input): @@ -115,4 +134,5 @@ class ContinuumInteractionSpecies(Input): Atomic and ion numbers of elements for which continuum interactions (radiative/collisional ionization and recombination) are treated """ - outputs = ('continuum_interaction_species',) + + outputs = ("continuum_interaction_species",) diff --git a/tardis/plasma/properties/properties.py b/tardis/plasma/properties/properties.py deleted file mode 100644 index 43e6f024367..00000000000 --- a/tardis/plasma/properties/properties.py +++ /dev/null @@ -1,3 +0,0 @@ -#### Importing properties from other modules ######## - -###################################################### \ No newline at end of file diff --git a/tardis/plasma/properties/property_collections.py b/tardis/plasma/properties/property_collections.py index 3f943270b3f..4b8eb7fa608 100644 --- a/tardis/plasma/properties/property_collections.py +++ b/tardis/plasma/properties/property_collections.py @@ -1,43 +1,93 @@ from tardis.plasma.properties import * + class PlasmaPropertyCollection(list): pass -basic_inputs = PlasmaPropertyCollection([TRadiative, Abundance, Density, - TimeExplosion, AtomicData, DilutionFactor, LinkTRadTElectron, - HeliumTreatment, ContinuumInteractionSpecies]) -basic_properties = PlasmaPropertyCollection([BetaRadiation, - Levels, Lines, AtomicMass, PartitionFunction, - GElectron, IonizationData, NumberDensity, LinesLowerLevelIndex, - LinesUpperLevelIndex, TauSobolev, - StimulatedEmissionFactor, SelectedAtoms, ElectronTemperature]) + +basic_inputs = PlasmaPropertyCollection( + [ + TRadiative, + Abundance, + Density, + TimeExplosion, + AtomicData, + DilutionFactor, + LinkTRadTElectron, + HeliumTreatment, + ContinuumInteractionSpecies, + ] +) +basic_properties = PlasmaPropertyCollection( + [ + BetaRadiation, + Levels, + Lines, + AtomicMass, + PartitionFunction, + GElectron, + IonizationData, + NumberDensity, + LinesLowerLevelIndex, + LinesUpperLevelIndex, + TauSobolev, + StimulatedEmissionFactor, + SelectedAtoms, + ElectronTemperature, + ] +) lte_ionization_properties = PlasmaPropertyCollection([PhiSahaLTE]) lte_excitation_properties = PlasmaPropertyCollection([LevelBoltzmannFactorLTE]) -macro_atom_properties = PlasmaPropertyCollection([BetaSobolev, - TransitionProbabilities]) -nebular_ionization_properties = PlasmaPropertyCollection([PhiSahaNebular, - ZetaData, BetaElectron, RadiationFieldCorrection]) -dilute_lte_excitation_properties = PlasmaPropertyCollection([ - LevelBoltzmannFactorDiluteLTE]) +macro_atom_properties = PlasmaPropertyCollection( + [BetaSobolev, TransitionProbabilities] +) +nebular_ionization_properties = PlasmaPropertyCollection( + [PhiSahaNebular, ZetaData, BetaElectron, RadiationFieldCorrection] +) +dilute_lte_excitation_properties = PlasmaPropertyCollection( + [LevelBoltzmannFactorDiluteLTE] +) non_nlte_properties = PlasmaPropertyCollection([LevelBoltzmannFactorNoNLTE]) -nlte_properties = PlasmaPropertyCollection([ - LevelBoltzmannFactorNLTE, NLTEData, PreviousElectronDensities, - PreviousBetaSobolev, BetaSobolev]) -helium_nlte_properties = PlasmaPropertyCollection([HeliumNLTE, - RadiationFieldCorrection, ZetaData, - BetaElectron, LevelNumberDensityHeNLTE, IonNumberDensityHeNLTE]) -helium_lte_properties = PlasmaPropertyCollection([LevelNumberDensity, - IonNumberDensity]) -helium_numerical_nlte_properties = PlasmaPropertyCollection([ - HeliumNumericalNLTE]) -detailed_j_blues_inputs = PlasmaPropertyCollection([JBluesEstimator, RInner, - TInner, Volume]) -detailed_j_blues_properties = PlasmaPropertyCollection([JBluesDetailed, - JBluesNormFactor, - LuminosityInner, - TimeSimulation]) +nlte_properties = PlasmaPropertyCollection( + [ + LevelBoltzmannFactorNLTE, + NLTEData, + PreviousElectronDensities, + PreviousBetaSobolev, + BetaSobolev, + ] +) +helium_nlte_properties = PlasmaPropertyCollection( + [ + HeliumNLTE, + RadiationFieldCorrection, + ZetaData, + BetaElectron, + LevelNumberDensityHeNLTE, + IonNumberDensityHeNLTE, + ] +) +helium_lte_properties = PlasmaPropertyCollection( + [LevelNumberDensity, IonNumberDensity] +) +helium_numerical_nlte_properties = PlasmaPropertyCollection( + [HeliumNumericalNLTE] +) +detailed_j_blues_inputs = PlasmaPropertyCollection( + [JBluesEstimator, RInner, TInner, Volume] +) +detailed_j_blues_properties = PlasmaPropertyCollection( + [JBluesDetailed, JBluesNormFactor, LuminosityInner, TimeSimulation] +) continuum_interaction_properties = PlasmaPropertyCollection( - [PhotoIonizationData, SpontRecombRateCoeff, - ThermalLevelBoltzmannFactorLTE, ThermalLTEPartitionFunction, BetaElectron, - ThermalGElectron, ThermalPhiSahaLTE, SahaFactor] + [ + PhotoIonizationData, + SpontRecombRateCoeff, + ThermalLevelBoltzmannFactorLTE, + ThermalLTEPartitionFunction, + BetaElectron, + ThermalGElectron, + ThermalPhiSahaLTE, + SahaFactor, + ] ) diff --git a/tardis/plasma/properties/radiative_properties.py b/tardis/plasma/properties/radiative_properties.py index 3d26f8cd9ec..21af5713a98 100644 --- a/tardis/plasma/properties/radiative_properties.py +++ b/tardis/plasma/properties/radiative_properties.py @@ -12,8 +12,13 @@ logger = logging.getLogger(__name__) -__all__ = ['StimulatedEmissionFactor', 'TauSobolev', 'BetaSobolev', - 'TransitionProbabilities'] +__all__ = [ + "StimulatedEmissionFactor", + "TauSobolev", + "BetaSobolev", + "TransitionProbabilities", +] + class StimulatedEmissionFactor(ProcessingPlasmaProperty): """ @@ -22,8 +27,9 @@ class StimulatedEmissionFactor(ProcessingPlasmaProperty): stimulated_emission_factor : Numpy Array, dtype float Indexed by lines, columns as zones. """ - outputs = ('stimulated_emission_factor',) - latex_formula = ('1-\\dfrac{g_{lower}n_{upper}}{g_{upper}n_{lower}}',) + + outputs = ("stimulated_emission_factor",) + latex_formula = ("1-\\dfrac{g_{lower}n_{upper}}{g_{upper}n_{lower}}",) def __init__(self, plasma_parent=None, nlte_species=None): super(StimulatedEmissionFactor, self).__init__(plasma_parent) @@ -33,52 +39,74 @@ def __init__(self, plasma_parent=None, nlte_species=None): def get_g_lower(self, g, lines_lower_level_index): if self._g_lower is None: - g_lower = np.array(g.iloc[lines_lower_level_index], - dtype=np.float64) + g_lower = np.array( + g.iloc[lines_lower_level_index], dtype=np.float64 + ) self._g_lower = g_lower[np.newaxis].T return self._g_lower def get_g_upper(self, g, lines_upper_level_index): if self._g_upper is None: - g_upper = np.array(g.iloc[lines_upper_level_index], - dtype=np.float64) + g_upper = np.array( + g.iloc[lines_upper_level_index], dtype=np.float64 + ) self._g_upper = g_upper[np.newaxis].T return self._g_upper def get_metastable_upper(self, metastability, lines_upper_level_index): - if getattr(self, '_meta_stable_upper', None) is None: + if getattr(self, "_meta_stable_upper", None) is None: self._meta_stable_upper = metastability.values[ - lines_upper_level_index][np.newaxis].T + lines_upper_level_index + ][np.newaxis].T return self._meta_stable_upper - def calculate(self, g, level_number_density, lines_lower_level_index, - lines_upper_level_index, metastability, lines): - n_lower = level_number_density.values.take(lines_lower_level_index, - axis=0, mode='raise') - n_upper = level_number_density.values.take(lines_upper_level_index, - axis=0, mode='raise') + def calculate( + self, + g, + level_number_density, + lines_lower_level_index, + lines_upper_level_index, + metastability, + lines, + ): + n_lower = level_number_density.values.take( + lines_lower_level_index, axis=0, mode="raise" + ) + n_upper = level_number_density.values.take( + lines_upper_level_index, axis=0, mode="raise" + ) g_lower = self.get_g_lower(g, lines_lower_level_index) g_upper = self.get_g_upper(g, lines_upper_level_index) - meta_stable_upper = self.get_metastable_upper(metastability, - lines_upper_level_index) + meta_stable_upper = self.get_metastable_upper( + metastability, lines_upper_level_index + ) - stimulated_emission_factor = ne.evaluate('1 - ((g_lower * n_upper) / ' - '(g_upper * n_lower))') + stimulated_emission_factor = ne.evaluate( + "1 - ((g_lower * n_upper) / " "(g_upper * n_lower))" + ) stimulated_emission_factor[n_lower == 0.0] = 0.0 - stimulated_emission_factor[np.isneginf(stimulated_emission_factor)]\ - = 0.0 - stimulated_emission_factor[meta_stable_upper & - (stimulated_emission_factor < 0)] = 0.0 + stimulated_emission_factor[ + np.isneginf(stimulated_emission_factor) + ] = 0.0 + stimulated_emission_factor[ + meta_stable_upper & (stimulated_emission_factor < 0) + ] = 0.0 if self.nlte_species: - nlte_lines_mask = lines.reset_index().apply( - lambda row: - (row.atomic_number, row.ion_number) in self.nlte_species, - axis=1 - ).values - stimulated_emission_factor[(stimulated_emission_factor < 0) & - nlte_lines_mask[np.newaxis].T] = 0.0 + nlte_lines_mask = ( + lines.reset_index() + .apply( + lambda row: (row.atomic_number, row.ion_number) + in self.nlte_species, + axis=1, + ) + .values + ) + stimulated_emission_factor[ + (stimulated_emission_factor < 0) & nlte_lines_mask[np.newaxis].T + ] = 0.0 return stimulated_emission_factor + class TauSobolev(ProcessingPlasmaProperty): """ Attributes @@ -87,35 +115,65 @@ class TauSobolev(ProcessingPlasmaProperty): Sobolev optical depth for each line. Indexed by line. Columns as zones. """ - outputs = ('tau_sobolevs',) - latex_name = ('\\tau_{\\textrm{sobolev}}',) - latex_formula = ('\\dfrac{\\pi e^{2}}{m_{e} c}f_{lu}\\lambda t_{exp}\ - n_{lower} \\Big(1-\\dfrac{g_{lower}n_{upper}}{g_{upper}n_{lower}}\\Big)',) + + outputs = ("tau_sobolevs",) + latex_name = ("\\tau_{\\textrm{sobolev}}",) + latex_formula = ( + "\\dfrac{\\pi e^{2}}{m_{e} c}f_{lu}\\lambda t_{exp}\ + n_{lower} \\Big(1-\\dfrac{g_{lower}n_{upper}}{g_{upper}n_{lower}}\\Big)", + ) def __init__(self, plasma_parent): super(TauSobolev, self).__init__(plasma_parent) - self.sobolev_coefficient = (((np.pi * const.e.gauss ** 2) / - (const.m_e.cgs * const.c.cgs)) - * u.cm * u.s / u.cm**3).to(1).value - - def calculate(self, lines, level_number_density, lines_lower_level_index, - time_explosion, stimulated_emission_factor, j_blues, - f_lu, wavelength_cm): + self.sobolev_coefficient = ( + ( + ((np.pi * const.e.gauss ** 2) / (const.m_e.cgs * const.c.cgs)) + * u.cm + * u.s + / u.cm ** 3 + ) + .to(1) + .value + ) + + def calculate( + self, + lines, + level_number_density, + lines_lower_level_index, + time_explosion, + stimulated_emission_factor, + j_blues, + f_lu, + wavelength_cm, + ): f_lu = f_lu.values[np.newaxis].T wavelength = wavelength_cm.values[np.newaxis].T - n_lower = level_number_density.values.take(lines_lower_level_index, - axis=0, mode='raise') - tau_sobolevs = (self.sobolev_coefficient * f_lu * wavelength * - time_explosion * n_lower * stimulated_emission_factor) - - if (np.any(np.isnan(tau_sobolevs)) or - np.any(np.isinf(np.abs(tau_sobolevs)))): + n_lower = level_number_density.values.take( + lines_lower_level_index, axis=0, mode="raise" + ) + tau_sobolevs = ( + self.sobolev_coefficient + * f_lu + * wavelength + * time_explosion + * n_lower + * stimulated_emission_factor + ) + + if np.any(np.isnan(tau_sobolevs)) or np.any( + np.isinf(np.abs(tau_sobolevs)) + ): raise ValueError( - 'Some tau_sobolevs are nan, inf, -inf in tau_sobolevs.' - ' Something went wrong!') + "Some tau_sobolevs are nan, inf, -inf in tau_sobolevs." + " Something went wrong!" + ) - return pd.DataFrame(tau_sobolevs, index=lines.index, - columns=np.array(level_number_density.columns)) + return pd.DataFrame( + tau_sobolevs, + index=lines.index, + columns=np.array(level_number_density.columns), + ) class BetaSobolev(ProcessingPlasmaProperty): @@ -124,23 +182,23 @@ class BetaSobolev(ProcessingPlasmaProperty): ---------- beta_sobolev : Numpy Array, dtype float """ - outputs = ('beta_sobolev',) - latex_name = ('\\beta_{\\textrm{sobolev}}',) + + outputs = ("beta_sobolev",) + latex_name = ("\\beta_{\\textrm{sobolev}}",) def calculate(self, tau_sobolevs): - if getattr(self, 'beta_sobolev', None) is None: - initial = 0. + if getattr(self, "beta_sobolev", None) is None: + initial = 0.0 else: initial = self.beta_sobolev beta_sobolev = pd.DataFrame( - initial, - index=tau_sobolevs.index, - columns=tau_sobolevs.columns - ) + initial, index=tau_sobolevs.index, columns=tau_sobolevs.columns + ) - self.calculate_beta_sobolev(tau_sobolevs.values.ravel(), - beta_sobolev.values.ravel()) + self.calculate_beta_sobolev( + tau_sobolevs.values.ravel(), beta_sobolev.values.ravel() + ) return beta_sobolev @staticmethod @@ -148,12 +206,13 @@ def calculate(self, tau_sobolevs): def calculate_beta_sobolev(tau_sobolevs, beta_sobolevs): for i in prange(len(tau_sobolevs)): if tau_sobolevs[i] > 1e3: - beta_sobolevs[i] = tau_sobolevs[i]**-1 + beta_sobolevs[i] = tau_sobolevs[i] ** -1 elif tau_sobolevs[i] < 1e-4: beta_sobolevs[i] = 1 - 0.5 * tau_sobolevs[i] else: beta_sobolevs[i] = (1 - np.exp(-tau_sobolevs[i])) / ( - tau_sobolevs[i]) + tau_sobolevs[i] + ) return beta_sobolevs @@ -163,15 +222,22 @@ class TransitionProbabilities(ProcessingPlasmaProperty): ---------- transition_probabilities : Pandas DataFrame, dtype float """ - outputs = ('transition_probabilities',) + + outputs = ("transition_probabilities",) def __init__(self, plasma_parent): super(TransitionProbabilities, self).__init__(plasma_parent) self.initialize = True - def calculate(self, atomic_data, beta_sobolev, j_blues, - stimulated_emission_factor, tau_sobolevs): - #I wonder why? + def calculate( + self, + atomic_data, + beta_sobolev, + j_blues, + stimulated_emission_factor, + tau_sobolevs, + ): + # I wonder why? # Not sure who wrote this but the answer is that when the plasma is # first initialised (before the first iteration, without temperature # values etc.) there are no j_blues values so this just prevents @@ -180,82 +246,107 @@ def calculate(self, atomic_data, beta_sobolev, j_blues, return None macro_atom_data = self._get_macro_atom_data(atomic_data) if self.initialize: - self.initialize_macro_atom_transition_type_filters(atomic_data, - macro_atom_data) - self.transition_probability_coef = ( - self._get_transition_probability_coefs(macro_atom_data)) + self.initialize_macro_atom_transition_type_filters( + atomic_data, macro_atom_data + ) + self.transition_probability_coef = self._get_transition_probability_coefs( + macro_atom_data + ) self.initialize = False transition_probabilities = self._calculate_transition_probability( - macro_atom_data, - beta_sobolev, - j_blues, - stimulated_emission_factor) - transition_probabilities = pd.DataFrame(transition_probabilities, + macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor + ) + transition_probabilities = pd.DataFrame( + transition_probabilities, index=macro_atom_data.transition_line_id, - columns=tau_sobolevs.columns) + columns=tau_sobolevs.columns, + ) return transition_probabilities - def _calculate_transition_probability(self, macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor): - transition_probabilities = np.empty((self.transition_probability_coef.shape[0], beta_sobolev.shape[1])) - #trans_old = self.calculate_transition_probabilities(macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor) + def _calculate_transition_probability( + self, macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor + ): + transition_probabilities = np.empty( + (self.transition_probability_coef.shape[0], beta_sobolev.shape[1]) + ) + # trans_old = self.calculate_transition_probabilities(macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor) transition_type = macro_atom_data.transition_type.values lines_idx = macro_atom_data.lines_idx.values tpos = macro_atom_data.transition_probability.values macro_atom.calculate_transition_probabilities( - tpos, - beta_sobolev.values, - j_blues.values, - stimulated_emission_factor, - transition_type, - lines_idx, - self.block_references, - transition_probabilities) + tpos, + beta_sobolev.values, + j_blues.values, + stimulated_emission_factor, + transition_type, + lines_idx, + self.block_references, + transition_probabilities, + ) return transition_probabilities - def calculate_transition_probabilities(self, macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor): - transition_probabilities = self.prepare_transition_probabilities(macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor) + def calculate_transition_probabilities( + self, macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor + ): + transition_probabilities = self.prepare_transition_probabilities( + macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor + ) return transition_probabilities - def initialize_macro_atom_transition_type_filters(self, atomic_data, - macro_atom_data): - self.transition_up_filter = (macro_atom_data.transition_type.values - == 1) + def initialize_macro_atom_transition_type_filters( + self, atomic_data, macro_atom_data + ): + self.transition_up_filter = macro_atom_data.transition_type.values == 1 self.transition_up_line_filter = macro_atom_data.lines_idx.values[ - self.transition_up_filter] - self.block_references = np.hstack(( - atomic_data.macro_atom_references.block_references, - len(macro_atom_data))) + self.transition_up_filter + ] + self.block_references = np.hstack( + ( + atomic_data.macro_atom_references.block_references, + len(macro_atom_data), + ) + ) @staticmethod def _get_transition_probability_coefs(macro_atom_data): return macro_atom_data.transition_probability.values[np.newaxis].T - def prepare_transition_probabilities(self, macro_atom_data, beta_sobolev, - j_blues, stimulated_emission_factor): + def prepare_transition_probabilities( + self, macro_atom_data, beta_sobolev, j_blues, stimulated_emission_factor + ): current_beta_sobolev = beta_sobolev.values.take( - macro_atom_data.lines_idx.values, axis=0, mode='raise') - transition_probabilities = self.transition_probability_coef * current_beta_sobolev - j_blues = j_blues.take(self.transition_up_line_filter, axis=0, - mode='raise') + macro_atom_data.lines_idx.values, axis=0, mode="raise" + ) + transition_probabilities = ( + self.transition_probability_coef * current_beta_sobolev + ) + j_blues = j_blues.take( + self.transition_up_line_filter, axis=0, mode="raise" + ) macro_stimulated_emission = stimulated_emission_factor.take( - self.transition_up_line_filter, axis=0, mode='raise') - transition_probabilities[self.transition_up_filter] *= (j_blues * macro_stimulated_emission) + self.transition_up_line_filter, axis=0, mode="raise" + ) + transition_probabilities[self.transition_up_filter] *= ( + j_blues * macro_stimulated_emission + ) return transition_probabilities def _normalize_transition_probabilities(self, transition_probabilities): macro_atom.normalize_transition_probabilities( - transition_probabilities, self.block_references) + transition_probabilities, self.block_references + ) def _new_normalize_transition_probabilities(self, transition_probabilites): for i, start_id in enumerate(self.block_references[:-1]): end_id = self.block_references[i + 1] block = transition_probabilites[start_id:end_id] transition_probabilites[start_id:end_id] *= 1 / ne.evaluate( - 'sum(block, 0)') + "sum(block, 0)" + ) @staticmethod def _get_macro_atom_data(atomic_data): - try: - return atomic_data.macro_atom_data - except: - return atomic_data.macro_atom_data_all + try: + return atomic_data.macro_atom_data + except: + return atomic_data.macro_atom_data_all diff --git a/tardis/plasma/setup_package.py b/tardis/plasma/setup_package.py index a36ef1659a3..0a33ad7fbb9 100644 --- a/tardis/plasma/setup_package.py +++ b/tardis/plasma/setup_package.py @@ -3,23 +3,47 @@ from astropy_helpers.distutils_helpers import get_distutils_option import numpy as np + def get_package_data(): - return {'tardis.plasma.tests':['data/*.dat', 'data/*.txt', 'data/*.yml', 'data/*.h5', 'data/*.dot', 'data/*.tex']} + return { + "tardis.plasma.tests": [ + "data/*.dat", + "data/*.txt", + "data/*.yml", + "data/*.h5", + "data/*.dot", + "data/*.tex", + ] + } + -if get_distutils_option('with_openmp', ['build', 'install', 'develop']) is not None: - compile_args = ['-fopenmp', '-W', '-Wall', '-Wmissing-prototypes', '-std=c99'] - link_args = ['-fopenmp'] - define_macros = [('WITHOPENMP', None)] +if ( + get_distutils_option("with_openmp", ["build", "install", "develop"]) + is not None +): + compile_args = [ + "-fopenmp", + "-W", + "-Wall", + "-Wmissing-prototypes", + "-std=c99", + ] + link_args = ["-fopenmp"] + define_macros = [("WITHOPENMP", None)] else: - compile_args = ['-W', '-Wall', '-Wmissing-prototypes', '-std=c99'] + compile_args = ["-W", "-Wall", "-Wmissing-prototypes", "-std=c99"] link_args = [] define_macros = [] def get_extensions(): - sources = ['tardis/plasma/properties/util/macro_atom.pyx'] - return [Extension('tardis.plasma.properties.util.macro_atom', sources, - include_dirs=['numpy'], - extra_compile_args=compile_args, - extra_link_args=link_args)] - + sources = ["tardis/plasma/properties/util/macro_atom.pyx"] + return [ + Extension( + "tardis.plasma.properties.util.macro_atom", + sources, + include_dirs=["numpy"], + extra_compile_args=compile_args, + extra_link_args=link_args, + ) + ] diff --git a/tardis/plasma/standard_plasmas.py b/tardis/plasma/standard_plasmas.py index 08a3cae1070..b92b4b7f0cf 100644 --- a/tardis/plasma/standard_plasmas.py +++ b/tardis/plasma/standard_plasmas.py @@ -8,24 +8,35 @@ from tardis.io.config_reader import ConfigurationError from tardis.util.base import species_string_to_tuple from tardis.plasma import BasePlasma -from tardis.plasma.properties.property_collections import (basic_inputs, - basic_properties, lte_excitation_properties, lte_ionization_properties, - macro_atom_properties, dilute_lte_excitation_properties, - nebular_ionization_properties, non_nlte_properties, - nlte_properties, helium_nlte_properties, helium_numerical_nlte_properties, - helium_lte_properties, detailed_j_blues_properties, - detailed_j_blues_inputs, continuum_interaction_properties) +from tardis.plasma.properties.property_collections import ( + basic_inputs, + basic_properties, + lte_excitation_properties, + lte_ionization_properties, + macro_atom_properties, + dilute_lte_excitation_properties, + nebular_ionization_properties, + non_nlte_properties, + nlte_properties, + helium_nlte_properties, + helium_numerical_nlte_properties, + helium_lte_properties, + detailed_j_blues_properties, + detailed_j_blues_inputs, + continuum_interaction_properties, +) from tardis.plasma.exceptions import PlasmaConfigError from tardis.plasma.properties import ( - LevelBoltzmannFactorNLTE, - JBluesBlackBody, - JBluesDiluteBlackBody, - JBluesDetailed, - RadiationFieldCorrection, - StimulatedEmissionFactor, - HeliumNumericalNLTE, - IonNumberDensity) + LevelBoltzmannFactorNLTE, + JBluesBlackBody, + JBluesDiluteBlackBody, + JBluesDetailed, + RadiationFieldCorrection, + StimulatedEmissionFactor, + HeliumNumericalNLTE, + IonNumberDensity, +) logger = logging.getLogger(__name__) @@ -49,133 +60,153 @@ def assemble_plasma(config, model, atom_data=None): """ # Convert the nlte species list to a proper format. - nlte_species = [species_string_to_tuple(s) for s in - config.plasma.nlte.species] + nlte_species = [ + species_string_to_tuple(s) for s in config.plasma.nlte.species + ] # Convert the continuum interaction species list to a proper format. continuum_interaction_species = [ - species_string_to_tuple(s) for s in - config.plasma.continuum_interaction.species + species_string_to_tuple(s) + for s in config.plasma.continuum_interaction.species ] continuum_interaction_species = pd.MultiIndex.from_tuples( - continuum_interaction_species, names=['atomic_number', 'ion_number'] + continuum_interaction_species, names=["atomic_number", "ion_number"] ) if atom_data is None: - if 'atom_data' in config: + if "atom_data" in config: if os.path.isabs(config.atom_data): atom_data_fname = config.atom_data else: - atom_data_fname = os.path.join(config.config_dirname, - config.atom_data) + atom_data_fname = os.path.join( + config.config_dirname, config.atom_data + ) else: - raise ValueError('No atom_data option found in the configuration.') + raise ValueError("No atom_data option found in the configuration.") - logger.info('Reading Atomic Data from %s', atom_data_fname) + logger.info("Reading Atomic Data from %s", atom_data_fname) try: atom_data = AtomData.from_hdf(atom_data_fname) except TypeError as e: - print(e, 'Error might be from the use of an old-format of the atomic database, \n' - 'please see https://github.com/tardis-sn/tardis-refdata/tree/master/atom_data' - ',for the most recent version.') + print( + e, + "Error might be from the use of an old-format of the atomic database, \n" + "please see https://github.com/tardis-sn/tardis-refdata/tree/master/atom_data" + ",for the most recent version.", + ) raise atom_data.prepare_atom_data( model.abundance.index, line_interaction_type=config.plasma.line_interaction_type, - nlte_species=nlte_species) + nlte_species=nlte_species, + ) # Check if continuum interaction species are in selected_atoms continuum_atoms = continuum_interaction_species.get_level_values( - 'atomic_number' + "atomic_number" ) continuum_atoms_in_selected_atoms = np.all( continuum_atoms.isin(atom_data.selected_atomic_numbers) ) if not continuum_atoms_in_selected_atoms: - raise ConfigurationError('Not all continuum interaction species ' - 'belong to atoms that have been specified ' - 'in the configuration.') - - kwargs = dict(t_rad=model.t_radiative, abundance=model.abundance, - density=model.density, atomic_data=atom_data, - time_explosion=model.time_explosion, - w=model.dilution_factor, link_t_rad_t_electron=0.9, - continuum_interaction_species=continuum_interaction_species) + raise ConfigurationError( + "Not all continuum interaction species " + "belong to atoms that have been specified " + "in the configuration." + ) + + kwargs = dict( + t_rad=model.t_radiative, + abundance=model.abundance, + density=model.density, + atomic_data=atom_data, + time_explosion=model.time_explosion, + w=model.dilution_factor, + link_t_rad_t_electron=0.9, + continuum_interaction_species=continuum_interaction_species, + ) plasma_modules = basic_inputs + basic_properties property_kwargs = {} if config.plasma.continuum_interaction.species: plasma_modules += continuum_interaction_properties - if config.plasma.radiative_rates_type == 'blackbody': + if config.plasma.radiative_rates_type == "blackbody": plasma_modules.append(JBluesBlackBody) - elif config.plasma.radiative_rates_type == 'dilute-blackbody': + elif config.plasma.radiative_rates_type == "dilute-blackbody": plasma_modules.append(JBluesDiluteBlackBody) - elif config.plasma.radiative_rates_type == 'detailed': + elif config.plasma.radiative_rates_type == "detailed": plasma_modules += detailed_j_blues_properties + detailed_j_blues_inputs - kwargs.update(r_inner=model.r_inner, - t_inner=model.t_inner, - volume=model.volume, - j_blue_estimator=None) - property_kwargs[JBluesDetailed] = {'w_epsilon': config.plasma.w_epsilon} + kwargs.update( + r_inner=model.r_inner, + t_inner=model.t_inner, + volume=model.volume, + j_blue_estimator=None, + ) + property_kwargs[JBluesDetailed] = {"w_epsilon": config.plasma.w_epsilon} else: - raise ValueError('radiative_rates_type type unknown - %s', - config.plasma.radiative_rates_type) + raise ValueError(f"radiative_rates_type type unknown - {config.plasma.radiative_rates_type}") - if config.plasma.excitation == 'lte': + if config.plasma.excitation == "lte": plasma_modules += lte_excitation_properties - elif config.plasma.excitation == 'dilute-lte': + elif config.plasma.excitation == "dilute-lte": plasma_modules += dilute_lte_excitation_properties - if config.plasma.ionization == 'lte': + if config.plasma.ionization == "lte": plasma_modules += lte_ionization_properties - elif config.plasma.ionization == 'nebular': + elif config.plasma.ionization == "nebular": plasma_modules += nebular_ionization_properties if nlte_species: plasma_modules += nlte_properties nlte_conf = config.plasma.nlte - plasma_modules.append( - LevelBoltzmannFactorNLTE.from_config(nlte_conf) - ) + plasma_modules.append(LevelBoltzmannFactorNLTE.from_config(nlte_conf)) property_kwargs[StimulatedEmissionFactor] = dict( - nlte_species=nlte_species) + nlte_species=nlte_species + ) else: plasma_modules += non_nlte_properties - if config.plasma.line_interaction_type in ('downbranch', 'macroatom'): + if config.plasma.line_interaction_type in ("downbranch", "macroatom"): plasma_modules += macro_atom_properties - if 'delta_treatment' in config.plasma: + if "delta_treatment" in config.plasma: property_kwargs[RadiationFieldCorrection] = dict( - delta_treatment=config.plasma.delta_treatment) + delta_treatment=config.plasma.delta_treatment + ) - if config.plasma.helium_treatment == 'recomb-nlte': + if config.plasma.helium_treatment == "recomb-nlte": plasma_modules += helium_nlte_properties - elif config.plasma.helium_treatment == 'numerical-nlte': + elif config.plasma.helium_treatment == "numerical-nlte": plasma_modules += helium_numerical_nlte_properties # TODO: See issue #633 - if config.plasma.heating_rate_data_file in ['none', None]: - raise PlasmaConfigError('Heating rate data file not specified') + if config.plasma.heating_rate_data_file in ["none", None]: + raise PlasmaConfigError("Heating rate data file not specified") else: property_kwargs[HeliumNumericalNLTE] = dict( - heating_rate_data_file=config.plasma.heating_rate_data_file) + heating_rate_data_file=config.plasma.heating_rate_data_file + ) else: plasma_modules += helium_lte_properties if model._electron_densities: electron_densities = pd.Series(model._electron_densities.cgs.value) - if config.plasma.helium_treatment == 'numerical-nlte': + if config.plasma.helium_treatment == "numerical-nlte": property_kwargs[IonNumberDensityHeNLTE] = dict( - electron_densities=electron_densities) + electron_densities=electron_densities + ) else: property_kwargs[IonNumberDensity] = dict( - electron_densities=electron_densities) + electron_densities=electron_densities + ) - kwargs['helium_treatment'] = config.plasma.helium_treatment + kwargs["helium_treatment"] = config.plasma.helium_treatment - plasma = BasePlasma(plasma_properties=plasma_modules, - property_kwargs=property_kwargs, **kwargs) + plasma = BasePlasma( + plasma_properties=plasma_modules, + property_kwargs=property_kwargs, + **kwargs, + ) return plasma diff --git a/tardis/plasma/tests/test_complete_plasmas.py b/tardis/plasma/tests/test_complete_plasmas.py index 25f3b8cdb36..fc5ffac6086 100644 --- a/tardis/plasma/tests/test_complete_plasmas.py +++ b/tardis/plasma/tests/test_complete_plasmas.py @@ -9,124 +9,158 @@ from tardis.simulation import Simulation ionization = [ - {'ionization': 'nebular'}, - {'ionization': 'lte'}, + {"ionization": "nebular"}, + {"ionization": "lte"}, ] -excitation = [ - {'excitation': 'lte'}, - {'excitation': 'dilute-lte'} -] +excitation = [{"excitation": "lte"}, {"excitation": "dilute-lte"}] radiative_rates_type = [ - {'radiative_rates_type': 'detailed', 'w_epsilon': 1.0e-10}, - {'radiative_rates_type': 'detailed'}, - {'radiative_rates_type': 'blackbody'}, - {'radiative_rates_type': 'dilute-blackbody'} + {"radiative_rates_type": "detailed", "w_epsilon": 1.0e-10}, + {"radiative_rates_type": "detailed"}, + {"radiative_rates_type": "blackbody"}, + {"radiative_rates_type": "dilute-blackbody"}, ] line_interaction_type = [ - {'line_interaction_type': 'scatter'}, - {'line_interaction_type': 'macroatom'}, - {'line_interaction_type': 'downbranch'} + {"line_interaction_type": "scatter"}, + {"line_interaction_type": "macroatom"}, + {"line_interaction_type": "downbranch"}, ] disable_electron_scattering = [ - {'disable_electron_scattering': True}, - {'disable_electron_scattering': False} + {"disable_electron_scattering": True}, + {"disable_electron_scattering": False}, ] nlte = [ - {'nlte': {'species': ['He I'], 'coronal_approximation': True}}, - {'nlte': {'species': ['He I'], 'classical_nebular': True}}, - {'nlte': {'species': ['He I']}} + {"nlte": {"species": ["He I"], "coronal_approximation": True}}, + {"nlte": {"species": ["He I"], "classical_nebular": True}}, + {"nlte": {"species": ["He I"]}}, ] -initial_t_inner = [ - {'initial_t_inner': '10000 K'} -] +initial_t_inner = [{"initial_t_inner": "10000 K"}] -initial_t_rad = [ - {'initial_t_rad': '10000 K'} -] +initial_t_rad = [{"initial_t_rad": "10000 K"}] helium_treatment = [ - {'helium_treatment': 'recomb-nlte'}, - {'helium_treatment': 'recomb-nlte', 'delta_treatment': 0.5} + {"helium_treatment": "recomb-nlte"}, + {"helium_treatment": "recomb-nlte", "delta_treatment": 0.5}, ] config_list = ( - ionization + excitation + radiative_rates_type + - line_interaction_type + disable_electron_scattering + nlte + - initial_t_inner + initial_t_rad + helium_treatment) + ionization + + excitation + + radiative_rates_type + + line_interaction_type + + disable_electron_scattering + + nlte + + initial_t_inner + + initial_t_rad + + helium_treatment +) def idfn(fixture_value): - ''' + """ This function creates a string from a dictionary. We use it to obtain a readable name for the config fixture. - ''' - return str('-'.join([ - '{}:{}'.format(k, v) for k, v in fixture_value.items()])) + """ + return str( + "-".join(["{}:{}".format(k, v) for k, v in fixture_value.items()]) + ) class TestPlasma(object): - general_properties = ['beta_rad', 'g_electron', 'selected_atoms', - 'number_density', 't_electrons', 'w', 't_rad', 'beta_electron'] - partiton_properties = ['level_boltzmann_factor', 'partition_function'] - atomic_properties = ['excitation_energy', 'lines', 'lines_lower_level_index', - 'lines_upper_level_index', 'atomic_mass', 'ionization_data', - 'nu', 'wavelength_cm', 'f_lu', 'metastability'] - ion_population_properties = ['delta', 'previous_electron_densities', - 'phi', 'ion_number_density', 'electron_densities'] - level_population_properties = ['level_number_density'] - radiative_properties = ['stimulated_emission_factor', 'previous_beta_sobolev', - 'tau_sobolevs', 'beta_sobolev', 'transition_probabilities'] - j_blues_properties = ['j_blues', 'j_blues_norm_factor', 'j_blue_estimator'] - input_properties = ['volume', 'r_inner'] - helium_nlte_properties = ['helium_population', 'helium_population_updated'] + general_properties = [ + "beta_rad", + "g_electron", + "selected_atoms", + "number_density", + "t_electrons", + "w", + "t_rad", + "beta_electron", + ] + partiton_properties = ["level_boltzmann_factor", "partition_function"] + atomic_properties = [ + "excitation_energy", + "lines", + "lines_lower_level_index", + "lines_upper_level_index", + "atomic_mass", + "ionization_data", + "nu", + "wavelength_cm", + "f_lu", + "metastability", + ] + ion_population_properties = [ + "delta", + "previous_electron_densities", + "phi", + "ion_number_density", + "electron_densities", + ] + level_population_properties = ["level_number_density"] + radiative_properties = [ + "stimulated_emission_factor", + "previous_beta_sobolev", + "tau_sobolevs", + "beta_sobolev", + "transition_probabilities", + ] + j_blues_properties = ["j_blues", "j_blues_norm_factor", "j_blue_estimator"] + input_properties = ["volume", "r_inner"] + helium_nlte_properties = ["helium_population", "helium_population_updated"] combined_properties = ( - general_properties + partiton_properties + - atomic_properties + ion_population_properties + - level_population_properties + radiative_properties + - j_blues_properties + input_properties + helium_nlte_properties) - - scalars_properties = ['time_explosion', 'link_t_rad_t_electron'] + general_properties + + partiton_properties + + atomic_properties + + ion_population_properties + + level_population_properties + + radiative_properties + + j_blues_properties + + input_properties + + helium_nlte_properties + ) + + scalars_properties = ["time_explosion", "link_t_rad_t_electron"] @pytest.fixture(scope="class") def chianti_he_db_fpath(self, tardis_ref_path): - return os.path.abspath(os.path.join( - tardis_ref_path, 'atom_data', 'chianti_He.h5')) - - @pytest.fixture( - scope="class", - params=config_list, - ids=idfn - ) + return os.path.abspath( + os.path.join(tardis_ref_path, "atom_data", "chianti_He.h5") + ) + + @pytest.fixture(scope="class", params=config_list, ids=idfn) def config(self, request): config_path = os.path.join( - 'tardis', 'plasma', 'tests', 'data', 'plasma_base_test_config.yml') + "tardis", "plasma", "tests", "data", "plasma_base_test_config.yml" + ) config = Configuration.from_yaml(config_path) - hash_string = '' + hash_string = "" for prop, value in request.param.items(): - hash_string = '_'.join((hash_string, prop)) - if prop == 'nlte': + hash_string = "_".join((hash_string, prop)) + if prop == "nlte": for nlte_prop, nlte_value in request.param[prop].items(): config.plasma.nlte[nlte_prop] = nlte_value - if nlte_prop != 'species': - hash_string = '_'.join((hash_string, nlte_prop)) + if nlte_prop != "species": + hash_string = "_".join((hash_string, nlte_prop)) else: config.plasma[prop] = value - hash_string = '_'.join((hash_string, str(value))) + hash_string = "_".join((hash_string, str(value))) hash_string = os.path.join("plasma_unittest", hash_string) - setattr(config.plasma, 'save_path', hash_string) + setattr(config.plasma, "save_path", hash_string) return config @pytest.fixture(scope="class") - def plasma(self, pytestconfig, chianti_he_db_fpath, config, tardis_ref_data): - config['atom_data'] = chianti_he_db_fpath + def plasma( + self, pytestconfig, chianti_he_db_fpath, config, tardis_ref_data + ): + config["atom_data"] = chianti_he_db_fpath sim = Simulation.from_config(config) if pytestconfig.getvalue("--generate-reference"): sim.plasma.to_hdf(tardis_ref_data, path=config.plasma.save_path) @@ -141,7 +175,7 @@ def test_plasma_properties(self, plasma, tardis_ref_data, config, attr): actual = pd.Series(actual) else: actual = pd.DataFrame(actual) - key = os.path.join(config.plasma.save_path, 'plasma', attr) + key = os.path.join(config.plasma.save_path, "plasma", attr) expected = tardis_ref_data[key] pdt.assert_almost_equal(actual, expected) else: @@ -149,32 +183,28 @@ def test_plasma_properties(self, plasma, tardis_ref_data, config, attr): def test_levels(self, plasma, tardis_ref_data, config): actual = pd.DataFrame(plasma.levels) - key = os.path.join( - config.plasma.save_path, 'plasma', 'levels') + key = os.path.join(config.plasma.save_path, "plasma", "levels") expected = tardis_ref_data[key] pdt.assert_almost_equal(actual, expected) @pytest.mark.parametrize("attr", scalars_properties) def test_scalars_properties(self, plasma, tardis_ref_data, config, attr): actual = getattr(plasma, attr) - if hasattr(actual, 'cgs'): + if hasattr(actual, "cgs"): actual = actual.cgs.value - key = os.path.join( - config.plasma.save_path, 'plasma', 'scalars') + key = os.path.join(config.plasma.save_path, "plasma", "scalars") expected = tardis_ref_data[key][attr] pdt.assert_almost_equal(actual, expected) def test_helium_treatment(self, plasma, tardis_ref_data, config): actual = plasma.helium_treatment - key = os.path.join( - config.plasma.save_path, 'plasma', 'scalars') - expected = tardis_ref_data[key]['helium_treatment'] + key = os.path.join(config.plasma.save_path, "plasma", "scalars") + expected = tardis_ref_data[key]["helium_treatment"] assert actual == expected def test_zeta_data(self, plasma, tardis_ref_data, config): - if hasattr(plasma, 'zeta_data'): + if hasattr(plasma, "zeta_data"): actual = plasma.zeta_data - key = os.path.join( - config.plasma.save_path, 'plasma', 'zeta_data') + key = os.path.join(config.plasma.save_path, "plasma", "zeta_data") expected = tardis_ref_data[key] assert_almost_equal(actual, expected.values) diff --git a/tardis/plasma/tests/test_hdf_plasma.py b/tardis/plasma/tests/test_hdf_plasma.py index 87924d558a0..6d97ef659cb 100644 --- a/tardis/plasma/tests/test_hdf_plasma.py +++ b/tardis/plasma/tests/test_hdf_plasma.py @@ -16,77 +16,105 @@ def to_hdf_buffer(hdf_file_path, simulation_verysimple): plasma_properties_list = [ - 'number_density', 'beta_rad', 'general_level_boltzmann_factor', - 'level_boltzmann_factor', 'stimulated_emission_factor', 't_electrons', - 'wavelength_cm', 'lines_lower_level_index', 'ionization_data', 'density', - 'atomic_mass', 'level_number_density', 'lines_upper_level_index', 'nu', - 'beta_sobolev', 'transition_probabilities', 'phi', 'electron_densities', - 't_rad', 'selected_atoms', 'ion_number_density', 'partition_function', - 'abundance', 'g_electron', 'g', 'lines', 'f_lu', 'tau_sobolevs', - 'j_blues', 'metastability', 'w', 'excitation_energy'] + "number_density", + "beta_rad", + "general_level_boltzmann_factor", + "level_boltzmann_factor", + "stimulated_emission_factor", + "t_electrons", + "wavelength_cm", + "lines_lower_level_index", + "ionization_data", + "density", + "atomic_mass", + "level_number_density", + "lines_upper_level_index", + "nu", + "beta_sobolev", + "transition_probabilities", + "phi", + "electron_densities", + "t_rad", + "selected_atoms", + "ion_number_density", + "partition_function", + "abundance", + "g_electron", + "g", + "lines", + "f_lu", + "tau_sobolevs", + "j_blues", + "metastability", + "w", + "excitation_energy", +] @pytest.mark.parametrize("attr", plasma_properties_list) def test_hdf_plasma(hdf_file_path, simulation_verysimple, attr): if hasattr(simulation_verysimple.plasma, attr): actual = getattr(simulation_verysimple.plasma, attr) - if hasattr(actual, 'cgs'): + if hasattr(actual, "cgs"): actual = actual.cgs.value - path = os.path.join('plasma', attr) + path = os.path.join("plasma", attr) expected = pd.read_hdf(hdf_file_path, path) assert_almost_equal(actual, expected.values) def test_hdf_levels(hdf_file_path, simulation_verysimple): - actual = getattr(simulation_verysimple.plasma, 'levels') - if hasattr(actual, 'cgs'): + actual = getattr(simulation_verysimple.plasma, "levels") + if hasattr(actual, "cgs"): actual = actual.cgs.value - path = os.path.join('plasma', 'levels') + path = os.path.join("plasma", "levels") expected = pd.read_hdf(hdf_file_path, path) pdt.assert_almost_equal(pd.DataFrame(actual), expected) -scalars_list = ['time_explosion', 'link_t_rad_t_electron'] +scalars_list = ["time_explosion", "link_t_rad_t_electron"] @pytest.mark.parametrize("attr", scalars_list) def test_hdf_scalars(hdf_file_path, simulation_verysimple, attr): actual = getattr(simulation_verysimple.plasma, attr) - if hasattr(actual, 'cgs'): + if hasattr(actual, "cgs"): actual = actual.cgs.value - path = os.path.join('plasma', 'scalars') + path = os.path.join("plasma", "scalars") expected = pd.read_hdf(hdf_file_path, path)[attr] assert_almost_equal(actual, expected) def test_hdf_helium_treatment(hdf_file_path, simulation_verysimple): - actual = getattr(simulation_verysimple.plasma, 'helium_treatment') - path = os.path.join('plasma', 'scalars') - expected = pd.read_hdf(hdf_file_path, path)['helium_treatment'] + actual = getattr(simulation_verysimple.plasma, "helium_treatment") + path = os.path.join("plasma", "scalars") + expected = pd.read_hdf(hdf_file_path, path)["helium_treatment"] assert actual == expected def test_atomic_data_uuid(hdf_file_path, simulation_verysimple): - actual = getattr(simulation_verysimple.plasma.atomic_data, 'uuid1') - path = os.path.join('plasma', 'scalars') - expected = pd.read_hdf(hdf_file_path, path)['atom_data_uuid'] + actual = getattr(simulation_verysimple.plasma.atomic_data, "uuid1") + path = os.path.join("plasma", "scalars") + expected = pd.read_hdf(hdf_file_path, path)["atom_data_uuid"] assert actual == expected @pytest.fixture(scope="module", autouse=True) def to_hdf_collection_buffer(hdf_file_path, simulation_verysimple): simulation_verysimple.plasma.to_hdf( - hdf_file_path, name='collection', collection=property_collections.basic_inputs) + hdf_file_path, + name="collection", + collection=property_collections.basic_inputs, + ) -collection_properties = ['t_rad', 'w', 'density'] +collection_properties = ["t_rad", "w", "density"] @pytest.mark.parametrize("attr", collection_properties) def test_collection(hdf_file_path, simulation_verysimple, attr): actual = getattr(simulation_verysimple.plasma, attr) - if hasattr(actual, 'cgs'): + if hasattr(actual, "cgs"): actual = actual.cgs.value - path = os.path.join('collection', attr) + path = os.path.join("collection", attr) expected = pd.read_hdf(hdf_file_path, path) assert_almost_equal(actual, expected.values) diff --git a/tardis/plasma/tests/test_plasma_vboundary.py b/tardis/plasma/tests/test_plasma_vboundary.py index 5fdc76ab3bd..57228a60693 100644 --- a/tardis/plasma/tests/test_plasma_vboundary.py +++ b/tardis/plasma/tests/test_plasma_vboundary.py @@ -7,34 +7,44 @@ from tardis.io.atom_data.base import AtomData from tardis.simulation import Simulation -DATA_PATH = os.path.join(tardis.__path__[0], 'plasma', 'tests', 'data') +DATA_PATH = os.path.join(tardis.__path__[0], "plasma", "tests", "data") @pytest.fixture def config_init_trad_fname(): - return os.path.join(DATA_PATH, 'config_init_trad.yml') + return os.path.join(DATA_PATH, "config_init_trad.yml") -@pytest.mark.parametrize("v_inner_boundary, v_outer_boundary", - [(3350, 3650), - (2900, 3750), - (2900, 3850), - (2900, 3900), - (2950, 3750), - (2950, 3850), - (2950, 3900), - (3050, 3750), - (3050, 3850), - (3050, 3900), - (3150, 3750), - (3150, 3850), - (3150, 3900)]) -def test_plasma_vboundary(config_init_trad_fname, v_inner_boundary, - v_outer_boundary, atomic_data_fname): +@pytest.mark.parametrize( + "v_inner_boundary, v_outer_boundary", + [ + (3350, 3650), + (2900, 3750), + (2900, 3850), + (2900, 3900), + (2950, 3750), + (2950, 3850), + (2950, 3900), + (3050, 3750), + (3050, 3850), + (3050, 3900), + (3150, 3750), + (3150, 3850), + (3150, 3900), + ], +) +def test_plasma_vboundary( + config_init_trad_fname, + v_inner_boundary, + v_outer_boundary, + atomic_data_fname, +): tardis_config = Configuration.from_yaml(config_init_trad_fname) tardis_config.atom_data = atomic_data_fname tardis_config.model.structure.v_inner_boundary = ( - v_inner_boundary * u.km / u.s) + v_inner_boundary * u.km / u.s + ) tardis_config.model.structure.v_outer_boundary = ( - v_outer_boundary * u.km / u.s) + v_outer_boundary * u.km / u.s + ) simulation = Simulation.from_config(tardis_config) diff --git a/tardis/plasma/tests/test_tardis_model_density_config.py b/tardis/plasma/tests/test_tardis_model_density_config.py index 22d0a1130b5..6ac84e1db7b 100644 --- a/tardis/plasma/tests/test_tardis_model_density_config.py +++ b/tardis/plasma/tests/test_tardis_model_density_config.py @@ -5,12 +5,12 @@ from tardis.plasma.standard_plasmas import assemble_plasma from numpy.testing import assert_almost_equal -data_path = os.path.join('tardis', 'io', 'tests', 'data') +data_path = os.path.join("tardis", "io", "tests", "data") @pytest.fixture def tardis_model_density_config(): - filename = 'tardis_configv1_tardis_model_format.yml' + filename = "tardis_configv1_tardis_model_format.yml" return Configuration.from_yaml(os.path.join(data_path, filename)) @@ -21,12 +21,14 @@ def raw_model(tardis_model_density_config): @pytest.fixture() def raw_plasma(tardis_model_density_config, raw_model, kurucz_atomic_data): - return assemble_plasma(tardis_model_density_config, raw_model, kurucz_atomic_data) + return assemble_plasma( + tardis_model_density_config, raw_model, kurucz_atomic_data + ) def test_electron_densities(raw_plasma): - assert_almost_equal(raw_plasma.electron_densities[8], 2.72e+14) - assert_almost_equal(raw_plasma.electron_densities[3], 2.6e+14) + assert_almost_equal(raw_plasma.electron_densities[8], 2.72e14) + assert_almost_equal(raw_plasma.electron_densities[3], 2.6e14) def test_t_rad(raw_plasma): diff --git a/tardis/scripts/cmfgen2tardis.py b/tardis/scripts/cmfgen2tardis.py index cef4ee18809..69546c1f0a1 100644 --- a/tardis/scripts/cmfgen2tardis.py +++ b/tardis/scripts/cmfgen2tardis.py @@ -13,7 +13,7 @@ def get_atomic_number(element): index = -1 for atomic_no, row in atomic_dataset.atom_data.iterrows(): - if element in row['name']: + if element in row["name"]: index = atomic_no break return index @@ -39,71 +39,87 @@ def extract_file_block(f): def convert_format(file_path): quantities_row = [] - prop_list = ['Velocity', 'Density', 'Electron density', 'Temperature'] - with open(file_path, 'r') as f: + prop_list = ["Velocity", "Density", "Electron density", "Temperature"] + with open(file_path, "r") as f: for line in f: - items = line.replace('(', '').replace(')', '').split() + items = line.replace("(", "").replace(")", "").split() n = len(items) - if 'data points' in line: - abundances_df = pd.DataFrame(columns=np.arange(int(items[n - 1])), - index=pd.Index([], - name='element'), - dtype=np.float64) + if "data points" in line: + abundances_df = pd.DataFrame( + columns=np.arange(int(items[n - 1])), + index=pd.Index([], name="element"), + dtype=np.float64, + ) if any(prop in line for prop in prop_list): - quantities_row.append(items[n - 1].replace('gm', 'g')) - if 'Time' in line: + quantities_row.append(items[n - 1].replace("gm", "g")) + if "Time" in line: time_of_model = float(items[n - 1]) - if 'Velocity' in line: + if "Velocity" in line: velocity = extract_file_block(f) - if 'Density' in line: + if "Density" in line: density = extract_file_block(f) - if 'Electron density' in line: + if "Electron density" in line: electron_density = extract_file_block(f) - if 'Temperature' in line: + if "Temperature" in line: temperature = extract_file_block(f) - if 'mass fraction\n' in line: - element_string = items[0] - atomic_no = get_atomic_number(element_string.capitalize()) - element_symbol = atomic_dataset.atom_data.loc[atomic_no]['symbol'] + if "mass fraction\n" in line: + element_string = items[0] + atomic_no = get_atomic_number(element_string.capitalize()) + element_symbol = atomic_dataset.atom_data.loc[atomic_no][ + "symbol" + ] - #Its a Isotope - if n == 4: - element_symbol += items[1] + # Its a Isotope + if n == 4: + element_symbol += items[1] - abundances = extract_file_block(f) - abundances_df.loc[element_symbol] = abundances + abundances = extract_file_block(f) + abundances_df.loc[element_symbol] = abundances density_df = pd.DataFrame.from_records( - [velocity, temperature * 10**4, density, electron_density]).transpose() - density_df.columns = ['velocity', 'temperature', - 'densities', 'electron_densities'] + [velocity, temperature * 10 ** 4, density, electron_density] + ).transpose() + density_df.columns = [ + "velocity", + "temperature", + "densities", + "electron_densities", + ] quantities_row += abundances_df.shape[0] * [1] - return abundances_df.transpose(), density_df, time_of_model, quantities_row + return ( + abundances_df.transpose(), + density_df, + time_of_model, + quantities_row, + ) def parse_file(args): abundances_df, density_df, time_of_model, quantities_row = convert_format( - args.input_path) + args.input_path + ) filename = os.path.splitext(os.path.basename(args.input_path))[0] - save_fname = '.'.join((filename, 'csv')) + save_fname = ".".join((filename, "csv")) resultant_df = pd.concat([density_df, abundances_df], axis=1) resultant_df.columns = pd.MultiIndex.from_tuples( - zip(resultant_df.columns, quantities_row)) + zip(resultant_df.columns, quantities_row) + ) save_file_path = os.path.join(args.output_path, save_fname) - with open(save_file_path, 'w') as f: - f.write(" ".join(('t0:', str(time_of_model), 'day'))) + with open(save_file_path, "w") as f: + f.write(" ".join(("t0:", str(time_of_model), "day"))) f.write("\n") - resultant_df.to_csv(save_file_path, index=False, sep=' ', mode='a') + resultant_df.to_csv(save_file_path, index=False, sep=" ", mode="a") def main(): parser = argparse.ArgumentParser() - parser.add_argument('input_path', help='Path to a CMFGEN file') + parser.add_argument("input_path", help="Path to a CMFGEN file") parser.add_argument( - 'output_path', help='Path to store converted TARDIS format files') + "output_path", help="Path to store converted TARDIS format files" + ) args = parser.parse_args() parse_file(args) diff --git a/tardis/simulation/__init__.py b/tardis/simulation/__init__.py index 074cd702f18..8c88a5ab087 100644 --- a/tardis/simulation/__init__.py +++ b/tardis/simulation/__init__.py @@ -1 +1 @@ -from tardis.simulation.base import Simulation \ No newline at end of file +from tardis.simulation.base import Simulation diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 95cddd93527..623980acf2c 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -10,6 +10,7 @@ from tardis.plasma.standard_plasmas import assemble_plasma from tardis.io.util import HDFWriterMixin from tardis.io.config_reader import ConfigurationError + # Adding logging support logger = logging.getLogger(__name__) @@ -23,14 +24,14 @@ class PlasmaStateStorerMixin(object): the electron density in each cell is provided. Additionally, the temperature at the inner boundary is saved. """ + def __init__(self, iterations, no_of_shells): - self.iterations_w = np.zeros( - (iterations, no_of_shells)) - self.iterations_t_rad = np.zeros( - (iterations, no_of_shells)) * u.K + self.iterations_w = np.zeros((iterations, no_of_shells)) + self.iterations_t_rad = np.zeros((iterations, no_of_shells)) * u.K self.iterations_electron_densities = np.zeros( - (iterations, no_of_shells)) + (iterations, no_of_shells) + ) self.iterations_t_inner = np.zeros(iterations) * u.K def store_plasma_state(self, i, w, t_rad, electron_densities, t_inner): @@ -52,8 +53,7 @@ def store_plasma_state(self, i, w, t_rad, electron_densities, t_inner): """ self.iterations_w[i, :] = w self.iterations_t_rad[i, :] = t_rad - self.iterations_electron_densities[i, :] = \ - electron_densities.values + self.iterations_electron_densities[i, :] = electron_densities.values self.iterations_t_inner[i] = t_inner def reshape_plasma_state_store(self, executed_iterations): @@ -65,13 +65,16 @@ def reshape_plasma_state_store(self, executed_iterations): executed_iterations : int iteration index, i.e. number of iterations executed minus one! """ - self.iterations_w = self.iterations_w[:executed_iterations+1, :] - self.iterations_t_rad = \ - self.iterations_t_rad[:executed_iterations+1, :] - self.iterations_electron_densities = \ - self.iterations_electron_densities[:executed_iterations+1, :] - self.iterations_t_inner = \ - self.iterations_t_inner[:executed_iterations+1] + self.iterations_w = self.iterations_w[: executed_iterations + 1, :] + self.iterations_t_rad = self.iterations_t_rad[ + : executed_iterations + 1, : + ] + self.iterations_electron_densities = self.iterations_electron_densities[ + : executed_iterations + 1, : + ] + self.iterations_t_inner = self.iterations_t_inner[ + : executed_iterations + 1 + ] class Simulation(PlasmaStateStorerMixin, HDFWriterMixin): @@ -97,16 +100,33 @@ class Simulation(PlasmaStateStorerMixin, HDFWriterMixin): .. note:: TARDIS must be built with OpenMP support in order for `nthreads` to have effect. """ - hdf_properties = ['model', 'plasma', 'runner', 'iterations_w', - 'iterations_t_rad', 'iterations_electron_densities', - 'iterations_t_inner'] - hdf_name = 'simulation' - def __init__(self, iterations, model, plasma, runner, - no_of_packets, no_of_virtual_packets, luminosity_nu_start, - luminosity_nu_end, last_no_of_packets, - luminosity_requested, convergence_strategy, - nthreads): + hdf_properties = [ + "model", + "plasma", + "runner", + "iterations_w", + "iterations_t_rad", + "iterations_electron_densities", + "iterations_t_inner", + ] + hdf_name = "simulation" + + def __init__( + self, + iterations, + model, + plasma, + runner, + no_of_packets, + no_of_virtual_packets, + luminosity_nu_start, + luminosity_nu_end, + last_no_of_packets, + luminosity_requested, + convergence_strategy, + nthreads, + ): super(Simulation, self).__init__(iterations, model.no_of_shells) @@ -123,32 +143,35 @@ def __init__(self, iterations, model, plasma, runner, self.luminosity_nu_end = luminosity_nu_end self.luminosity_requested = luminosity_requested self.nthreads = nthreads - if convergence_strategy.type in ('damped'): + if convergence_strategy.type in ("damped"): self.convergence_strategy = convergence_strategy self.converged = False self.consecutive_converges_count = 0 - elif convergence_strategy.type in ('custom'): + elif convergence_strategy.type in ("custom"): raise NotImplementedError( - 'Convergence strategy type is custom; ' - 'you need to implement your specific treatment!' + "Convergence strategy type is custom; " + "you need to implement your specific treatment!" ) else: raise ValueError( - 'Convergence strategy type is ' - 'not damped or custom ' - '- input is {0}'.format(convergence_strategy.type)) + "Convergence strategy type is " + "not damped or custom " + "- input is {0}".format(convergence_strategy.type) + ) self._callbacks = OrderedDict() self._cb_next_id = 0 - def estimate_t_inner(self, input_t_inner, luminosity_requested, - t_inner_update_exponent=-0.5): + def estimate_t_inner( + self, input_t_inner, luminosity_requested, t_inner_update_exponent=-0.5 + ): emitted_luminosity = self.runner.calculate_emitted_luminosity( - self.luminosity_nu_start, - self.luminosity_nu_end) + self.luminosity_nu_start, self.luminosity_nu_end + ) luminosity_ratios = ( - (emitted_luminosity / luminosity_requested).to(1).value) + (emitted_luminosity / luminosity_requested).to(1).value + ) return input_t_inner * luminosity_ratios ** t_inner_update_exponent @@ -158,42 +181,53 @@ def damped_converge(value, estimated_value, damping_factor): # as a method return value + damping_factor * (estimated_value - value) - def _get_convergence_status(self, t_rad, w, t_inner, estimated_t_rad, - estimated_w, estimated_t_inner): + def _get_convergence_status( + self, t_rad, w, t_inner, estimated_t_rad, estimated_w, estimated_t_inner + ): # FIXME: Move the convergence checking in its own class. no_of_shells = self.model.no_of_shells - convergence_t_rad = (abs(t_rad - estimated_t_rad) / - estimated_t_rad).value - convergence_w = (abs(w - estimated_w) / estimated_w) - convergence_t_inner = (abs(t_inner - estimated_t_inner) / - estimated_t_inner).value + convergence_t_rad = ( + abs(t_rad - estimated_t_rad) / estimated_t_rad + ).value + convergence_w = abs(w - estimated_w) / estimated_w + convergence_t_inner = ( + abs(t_inner - estimated_t_inner) / estimated_t_inner + ).value fraction_t_rad_converged = ( np.count_nonzero( - convergence_t_rad < self.convergence_strategy.t_rad.threshold) - / no_of_shells) + convergence_t_rad < self.convergence_strategy.t_rad.threshold + ) + / no_of_shells + ) t_rad_converged = ( - fraction_t_rad_converged > self.convergence_strategy.fraction) + fraction_t_rad_converged > self.convergence_strategy.fraction + ) fraction_w_converged = ( np.count_nonzero( - convergence_w < self.convergence_strategy.w.threshold) - / no_of_shells) + convergence_w < self.convergence_strategy.w.threshold + ) + / no_of_shells + ) - w_converged = ( - fraction_w_converged > self.convergence_strategy.fraction) + w_converged = fraction_w_converged > self.convergence_strategy.fraction t_inner_converged = ( - convergence_t_inner < self.convergence_strategy.t_inner.threshold) + convergence_t_inner < self.convergence_strategy.t_inner.threshold + ) if np.all([t_rad_converged, w_converged, t_inner_converged]): hold_iterations = self.convergence_strategy.hold_iterations self.consecutive_converges_count += 1 - logger.info("Iteration converged {0:d}/{1:d} consecutive " - "times.".format(self.consecutive_converges_count, - hold_iterations + 1)) + logger.info( + "Iteration converged {0:d}/{1:d} consecutive " + "times.".format( + self.consecutive_converges_count, hold_iterations + 1 + ) + ) # If an iteration has converged, require hold_iterations more # iterations to converge before we conclude that the Simulation # is converged. @@ -212,36 +246,56 @@ def advance_state(self): ------- converged : ~bool """ - estimated_t_rad, estimated_w = ( - self.runner.calculate_radiationfield_properties()) + ( + estimated_t_rad, + estimated_w, + ) = self.runner.calculate_radiationfield_properties() estimated_t_inner = self.estimate_t_inner( - self.model.t_inner, self.luminosity_requested, - t_inner_update_exponent=self.convergence_strategy.t_inner_update_exponent) - - converged = self._get_convergence_status(self.model.t_rad, - self.model.w, - self.model.t_inner, - estimated_t_rad, - estimated_w, - estimated_t_inner) + self.model.t_inner, + self.luminosity_requested, + t_inner_update_exponent=self.convergence_strategy.t_inner_update_exponent, + ) + + converged = self._get_convergence_status( + self.model.t_rad, + self.model.w, + self.model.t_inner, + estimated_t_rad, + estimated_w, + estimated_t_inner, + ) # calculate_next_plasma_state equivalent # FIXME: Should convergence strategy have its own class? next_t_rad = self.damped_converge( - self.model.t_rad, estimated_t_rad, - self.convergence_strategy.t_rad.damping_constant) + self.model.t_rad, + estimated_t_rad, + self.convergence_strategy.t_rad.damping_constant, + ) next_w = self.damped_converge( - self.model.w, estimated_w, self.convergence_strategy.w.damping_constant) - if (self.iterations_executed + 1) % self.convergence_strategy.lock_t_inner_cycles == 0: + self.model.w, + estimated_w, + self.convergence_strategy.w.damping_constant, + ) + if ( + self.iterations_executed + 1 + ) % self.convergence_strategy.lock_t_inner_cycles == 0: next_t_inner = self.damped_converge( - self.model.t_inner, estimated_t_inner, - self.convergence_strategy.t_inner.damping_constant) + self.model.t_inner, + estimated_t_inner, + self.convergence_strategy.t_inner.damping_constant, + ) else: next_t_inner = self.model.t_inner - self.log_plasma_state(self.model.t_rad, self.model.w, - self.model.t_inner, next_t_rad, next_w, - next_t_inner) + self.log_plasma_state( + self.model.t_rad, + self.model.w, + self.model.t_inner, + next_t_rad, + next_w, + next_t_inner, + ) self.model.t_rad = next_t_rad self.model.w = next_w self.model.t_inner = next_t_inner @@ -249,45 +303,59 @@ def advance_state(self): # model.calculate_j_blues() equivalent # model.update_plasmas() equivalent # Bad test to see if this is a nlte run - if 'nlte_data' in self.plasma.outputs_dict: + if "nlte_data" in self.plasma.outputs_dict: self.plasma.store_previous_properties() update_properties = dict(t_rad=self.model.t_rad, w=self.model.w) # A check to see if the plasma is set with JBluesDetailed, in which # case it needs some extra kwargs. - if 'j_blue_estimator' in self.plasma.outputs_dict: - update_properties.update(t_inner=next_t_inner, - j_blue_estimator=self.runner.j_blue_estimator) + if "j_blue_estimator" in self.plasma.outputs_dict: + update_properties.update( + t_inner=next_t_inner, + j_blue_estimator=self.runner.j_blue_estimator, + ) self.plasma.update(**update_properties) return converged def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False): - logger.info('Starting iteration {0:d}/{1:d}'.format( - self.iterations_executed + 1, self.iterations)) - self.runner.run(self.model, self.plasma, no_of_packets, - no_of_virtual_packets=no_of_virtual_packets, - nthreads=self.nthreads, last_run=last_run) + logger.info( + "Starting iteration {0:d}/{1:d}".format( + self.iterations_executed + 1, self.iterations + ) + ) + self.runner.run( + self.model, + self.plasma, + no_of_packets, + no_of_virtual_packets=no_of_virtual_packets, + nthreads=self.nthreads, + last_run=last_run, + ) output_energy = self.runner.output_energy if np.sum(output_energy < 0) == len(output_energy): logger.critical("No r-packet escaped through the outer boundary.") emitted_luminosity = self.runner.calculate_emitted_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end) + self.luminosity_nu_start, self.luminosity_nu_end + ) reabsorbed_luminosity = self.runner.calculate_reabsorbed_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end) - self.log_run_results(emitted_luminosity, - reabsorbed_luminosity) + self.luminosity_nu_start, self.luminosity_nu_end + ) + self.log_run_results(emitted_luminosity, reabsorbed_luminosity) self.iterations_executed += 1 def run(self): start_time = time.time() - while self.iterations_executed < self.iterations-1: - self.store_plasma_state(self.iterations_executed, self.model.w, - self.model.t_rad, - self.plasma.electron_densities, - self.model.t_inner) + while self.iterations_executed < self.iterations - 1: + self.store_plasma_state( + self.iterations_executed, + self.model.w, + self.model.t_rad, + self.plasma.electron_densities, + self.model.t_inner, + ) self.iterate(self.no_of_packets) self.converged = self.advance_state() self._call_back() @@ -295,22 +363,35 @@ def run(self): if self.convergence_strategy.stop_if_converged: break # Last iteration - self.store_plasma_state(self.iterations_executed, self.model.w, - self.model.t_rad, - self.plasma.electron_densities, - self.model.t_inner) + self.store_plasma_state( + self.iterations_executed, + self.model.w, + self.model.t_rad, + self.plasma.electron_densities, + self.model.t_inner, + ) self.iterate(self.last_no_of_packets, self.no_of_virtual_packets, True) self.reshape_plasma_state_store(self.iterations_executed) - logger.info("Simulation finished in {0:d} iterations " - "and took {1:.2f} s".format( - self.iterations_executed, time.time() - start_time)) + logger.info( + "Simulation finished in {0:d} iterations " + "and took {1:.2f} s".format( + self.iterations_executed, time.time() - start_time + ) + ) self._call_back() - - def log_plasma_state(self, t_rad, w, t_inner, next_t_rad, next_w, - next_t_inner, log_sampling=5): + def log_plasma_state( + self, + t_rad, + w, + t_inner, + next_t_rad, + next_w, + next_t_inner, + log_sampling=5, + ): """ Logging the change of the plasma state @@ -332,31 +413,40 @@ def log_plasma_state(self, t_rad, w, t_inner, next_t_rad, next_w, """ - plasma_state_log = pd.DataFrame(index=np.arange(len(t_rad)), - columns=['t_rad', 'next_t_rad', - 'w', 'next_w']) - plasma_state_log['t_rad'] = t_rad - plasma_state_log['next_t_rad'] = next_t_rad - plasma_state_log['w'] = w - plasma_state_log['next_w'] = next_w + plasma_state_log = pd.DataFrame( + index=np.arange(len(t_rad)), + columns=["t_rad", "next_t_rad", "w", "next_w"], + ) + plasma_state_log["t_rad"] = t_rad + plasma_state_log["next_t_rad"] = next_t_rad + plasma_state_log["w"] = w + plasma_state_log["next_w"] = next_w - plasma_state_log.index.name = 'Shell' + plasma_state_log.index.name = "Shell" plasma_state_log = str(plasma_state_log[::log_sampling]) - plasma_state_log = ''.join(['\t%s\n' % item for item in - plasma_state_log.split('\n')]) + plasma_state_log = "".join( + ["\t%s\n" % item for item in plasma_state_log.split("\n")] + ) - logger.info('Plasma stratification:\n%s\n', plasma_state_log) - logger.info('t_inner {0:.3f} -- next t_inner {1:.3f}'.format( - t_inner, next_t_inner)) + logger.info("Plasma stratification:\n%s\n", plasma_state_log) + logger.info( + "t_inner {0:.3f} -- next t_inner {1:.3f}".format( + t_inner, next_t_inner + ) + ) def log_run_results(self, emitted_luminosity, absorbed_luminosity): - logger.info("Luminosity emitted = {0:.5e} " - "Luminosity absorbed = {1:.5e} " - "Luminosity requested = {2:.5e}".format( - emitted_luminosity, absorbed_luminosity, - self.luminosity_requested)) + logger.info( + "Luminosity emitted = {0:.5e} " + "Luminosity absorbed = {1:.5e} " + "Luminosity requested = {2:.5e}".format( + emitted_luminosity, + absorbed_luminosity, + self.luminosity_requested, + ) + ) def _call_back(self): for cb, args in self._callbacks.values(): @@ -428,54 +518,59 @@ def from_config(cls, config, packet_source=None, **kwargs): """ # Allow overriding some config structures. This is useful in some # unit tests, and could be extended in all the from_config classmethods. - if 'model' in kwargs: - model = kwargs['model'] + if "model" in kwargs: + model = kwargs["model"] else: - if hasattr(config, 'csvy_model'): + if hasattr(config, "csvy_model"): model = Radial1DModel.from_csvy(config) else: model = Radial1DModel.from_config(config) - if 'plasma' in kwargs: - plasma = kwargs['plasma'] + if "plasma" in kwargs: + plasma = kwargs["plasma"] else: - plasma = assemble_plasma(config, model, - atom_data=kwargs.get('atom_data', None)) - if 'runner' in kwargs: + plasma = assemble_plasma( + config, model, atom_data=kwargs.get("atom_data", None) + ) + if "runner" in kwargs: if packet_source is not None: raise ConfigurationError( - 'Cannot specify packet_source and runner at the same time.' + "Cannot specify packet_source and runner at the same time." ) - runner = kwargs['runner'] + runner = kwargs["runner"] else: - runner = MontecarloRunner.from_config(config, - packet_source=packet_source) + runner = MontecarloRunner.from_config( + config, packet_source=packet_source + ) luminosity_nu_start = config.supernova.luminosity_wavelength_end.to( - u.Hz, u.spectral()) + u.Hz, u.spectral() + ) if u.isclose( - config.supernova.luminosity_wavelength_start, 0 * u.angstrom): + config.supernova.luminosity_wavelength_start, 0 * u.angstrom + ): luminosity_nu_end = np.inf * u.Hz else: luminosity_nu_end = ( - const.c / - config.supernova.luminosity_wavelength_start).to(u.Hz) + const.c / config.supernova.luminosity_wavelength_start + ).to(u.Hz) last_no_of_packets = config.montecarlo.last_no_of_packets if last_no_of_packets is None or last_no_of_packets < 0: - last_no_of_packets = config.montecarlo.no_of_packets + last_no_of_packets = config.montecarlo.no_of_packets last_no_of_packets = int(last_no_of_packets) - return cls(iterations=config.montecarlo.iterations, - model=model, - plasma=plasma, - runner=runner, - no_of_packets=int(config.montecarlo.no_of_packets), - no_of_virtual_packets=int( - config.montecarlo.no_of_virtual_packets), - luminosity_nu_start=luminosity_nu_start, - luminosity_nu_end=luminosity_nu_end, - last_no_of_packets=last_no_of_packets, - luminosity_requested=config.supernova.luminosity_requested.cgs, - convergence_strategy=config.montecarlo.convergence_strategy, - nthreads=config.montecarlo.nthreads) + return cls( + iterations=config.montecarlo.iterations, + model=model, + plasma=plasma, + runner=runner, + no_of_packets=int(config.montecarlo.no_of_packets), + no_of_virtual_packets=int(config.montecarlo.no_of_virtual_packets), + luminosity_nu_start=luminosity_nu_start, + luminosity_nu_end=luminosity_nu_end, + last_no_of_packets=last_no_of_packets, + luminosity_requested=config.supernova.luminosity_requested.cgs, + convergence_strategy=config.montecarlo.convergence_strategy, + nthreads=config.montecarlo.nthreads, + ) diff --git a/tardis/simulation/setup_package.py b/tardis/simulation/setup_package.py index fbaa4b0ee0e..10b3b9f8880 100644 --- a/tardis/simulation/setup_package.py +++ b/tardis/simulation/setup_package.py @@ -3,7 +3,6 @@ from astropy_helpers.distutils_helpers import get_distutils_option import numpy as np -def get_package_data(): - return {'tardis.simulation.tests':['data/*.h5']} - +def get_package_data(): + return {"tardis.simulation.tests": ["data/*.h5"]} diff --git a/tardis/simulation/tests/test_simulation.py b/tardis/simulation/tests/test_simulation.py index e9eff0b0574..2f963287285 100644 --- a/tardis/simulation/tests/test_simulation.py +++ b/tardis/simulation/tests/test_simulation.py @@ -10,24 +10,25 @@ import astropy.units as u -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def refdata(tardis_ref_data): def get_ref_data(key): - return tardis_ref_data[os.path.join( - 'test_simulation', key)] + return tardis_ref_data[os.path.join("test_simulation", key)] + return get_ref_data -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def config(): return Configuration.from_yaml( - 'tardis/io/tests/data/tardis_configv1_verysimple.yml') + "tardis/io/tests/data/tardis_configv1_verysimple.yml" + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def simulation_one_loop( - atomic_data_fname, config, - tardis_ref_data, generate_reference): + atomic_data_fname, config, tardis_ref_data, generate_reference +): config.atom_data = atomic_data_fname config.montecarlo.iterations = 2 config.montecarlo.no_of_packets = int(4e4) @@ -40,77 +41,64 @@ def simulation_one_loop( return simulation else: simulation.hdf_properties = [ - 'iterations_w', - 'iterations_t_rad', - 'iterations_electron_densities', - 'iterations_t_inner', + "iterations_w", + "iterations_t_rad", + "iterations_electron_densities", + "iterations_t_inner", ] - simulation.model.hdf_properties = [ - 't_radiative', - 'dilution_factor' - ] + simulation.model.hdf_properties = ["t_radiative", "dilution_factor"] simulation.runner.hdf_properties = [ - 'j_estimator', - 'nu_bar_estimator', - 'output_nu', - 'output_energy' - ] - simulation.to_hdf( - tardis_ref_data, - '', - 'test_simulation' - ) - simulation.model.to_hdf( - tardis_ref_data, - '', - 'test_simulation') - simulation.runner.to_hdf( - tardis_ref_data, - '', - 'test_simulation') - pytest.skip( - 'Reference data was generated during this run.') - - -@pytest.mark.parametrize('name', [ - 'nu_bar_estimator', 'j_estimator', 't_radiative', 'dilution_factor', - 'output_nu', 'output_energy' - ]) -def test_plasma_estimates( - simulation_one_loop, refdata, name): + "j_estimator", + "nu_bar_estimator", + "output_nu", + "output_energy", + ] + simulation.to_hdf(tardis_ref_data, "", "test_simulation") + simulation.model.to_hdf(tardis_ref_data, "", "test_simulation") + simulation.runner.to_hdf(tardis_ref_data, "", "test_simulation") + pytest.skip("Reference data was generated during this run.") + + +@pytest.mark.parametrize( + "name", + [ + "nu_bar_estimator", + "j_estimator", + "t_radiative", + "dilution_factor", + "output_nu", + "output_energy", + ], +) +def test_plasma_estimates(simulation_one_loop, refdata, name): try: - actual = getattr( - simulation_one_loop.runner, name) + actual = getattr(simulation_one_loop.runner, name) except AttributeError: - actual = getattr( - simulation_one_loop.model, name) + actual = getattr(simulation_one_loop.model, name) actual = pd.Series(actual) - pdt.assert_almost_equal( - actual, - refdata(name) - ) + pdt.assert_almost_equal(actual, refdata(name)) -@pytest.mark.parametrize('name', [ - 'iterations_w', 'iterations_t_rad', - 'iterations_electron_densities', 'iterations_t_inner' - ]) -def test_plasma_state_iterations( - simulation_one_loop, refdata, name): - actual = getattr( - simulation_one_loop, name) +@pytest.mark.parametrize( + "name", + [ + "iterations_w", + "iterations_t_rad", + "iterations_electron_densities", + "iterations_t_inner", + ], +) +def test_plasma_state_iterations(simulation_one_loop, refdata, name): + actual = getattr(simulation_one_loop, name) try: actual = pd.Series(actual) except Exception: actual = pd.DataFrame(actual) - pdt.assert_almost_equal( - actual, - refdata(name) - ) + pdt.assert_almost_equal(actual, refdata(name)) @pytest.fixture(scope="module") @@ -121,8 +109,9 @@ def simulation_without_loop(atomic_data_fname, config): return Simulation.from_config(config) -def test_plasma_state_storer_store(atomic_data_fname, config, - simulation_without_loop): +def test_plasma_state_storer_store( + atomic_data_fname, config, simulation_without_loop +): simulation = simulation_without_loop @@ -131,18 +120,21 @@ def test_plasma_state_storer_store(atomic_data_fname, config, electron_densities_test = pd.Series(np.linspace(1e7, 1e6, 20)) t_inner_test = 12500 * u.K - simulation.store_plasma_state(1, w_test, t_rad_test, - electron_densities_test, t_inner_test) + simulation.store_plasma_state( + 1, w_test, t_rad_test, electron_densities_test, t_inner_test + ) np.testing.assert_allclose(simulation.iterations_w[1, :], w_test) np.testing.assert_allclose(simulation.iterations_t_rad[1, :], t_rad_test) - np.testing.assert_allclose(simulation.iterations_electron_densities[1, :], - electron_densities_test) + np.testing.assert_allclose( + simulation.iterations_electron_densities[1, :], electron_densities_test + ) np.testing.assert_allclose(simulation.iterations_t_inner[1], t_inner_test) -def test_plasma_state_storer_reshape(atomic_data_fname, config, - simulation_without_loop): +def test_plasma_state_storer_reshape( + atomic_data_fname, config, simulation_without_loop +): simulation = simulation_without_loop simulation.reshape_plasma_state_store(0) diff --git a/tardis/stats/base.py b/tardis/stats/base.py index a0628a3f70b..e0ea57f684e 100644 --- a/tardis/stats/base.py +++ b/tardis/stats/base.py @@ -1,11 +1,14 @@ import numpy as np + def get_trivial_poisson_uncertainty(model): """ """ emitted_nu = model.montecarlo_nu[model.montecarlo_luminosity >= 0] - emitted_luminosity = model.montecarlo_luminosity[model.montecarlo_luminosity >= 0] + emitted_luminosity = model.montecarlo_luminosity[ + model.montecarlo_luminosity >= 0 + ] freq_bins = model.tardis_config.spectrum.frequency.value bin_counts = np.histogram(emitted_nu, bins=freq_bins)[0] @@ -13,4 +16,3 @@ def get_trivial_poisson_uncertainty(model): uncertainty = np.sqrt(bin_counts) * np.mean(emitted_luminosity) return uncertainty / (freq_bins[1] - freq_bins[0]) - diff --git a/tardis/tests/fixtures/atom_data.py b/tardis/tests/fixtures/atom_data.py index 5851300f504..c0b88f199c3 100644 --- a/tardis/tests/fixtures/atom_data.py +++ b/tardis/tests/fixtures/atom_data.py @@ -5,15 +5,19 @@ from tardis.io.atom_data.base import AtomData -DEFAULT_ATOM_DATA_UUID = '864f1753714343c41f99cb065710cace' +DEFAULT_ATOM_DATA_UUID = "864f1753714343c41f99cb065710cace" + @pytest.fixture(scope="session") def atomic_data_fname(tardis_ref_path): atomic_data_fname = os.path.join( - tardis_ref_path, 'atom_data', 'kurucz_cd23_chianti_H_He.h5') + tardis_ref_path, "atom_data", "kurucz_cd23_chianti_H_He.h5" + ) - atom_data_missing_str = ("{0} atomic datafiles " - "does not seem to exist".format(atomic_data_fname)) + atom_data_missing_str = ( + "{0} atomic datafiles " + "does not seem to exist".format(atomic_data_fname) + ) if not os.path.exists(atomic_data_fname): pytest.exit(atom_data_missing_str) @@ -27,8 +31,10 @@ def atomic_dataset(atomic_data_fname): if atomic_data.md5 != DEFAULT_ATOM_DATA_UUID: pytest.skip( - 'Need default Kurucz atomic dataset (md5="{}"'.format( - DEFAULT_ATOM_DATA_UUID)) + 'Need default Kurucz atomic dataset (md5="{}"'.format( + DEFAULT_ATOM_DATA_UUID + ) + ) else: return atomic_data diff --git a/tardis/tests/integration_tests/conftest.py b/tardis/tests/integration_tests/conftest.py index 29ee6e51a2c..b18b7ee251d 100644 --- a/tardis/tests/integration_tests/conftest.py +++ b/tardis/tests/integration_tests/conftest.py @@ -6,7 +6,10 @@ from tardis import __githash__ as tardis_githash from tardis.tests.integration_tests.report import DokuReport -from tardis.tests.integration_tests.plot_helpers import LocalPlotSaver, RemotePlotSaver +from tardis.tests.integration_tests.plot_helpers import ( + LocalPlotSaver, + RemotePlotSaver, +) def pytest_configure(config): @@ -16,33 +19,44 @@ def pytest_configure(config): os.path.expanduser(integration_tests_configpath) ) config.integration_tests_config = yaml.load( - open(integration_tests_configpath), Loader=yaml.CLoader) + open(integration_tests_configpath), Loader=yaml.CLoader + ) if not config.getoption("--generate-reference"): # Used by DokuReport class to show build environment details in report. config._environment = [] # prevent opening dokupath on slave nodes (xdist) - if not hasattr(config, 'slaveinput'): + if not hasattr(config, "slaveinput"): config.dokureport = DokuReport( - config.integration_tests_config['report']) + config.integration_tests_config["report"] + ) config.pluginmanager.register(config.dokureport) def pytest_unconfigure(config): # Unregister only if it was registered in pytest_configure - if (config.getvalue("integration-tests") and not - config.getoption("--generate-reference")): + if config.getvalue("integration-tests") and not config.getoption( + "--generate-reference" + ): config.pluginmanager.unregister(config.dokureport) def pytest_terminal_summary(terminalreporter): - if (terminalreporter.config.getoption("--generate-reference") and - terminalreporter.config.getvalue("integration-tests")): + if terminalreporter.config.getoption( + "--generate-reference" + ) and terminalreporter.config.getvalue("integration-tests"): # TODO: Add a check whether generation was successful or not. - terminalreporter.write_sep("-", "Generated reference data: {0}".format(os.path.join( - terminalreporter.config.integration_tests_config['reference'], - tardis_githash[:7] - ))) + terminalreporter.write_sep( + "-", + "Generated reference data: {0}".format( + os.path.join( + terminalreporter.config.integration_tests_config[ + "reference" + ], + tardis_githash[:7], + ) + ), + ) @pytest.mark.hookwrapper @@ -60,50 +74,64 @@ def pytest_runtest_makereport(item, call): @pytest.fixture(scope="function") def plot_object(request): integration_tests_config = request.config.integration_tests_config - report_save_mode = integration_tests_config['report']['save_mode'] + report_save_mode = integration_tests_config["report"]["save_mode"] if report_save_mode == "remote": return RemotePlotSaver(request, request.config.dokureport.dokuwiki_url) else: - return LocalPlotSaver(request, os.path.join( - request.config.dokureport.report_dirpath, "assets") + return LocalPlotSaver( + request, + os.path.join(request.config.dokureport.report_dirpath, "assets"), ) -@pytest.fixture(scope="class", params=[ - path for path in glob.glob(os.path.join( - os.path.dirname(os.path.realpath(__file__)), "*")) if os.path.isdir(path) -]) +@pytest.fixture( + scope="class", + params=[ + path + for path in glob.glob( + os.path.join(os.path.dirname(os.path.realpath(__file__)), "*") + ) + if os.path.isdir(path) + ], +) def data_path(request): integration_tests_config = request.config.integration_tests_config hdf_filename = "{0}.h5".format(os.path.basename(request.param)) - if (request.config.getoption("--generate-reference") ): - ref_path = os.path.join(os.path.expandvars(os.path.expanduser( - integration_tests_config['reference'])), tardis_githash[:7] + if request.config.getoption("--generate-reference"): + ref_path = os.path.join( + os.path.expandvars( + os.path.expanduser(integration_tests_config["reference"]) + ), + tardis_githash[:7], ) else: - ref_path = os.path.join(os.path.expandvars( - os.path.expanduser(integration_tests_config['reference'])), hdf_filename + ref_path = os.path.join( + os.path.expandvars( + os.path.expanduser(integration_tests_config["reference"]) + ), + hdf_filename, ) path = { - 'config_dirpath': request.param, - 'reference_path': ref_path, - 'setup_name': hdf_filename[:-3], + "config_dirpath": request.param, + "reference_path": ref_path, + "setup_name": hdf_filename[:-3], # Temporary hack for providing atom data per individual setup. # This url has all the atom data files hosted, for downloading. -# 'atom_data_url': integration_tests_config['atom_data']['atom_data_url'] + # 'atom_data_url': integration_tests_config['atom_data']['atom_data_url'] } # For providing atom data per individual setup. Atom data can be fetched # from a local directory or a remote url. - path['atom_data_path'] = os.path.expandvars(os.path.expanduser( - integration_tests_config['atom_data_path'] - )) - - if (request.config.getoption("--generate-reference") and not - os.path.exists(path['reference_path'])): - os.makedirs(path['reference_path']) + path["atom_data_path"] = os.path.expandvars( + os.path.expanduser(integration_tests_config["atom_data_path"]) + ) + + if request.config.getoption("--generate-reference") and not os.path.exists( + path["reference_path"] + ): + os.makedirs(path["reference_path"]) return path @@ -122,10 +150,12 @@ def reference(request, data_path): return else: try: - reference = pd.HDFStore(data_path['reference_path'], 'r') + reference = pd.HDFStore(data_path["reference_path"], "r") except IOError: - raise IOError('Reference file {0} does not exist and is needed' - ' for the tests'.format(data_path['reference_path'])) + raise IOError( + "Reference file {0} does not exist and is needed" + " for the tests".format(data_path["reference_path"]) + ) else: return reference diff --git a/tardis/tests/integration_tests/plot_helpers.py b/tardis/tests/integration_tests/plot_helpers.py index 69120bba449..d92f066c584 100644 --- a/tardis/tests/integration_tests/plot_helpers.py +++ b/tardis/tests/integration_tests/plot_helpers.py @@ -46,11 +46,21 @@ def save(self, plot, filepath, report): axes = plot.axes[0] if report.passed: - axes.text(0.8, 0.8, 'passed', transform=axes.transAxes, - bbox={'facecolor': 'green', 'alpha': 0.5, 'pad': 10}) + axes.text( + 0.8, + 0.8, + "passed", + transform=axes.transAxes, + bbox={"facecolor": "green", "alpha": 0.5, "pad": 10}, + ) else: - axes.text(0.8, 0.8, 'failed', transform=axes.transAxes, - bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 10}) + axes.text( + 0.8, + 0.8, + "failed", + transform=axes.transAxes, + bbox={"facecolor": "red", "alpha": 0.5, "pad": 10}, + ) plot.savefig(filepath) @@ -76,7 +86,9 @@ def get_extras(self): class RemotePlotSaver(BasePlotSaver): def __init__(self, request, dokuwiki_url): - super(RemotePlotSaver, self).__init__(request, dokuwiki_url=dokuwiki_url) + super(RemotePlotSaver, self).__init__( + request, dokuwiki_url=dokuwiki_url + ) def upload(self, report): """Upload content of ``self._plots`` to ``self.dokuwiki_url``. @@ -93,14 +105,16 @@ def upload(self, report): self.request.config.dokureport.doku_conn.medias.add( "reports:{0}:{1}.png".format(tardis_githash[:7], name), - plot_file.name + plot_file.name, ) - self.plot_html.append(extras.html( - thumbnail_html_remote.format( - dokuwiki_url=self.dokuwiki_url, - githash=tardis_githash[:7], - name=name) + self.plot_html.append( + extras.html( + thumbnail_html_remote.format( + dokuwiki_url=self.dokuwiki_url, + githash=tardis_githash[:7], + name=name, + ) ) ) plot_file.close() @@ -117,7 +131,9 @@ def upload(self, report): class LocalPlotSaver(BasePlotSaver): def __init__(self, request, assets_dirpath): - super(LocalPlotSaver, self).__init__(request, assets_dirpath=assets_dirpath) + super(LocalPlotSaver, self).__init__( + request, assets_dirpath=assets_dirpath + ) def upload(self, report): """Save content of ``self._plots`` to ``self.assets_dirpath``. @@ -129,10 +145,12 @@ def upload(self, report): """ for plot, name in self._plots: - self.save(plot, os.path.join( - self.assets_dirpath, "{0}.png".format(name)), report + self.save( + plot, + os.path.join(self.assets_dirpath, "{0}.png".format(name)), + report, ) - self.plot_html.append(extras.html( - thumbnail_html_local.format(name=name)) + self.plot_html.append( + extras.html(thumbnail_html_local.format(name=name)) ) diff --git a/tardis/tests/integration_tests/report.py b/tardis/tests/integration_tests/report.py index 2edb443774c..fc5acc51a04 100644 --- a/tardis/tests/integration_tests/report.py +++ b/tardis/tests/integration_tests/report.py @@ -46,7 +46,6 @@ class DokuReport(HTMLReport): - def __init__(self, report_config): """ Initialization of a DokuReport object and registration as a plugin @@ -54,10 +53,11 @@ def __init__(self, report_config): password of dokuwiki is passed through `dokuwiki_details`. """ # This will be either "remote" or "local". - self.save_mode = report_config['save_mode'] + self.save_mode = report_config["save_mode"] if self.save_mode == "remote": import dokuwiki + # Base class accepts a file path to save the report, but we pass an # empty string as it is redundant for this use case. super(DokuReport, self).__init__( @@ -65,33 +65,37 @@ def __init__(self, report_config): ) # Upload the report on a dokuwiki instance. - dokuwiki_details = report_config['dokuwiki'] + dokuwiki_details = report_config["dokuwiki"] try: self.doku_conn = dokuwiki.DokuWiki( - url=dokuwiki_details['url'], - user=dokuwiki_details['username'], - password=dokuwiki_details['password']) + url=dokuwiki_details["url"], + user=dokuwiki_details["username"], + password=dokuwiki_details["password"], + ) except (TypeError, gaierror, dokuwiki.DokuWikiError) as e: raise e self.doku_conn = None self.dokuwiki_url = "" else: - self.dokuwiki_url = dokuwiki_details['url'] + self.dokuwiki_url = dokuwiki_details["url"] else: # Save the html report file locally. self.report_dirpath = os.path.join( - os.path.expandvars(os.path.expanduser(report_config['reportpath'])), - tardis_githash[:7] + os.path.expandvars( + os.path.expanduser(report_config["reportpath"]) + ), + tardis_githash[:7], ) if os.path.exists(self.report_dirpath): shutil.rmtree(self.report_dirpath) os.makedirs(self.report_dirpath) - os.makedirs(os.path.join(self.report_dirpath, 'assets')) + os.makedirs(os.path.join(self.report_dirpath, "assets")) super(DokuReport, self).__init__( logfile=os.path.join(self.report_dirpath, "report.html"), - self_contained=False, has_rerun=False + self_contained=False, + has_rerun=False, ) self.suite_start_time = time.time() @@ -114,10 +118,10 @@ def _generate_report(self, session): # Quick hack for preventing log to be placed in narrow left out space report_content = report_content.replace( - u'class="log"', u'class="log" style="clear: both"' + 'class="log"', 'class="log" style="clear: both"' ) # It was displayed raw on wiki pages, but not needed. - report_content = report_content.replace(u'', u'') + report_content = report_content.replace("", "") return report_content def _save_report(self, report_content): @@ -128,16 +132,19 @@ def _save_report(self, report_content): if self.save_mode == "remote": # Upload the report content to wiki try: - self.doku_conn.pages.set("reports:{0}".format( - tardis_githash[:7]), report_content) + self.doku_conn.pages.set( + "reports:{0}".format(tardis_githash[:7]), report_content + ) except (gaierror, TypeError): pass else: # Save the file locally at "self.logfile" path - with open(self.logfile, 'w') as f: + with open(self.logfile, "w") as f: f.write(report_content) - with open(os.path.join(self.report_dirpath, 'assets', 'style.css'), 'w') as f: + with open( + os.path.join(self.report_dirpath, "assets", "style.css"), "w" + ) as f: f.write(self.style_css) def _wiki_overview_entry(self): @@ -150,7 +157,9 @@ def _wiki_overview_entry(self): else: status = "Errored" - suite_start_datetime = datetime.datetime.utcfromtimestamp(self.suite_start_time) + suite_start_datetime = datetime.datetime.utcfromtimestamp( + self.suite_start_time + ) # Fetch commit message from github. gh_request = requests.get( @@ -160,7 +169,7 @@ def _wiki_overview_entry(self): ) gh_commit_data = json.loads(gh_request.content) # Pick only first line of commit message - gh_commit_message = gh_commit_data['message'].split('\n')[0] + gh_commit_message = gh_commit_data["message"].split("\n")[0] # Truncate long commit messages if len(gh_commit_message) > 60: @@ -173,13 +182,15 @@ def _wiki_overview_entry(self): tardis_githash, gh_commit_message ) # Append start time - row += "{0} | ".format(suite_start_datetime.strftime('%d %b %H:%M:%S')) + row += "{0} | ".format( + suite_start_datetime.strftime("%d %b %H:%M:%S") + ) # Append time elapsed row += "{0:.2f} sec | ".format(self.suite_time_delta) # Append status row += "{0} |\n".format(status) try: - self.doku_conn.pages.append('/', row) + self.doku_conn.pages.append("/", row) except (gaierror, TypeError): pass @@ -204,7 +215,8 @@ def pytest_terminal_summary(self, terminalreporter): if self.save_mode == "remote": try: uploaded_report = self.doku_conn.pages.get( - "reports:{0}".format(tardis_githash[:7])) + "reports:{0}".format(tardis_githash[:7]) + ) except (gaierror, TypeError): uploaded_report = "" @@ -213,12 +225,17 @@ def pytest_terminal_summary(self, terminalreporter): "-", "Successfully uploaded report to Dokuwiki" ) terminalreporter.write_sep( - "-", "URL: {0}doku.php?id=reports:{1}".format( - self.dokuwiki_url, tardis_githash[:7]) + "-", + "URL: {0}doku.php?id=reports:{1}".format( + self.dokuwiki_url, tardis_githash[:7] + ), ) else: terminalreporter.write_sep( - "-", "Connection not established, upload failed.") + "-", "Connection not established, upload failed." + ) else: if os.path.exists(self.logfile): - super(DokuReport, self).pytest_terminal_summary(terminalreporter) + super(DokuReport, self).pytest_terminal_summary( + terminalreporter + ) diff --git a/tardis/tests/integration_tests/runner.py b/tardis/tests/integration_tests/runner.py index ea9392f10fc..51d7f820385 100644 --- a/tardis/tests/integration_tests/runner.py +++ b/tardis/tests/integration_tests/runner.py @@ -13,29 +13,44 @@ logger = logging.getLogger(__name__) parser = argparse.ArgumentParser(description="Run slow integration tests") -parser.add_argument("--integration-tests", dest="yaml_filepath", - help="Path to YAML config file for integration tests.") -parser.add_argument("--tardis-refdata", dest="tardis_refdata", - help="Path to Tardis Reference Data.") -parser.add_argument("--less-packets", action="store_true", default=False, - help="Run integration tests with less packets.") +parser.add_argument( + "--integration-tests", + dest="yaml_filepath", + help="Path to YAML config file for integration tests.", +) +parser.add_argument( + "--tardis-refdata", + dest="tardis_refdata", + help="Path to Tardis Reference Data.", +) +parser.add_argument( + "--less-packets", + action="store_true", + default=False, + help="Run integration tests with less packets.", +) def run_tests(): args = parser.parse_args() - integration_tests_config = yaml.load(open(args.yaml_filepath), Loader=yaml.CLoader) + integration_tests_config = yaml.load( + open(args.yaml_filepath), Loader=yaml.CLoader + ) doku_conn = dokuwiki.DokuWiki( - url=integration_tests_config['dokuwiki']['url'], - user=integration_tests_config['dokuwiki']['username'], - password=integration_tests_config['dokuwiki']['password'] + url=integration_tests_config["dokuwiki"]["url"], + user=integration_tests_config["dokuwiki"]["username"], + password=integration_tests_config["dokuwiki"]["password"], ) less_packets = "--less-packets" if args.less_packets else "" test_command = [ - "python", "setup.py", "test", - "--test-path=tardis/tests/integration_tests/test_integration.py", "--args", + "python", + "setup.py", + "test", + "--test-path=tardis/tests/integration_tests/test_integration.py", + "--args", "--capture=no --integration-tests={0} --tardis-refdata={1} --remote-data " - "{2}".format(args.yaml_filepath, args.tardis_refdata, less_packets) + "{2}".format(args.yaml_filepath, args.tardis_refdata, less_packets), ] subprocess.call(test_command) @@ -45,7 +60,7 @@ def run_tests(): "https://api.github.com/repos/tardis-sn/tardis/branches/master" ) gh_master_head_data = json.loads(gh_request.content) - gh_tardis_githash = gh_master_head_data['commit']['sha'][:7] + gh_tardis_githash = gh_master_head_data["commit"]["sha"][:7] # Check whether a report of this githash is uploaded on dokuwiki. # If not, then this is a new commit and tests should be executed. @@ -56,13 +71,20 @@ def run_tests(): # If dokuwiki returns empty string, then it means that report has not # been created yet. if len(dokuwiki_report) == 0: - subprocess.call([ - "git", "pull", "https://www.github.com/tardis-sn/tardis", "master" - ]) + subprocess.call( + [ + "git", + "pull", + "https://www.github.com/tardis-sn/tardis", + "master", + ] + ) subprocess.call(test_command) else: checked = datetime.datetime.now() - logger.info("Up-to-date. Checked on {0} {1}".format( - checked.strftime("%d-%b-%Y"), checked.strftime("%H:%M:%S") - )) + logger.info( + "Up-to-date. Checked on {0} {1}".format( + checked.strftime("%d-%b-%Y"), checked.strftime("%H:%M:%S") + ) + ) time.sleep(600) diff --git a/tardis/tests/integration_tests/test_integration.py b/tardis/tests/integration_tests/test_integration.py index f20c9a1052e..9c21a7c30a9 100644 --- a/tardis/tests/integration_tests/test_integration.py +++ b/tardis/tests/integration_tests/test_integration.py @@ -3,6 +3,7 @@ import yaml import pytest import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt from numpy.testing import assert_allclose @@ -12,45 +13,50 @@ from tardis.io.config_reader import Configuration quantity_comparison = [ - ('/simulation/runner/last_line_interaction_in_id', - 'runner.last_line_interaction_in_id'), - ('/simulation/runner/last_line_interaction_out_id', - 'runner.last_line_interaction_out_id'), - ('/simulation/runner/last_line_interaction_shell_id', - 'runner.last_line_interaction_shell_id'), - ('/simulation/plasma/j_blues', - 'plasma.j_blues'), - ('/simulation/plasma/j_blue_estimator', - 'plasma.j_blue_estimator'), - ('/simulation/runner/packet_luminosity', - 'runner.packet_luminosity.cgs.value'), - ('/simulation/runner/montecarlo_virtual_luminosity', - 'runner.montecarlo_virtual_luminosity.cgs.value'), - ('/simulation/runner/output_nu', - 'runner.output_nu.cgs.value'), - ('/simulation/plasma/ion_number_density', - 'plasma.ion_number_density'), - ('/simulation/plasma/level_number_density', - 'plasma.level_number_density'), - ('/simulation/plasma/electron_densities', - 'plasma.electron_densities'), - ('/simulation/plasma/tau_sobolevs', - 'plasma.tau_sobolevs'), - ('/simulation/plasma/transition_probabilities', - 'plasma.transition_probabilities'), - ('/simulation/model/t_radiative', - 'model.t_radiative.cgs.value'), - ('/simulation/model/w', - 'model.w'), - ('/simulation/runner/j_estimator', - 'runner.j_estimator'), - ('/simulation/runner/nu_bar_estimator', - 'runner.nu_bar_estimator'), - ('/simulation/plasma/j_blues_norm_factor', - 'plasma.j_blues_norm_factor.cgs.value'), - ('/simulation/plasma/luminosity_inner', - 'plasma.luminosity_inner.cgs.value'), - ] + ( + "/simulation/runner/last_line_interaction_in_id", + "runner.last_line_interaction_in_id", + ), + ( + "/simulation/runner/last_line_interaction_out_id", + "runner.last_line_interaction_out_id", + ), + ( + "/simulation/runner/last_line_interaction_shell_id", + "runner.last_line_interaction_shell_id", + ), + ("/simulation/plasma/j_blues", "plasma.j_blues"), + ("/simulation/plasma/j_blue_estimator", "plasma.j_blue_estimator"), + ( + "/simulation/runner/packet_luminosity", + "runner.packet_luminosity.cgs.value", + ), + ( + "/simulation/runner/montecarlo_virtual_luminosity", + "runner.montecarlo_virtual_luminosity.cgs.value", + ), + ("/simulation/runner/output_nu", "runner.output_nu.cgs.value"), + ("/simulation/plasma/ion_number_density", "plasma.ion_number_density"), + ("/simulation/plasma/level_number_density", "plasma.level_number_density"), + ("/simulation/plasma/electron_densities", "plasma.electron_densities"), + ("/simulation/plasma/tau_sobolevs", "plasma.tau_sobolevs"), + ( + "/simulation/plasma/transition_probabilities", + "plasma.transition_probabilities", + ), + ("/simulation/model/t_radiative", "model.t_radiative.cgs.value"), + ("/simulation/model/w", "model.w"), + ("/simulation/runner/j_estimator", "runner.j_estimator"), + ("/simulation/runner/nu_bar_estimator", "runner.nu_bar_estimator"), + ( + "/simulation/plasma/j_blues_norm_factor", + "plasma.j_blues_norm_factor.cgs.value", + ), + ( + "/simulation/plasma/luminosity_inner", + "plasma.luminosity_inner.cgs.value", + ), +] @pytest.fixture(params=quantity_comparison) @@ -58,8 +64,10 @@ def model_quantities(request): return request.param -@pytest.mark.skipif('not config.getvalue("integration-tests")', - reason="integration tests are not included in this run") +@pytest.mark.skipif( + 'not config.getvalue("integration-tests")', + reason="integration tests are not included in this run", +) @pytest.mark.integration class TestIntegration(object): """Slow integration test for various setups present in subdirectories of @@ -74,22 +82,25 @@ def setup(self, request, reference, data_path, pytestconfig): a single run of integration test. """ # Get capture manager - capmanager = pytestconfig.pluginmanager.getplugin('capturemanager') + capmanager = pytestconfig.pluginmanager.getplugin("capturemanager") # The last component in dirpath can be extracted as name of setup. - self.name = data_path['setup_name'] + self.name = data_path["setup_name"] - self.config_file = os.path.join(data_path['config_dirpath'], "config.yml") + self.config_file = os.path.join( + data_path["config_dirpath"], "config.yml" + ) # A quick hack to use atom data per setup. Atom data is ingested from # local HDF or downloaded and cached from a url, depending on data_path # keys. - atom_data_name = yaml.load( - open(self.config_file), Loader=yaml.CLoader)['atom_data'] + atom_data_name = yaml.load(open(self.config_file), Loader=yaml.CLoader)[ + "atom_data" + ] # Get the path to HDF file: atom_data_filepath = os.path.join( - data_path['atom_data_path'], atom_data_name + data_path["atom_data_path"], atom_data_name ) # Load atom data file separately, pass it for forming tardis config. @@ -106,20 +117,20 @@ def setup(self, request, reference, data_path, pytestconfig): # Check whether current run is with less packets. if request.config.getoption("--less-packets"): - less_packets = request.config.integration_tests_config['less_packets'] - tardis_config['montecarlo']['no_of_packets'] = ( - less_packets['no_of_packets'] - ) - tardis_config['montecarlo']['last_no_of_packets'] = ( - less_packets['last_no_of_packets'] - ) - - - + less_packets = request.config.integration_tests_config[ + "less_packets" + ] + tardis_config["montecarlo"]["no_of_packets"] = less_packets[ + "no_of_packets" + ] + tardis_config["montecarlo"]["last_no_of_packets"] = less_packets[ + "last_no_of_packets" + ] # We now do a run with prepared config and get the simulation object. - self.result = Simulation.from_config(tardis_config, - atom_data=self.atom_data) + self.result = Simulation.from_config( + tardis_config, atom_data=self.atom_data + ) capmanager.suspend_global_capture(True) # If current test run is just for collecting reference data, store the @@ -129,16 +140,19 @@ def setup(self, request, reference, data_path, pytestconfig): self.result.run() if request.config.getoption("--generate-reference"): ref_data_path = os.path.join( - data_path['reference_path'], "{0}.h5".format(self.name) + data_path["reference_path"], "{0}.h5".format(self.name) ) if os.path.exists(ref_data_path): pytest.skip( - 'Reference data {0} does exist and tests will not ' - 'proceed generating new data'.format(ref_data_path)) + "Reference data {0} does exist and tests will not " + "proceed generating new data".format(ref_data_path) + ) self.result.to_hdf(file_path=ref_data_path) - pytest.skip("Reference data saved at {0}".format( - data_path['reference_path'] - )) + pytest.skip( + "Reference data saved at {0}".format( + data_path["reference_path"] + ) + ) capmanager.resume_global_capture() # Get the reference data through the fixture. @@ -147,10 +161,11 @@ def setup(self, request, reference, data_path, pytestconfig): def test_model_quantities(self, model_quantities): reference_quantity_name, tardis_quantity_name = model_quantities if reference_quantity_name not in self.reference: - pytest.skip('{0} not calculated in this run'.format( - reference_quantity_name)) + pytest.skip( + "{0} not calculated in this run".format(reference_quantity_name) + ) reference_quantity = self.reference[reference_quantity_name] - tardis_quantity = eval('self.result.' + tardis_quantity_name) + tardis_quantity = eval("self.result." + tardis_quantity_name) assert_allclose(tardis_quantity, reference_quantity) def plot_t_rad(self): @@ -162,17 +177,28 @@ def plot_t_rad(self): ax.set_ylabel("t_rad") result_line = ax.plot( - self.result.model.t_rad.cgs, color="blue", marker=".", label="Result" + self.result.model.t_rad.cgs, + color="blue", + marker=".", + label="Result", ) reference_line = ax.plot( - self.reference['/simulation/model/t_rad'], - color="green", marker=".", label="Reference" + self.reference["/simulation/model/t_rad"], + color="green", + marker=".", + label="Reference", ) error_ax = ax.twinx() error_line = error_ax.plot( - (1 - self.result.model.t_rad.cgs.value / self.reference['/simulation/model/t_rad']), - color="red", marker=".", label="Rel. Error" + ( + 1 + - self.result.model.t_rad.cgs.value + / self.reference["/simulation/model/t_rad"] + ), + color="red", + marker=".", + label="Rel. Error", ) error_ax.set_ylabel("Relative error (1 - result / reference)") @@ -182,21 +208,25 @@ def plot_t_rad(self): ax.legend(lines, labels, loc="lower left") return figure - def test_spectrum(self, plot_object): plot_object.add(self.plot_spectrum(), "{0}_spectrum".format(self.name)) assert_allclose( - self.reference['/simulation/runner/spectrum/luminosity_density_nu'], - self.result.runner.spectrum.luminosity_density_nu.cgs.value) + self.reference["/simulation/runner/spectrum/luminosity_density_nu"], + self.result.runner.spectrum.luminosity_density_nu.cgs.value, + ) assert_allclose( - self.reference['/simulation/runner/spectrum/wavelength'], - self.result.runner.spectrum.wavelength.cgs.value) + self.reference["/simulation/runner/spectrum/wavelength"], + self.result.runner.spectrum.wavelength.cgs.value, + ) assert_allclose( - self.reference['/simulation/runner/spectrum/luminosity_density_lambda'], - self.result.runner.spectrum.luminosity_density_lambda.cgs.value) + self.reference[ + "/simulation/runner/spectrum/luminosity_density_lambda" + ], + self.result.runner.spectrum.luminosity_density_lambda.cgs.value, + ) def plot_spectrum(self): @@ -209,29 +239,32 @@ def plot_spectrum(self): spectrum_ax.set_ylabel("Flux [cgs]") deviation = 1 - ( - self.result.runner.spectrum.luminosity_density_lambda.cgs.value / - self.reference[ - '/simulation/runner/spectrum/luminosity_density_lambda'] - + self.result.runner.spectrum.luminosity_density_lambda.cgs.value + / self.reference[ + "/simulation/runner/spectrum/luminosity_density_lambda" + ] ) - spectrum_ax.plot( - self.reference['/simulation/runner/spectrum/wavelength'], + self.reference["/simulation/runner/spectrum/wavelength"], self.reference[ - '/simulation/runner/spectrum/luminosity_density_lambda'], - color="black" + "/simulation/runner/spectrum/luminosity_density_lambda" + ], + color="black", ) spectrum_ax.plot( - self.reference['/simulation/runner/spectrum/wavelength'], + self.reference["/simulation/runner/spectrum/wavelength"], self.result.runner.spectrum.luminosity_density_lambda.cgs.value, - color="red" + color="red", ) spectrum_ax.set_xticks([]) deviation_ax = plt.subplot(gs[1]) - deviation_ax.plot(self.reference['/simulation/runner/spectrum/wavelength'], - deviation, color='black') + deviation_ax.plot( + self.reference["/simulation/runner/spectrum/wavelength"], + deviation, + color="black", + ) deviation_ax.set_xlabel("Wavelength [Angstrom]") - return plt.gcf() \ No newline at end of file + return plt.gcf() diff --git a/tardis/tests/setup_package.py b/tardis/tests/setup_package.py index 29a9bf65331..113c3f617f9 100644 --- a/tardis/tests/setup_package.py +++ b/tardis/tests/setup_package.py @@ -1,6 +1,12 @@ def get_package_data(): return { - _ASTROPY_PACKAGE_NAME_ + '.tests': ['coveragerc', 'data/*.h5', - 'data/*.dat', 'data/*.npy', - 'integration_tests/*/*.yml', - 'integration_tests/*/*.dat']} + _ASTROPY_PACKAGE_NAME_ + + ".tests": [ + "coveragerc", + "data/*.h5", + "data/*.dat", + "data/*.npy", + "integration_tests/*/*.yml", + "integration_tests/*/*.dat", + ] + } diff --git a/tardis/tests/test_montecarlo.py b/tardis/tests/test_montecarlo.py deleted file mode 100644 index 65830b57497..00000000000 --- a/tardis/tests/test_montecarlo.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -from tardis import montecarlo -import pytest - -test_line_list = np.array([10, 9, 8, 7, 6, 5, 5, 4, 3, 2, 1]).astype(np.float64) - -# @pytest.mark.parametrize(("insert_value", "expected_insert_position"), [ -# (9.5, 0), -# (8.5, 1), -# (7.5, 2), -# (6.5, 3), -# (5.5, 4), -# (5.2, 4), -# (4.5, 6), -# (3.5, 7), -# (2.5, 8), -# (1.5, 9)]) -# def test_binary_search(insert_value, expected_insert_position): -# insert_position = montecarlo.binary_search_wrapper(test_line_list, insert_value, 0, len(test_line_list)) -# assert insert_position == expected_insert_position - - -# @pytest.mark.parametrize(("insert_value"), [ -# (10.5), -# (0.5)]) -# def test_binary_search_out_of_bounds(insert_value, capsys): -# with pytest.raises(ValueError): -# insert_position = montecarlo.binary_search_wrapper(test_line_list, insert_value, 0, len(test_line_list)-1) - -# @pytest.mark.parametrize(("insert_value", "expected_insert_position"), [ -# (10.5, 0), -# (0.5, len(test_line_list))]) -# def test_line_search_out_of_bounds(insert_value, expected_insert_position): -# insert_position = montecarlo.line_search_wrapper(test_line_list, -# insert_value, len(test_line_list)) -# assert insert_position == expected_insert_position - -# def test_compute_distance2outer(): -# assert montecarlo.compute_distance2outer_wrapper(0.0, 0.5, 1.0) == 1.0 -# assert montecarlo.compute_distance2outer_wrapper(1.0, 0.5, 1.0) == 0.0 -# assert montecarlo.compute_distance2outer_wrapper(0.3, 1.0, 1.0) == 0.7 -# assert montecarlo.compute_distance2outer_wrapper(0.3, -1.0, 1.0) == 1.3 -# assert montecarlo.compute_distance2outer_wrapper(0.5, 0.0, 1.0) == np.sqrt(0.75) - -# def test_compute_distance2inner(): -# assert montecarlo.compute_distance2inner_wrapper(1.5, -1.0, 1.0) == 0.5 -# assert montecarlo.compute_distance2inner_wrapper(0.0, 0.0, 0.0) == montecarlo.miss_distance -# assert montecarlo.compute_distance2inner_wrapper(1.2, -0.7, 1.0) == 0.3246360509309949 - -# def test_compute_distance2line(): -# assert montecarlo.compute_distance2line_wrapper(2.20866912e+15, -0.251699059004, 1.05581082105e+15, 1.06020910733e+15, 1693440.0, 5.90513983371e-07, 1.0602263591e+15, 1.06011723237e+15, 2) == 344430881691490.5 -# assert montecarlo.compute_distance2line_wrapper(2.23434667994e+15, -0.291130548401, 1.05581082105e+15, 1.06733618121e+15, 1693440.0, 5.90513983371e-07, 1.06738407486e+15, 1.06732933961e+15, 3) == 96296282395637.2 -# with pytest.raises(RuntimeError): -# montecarlo.compute_distance2line_wrapper(1.0, 1.0, 1.0, 10.0, 15.0, 1.0 / 15.0, 0.0, 0.0, 0) - -# def test_compute_distance2electron(): -# assert montecarlo.compute_distance2electron_wrapper(0.0, 0.0, 2.0, 2.0) == 4.0 - diff --git a/tardis/tests/test_tardis_full.py b/tardis/tests/test_tardis_full.py index 682e4645728..825909e165c 100644 --- a/tardis/tests/test_tardis_full.py +++ b/tardis/tests/test_tardis_full.py @@ -9,19 +9,19 @@ from tardis.io.config_reader import Configuration -class TestRunnerSimple(): +class TestRunnerSimple: """ Very simple run """ - name = 'test_runner_simple' + + name = "test_runner_simple" @pytest.fixture(scope="class") - def runner( - self, atomic_data_fname, - tardis_ref_data, generate_reference): + def runner(self, atomic_data_fname, tardis_ref_data, generate_reference): config = Configuration.from_yaml( - 'tardis/io/tests/data/tardis_configv1_verysimple.yml') - config['atom_data'] = atomic_data_fname + "tardis/io/tests/data/tardis_configv1_verysimple.yml" + ) + config["atom_data"] = atomic_data_fname simulation = Simulation.from_config(config) simulation.run() @@ -30,45 +30,36 @@ def runner( return simulation.runner else: simulation.runner.hdf_properties = [ - 'j_blue_estimator', - 'spectrum', - 'spectrum_virtual' - ] - simulation.runner.to_hdf( - tardis_ref_data, - '', - self.name) - pytest.skip( - 'Reference data was generated during this run.') - - @pytest.fixture(scope='class') + "j_blue_estimator", + "spectrum", + "spectrum_virtual", + ] + simulation.runner.to_hdf(tardis_ref_data, "", self.name) + pytest.skip("Reference data was generated during this run.") + + @pytest.fixture(scope="class") def refdata(self, tardis_ref_data): def get_ref_data(key): - return tardis_ref_data[os.path.join( - self.name, key)] + return tardis_ref_data[os.path.join(self.name, key)] + return get_ref_data def test_j_blue_estimators(self, runner, refdata): - j_blue_estimator = refdata('j_blue_estimator').values + j_blue_estimator = refdata("j_blue_estimator").values - npt.assert_allclose( - runner.j_blue_estimator, - j_blue_estimator) + npt.assert_allclose(runner.j_blue_estimator, j_blue_estimator) def test_spectrum(self, runner, refdata): - luminosity = u.Quantity(refdata('spectrum/luminosity'), 'erg /s') + luminosity = u.Quantity(refdata("spectrum/luminosity"), "erg /s") - assert_quantity_allclose( - runner.spectrum.luminosity, - luminosity) + assert_quantity_allclose(runner.spectrum.luminosity, luminosity) def test_virtual_spectrum(self, runner, refdata): luminosity = u.Quantity( - refdata('spectrum_virtual/luminosity'), 'erg /s') + refdata("spectrum_virtual/luminosity"), "erg /s" + ) - assert_quantity_allclose( - runner.spectrum_virtual.luminosity, - luminosity) + assert_quantity_allclose(runner.spectrum_virtual.luminosity, luminosity) def test_runner_properties(self, runner): """Tests whether a number of runner attributes exist and also verifies @@ -81,14 +72,16 @@ def test_runner_properties(self, runner): virt_type = np.ndarray - props_required_by_modeltohdf5 = dict([ + props_required_by_modeltohdf5 = dict( + [ ("virt_packet_last_interaction_type", virt_type), ("virt_packet_last_line_interaction_in_id", virt_type), ("virt_packet_last_line_interaction_out_id", virt_type), ("virt_packet_last_interaction_in_nu", virt_type), ("virt_packet_nus", virt_type), ("virt_packet_energies", virt_type), - ]) + ] + ) required_props = props_required_by_modeltohdf5.copy() @@ -96,5 +89,5 @@ def test_runner_properties(self, runner): actual = getattr(runner, prop) assert type(actual) == prop_type, ( "wrong type of attribute '{}':" - "expected {}, found {}".format( - prop, prop_type, type(actual))) + "expected {}, found {}".format(prop, prop_type, type(actual)) + ) diff --git a/tardis/tests/test_tardis_full_formal_integral.py b/tardis/tests/test_tardis_full_formal_integral.py index d3de68501c4..53a7343e7ea 100644 --- a/tardis/tests/test_tardis_full_formal_integral.py +++ b/tardis/tests/test_tardis_full_formal_integral.py @@ -9,18 +9,19 @@ from tardis.io.config_reader import Configuration import astropy -config_line_modes = ['downbranch', 'macroatom'] +config_line_modes = ["downbranch", "macroatom"] interpolate_shells = [-1, 30] -@pytest.fixture(scope='module', params=config_line_modes) +@pytest.fixture(scope="module", params=config_line_modes) def base_config(request): config = Configuration.from_yaml( - 'tardis/io/tests/data/tardis_configv1_verysimple.yml') + "tardis/io/tests/data/tardis_configv1_verysimple.yml" + ) config["plasma"]["line_interaction_type"] = request.param - config["montecarlo"]["no_of_packets"] = 4.0e+4 - config["montecarlo"]["last_no_of_packets"] = 1.0e+5 + config["montecarlo"]["no_of_packets"] = 4.0e4 + config["montecarlo"]["last_no_of_packets"] = 1.0e5 config["montecarlo"]["no_of_virtual_packets"] = 0 config["spectrum"]["method"] = "integrated" config["spectrum"]["integrated"]["points"] = 200 @@ -28,27 +29,31 @@ def base_config(request): return config -@pytest.fixture(scope='module', params=interpolate_shells) + +@pytest.fixture(scope="module", params=interpolate_shells) def config(base_config, request): base_config["spectrum"]["integrated"]["interpolate_shells"] = request.param return base_config -class TestRunnerSimpleFormalInegral(): + +class TestRunnerSimpleFormalInegral: """ Very simple run with the formal integral spectral synthesis method """ - _name = 'test_runner_simple_integral' + + _name = "test_runner_simple_integral" @pytest.fixture(scope="class") def runner( - self, config, atomic_data_fname, - tardis_ref_data, generate_reference): + self, config, atomic_data_fname, tardis_ref_data, generate_reference + ): config.atom_data = atomic_data_fname - self.name = (self._name + - "_{:s}".format(config.plasma.line_interaction_type)) + self.name = self._name + "_{:s}".format( + config.plasma.line_interaction_type + ) if config.spectrum.integrated.interpolate_shells > 0: - self.name += '_interp' + self.name += "_interp" simulation = Simulation.from_config(config) simulation.run() @@ -57,43 +62,40 @@ def runner( return simulation.runner else: simulation.runner.hdf_properties = [ - 'j_blue_estimator', - 'spectrum', - 'spectrum_integrated' - ] - simulation.runner.to_hdf( - tardis_ref_data, - '', - self.name) - pytest.skip( - 'Reference data was generated during this run.') - - @pytest.fixture(scope='class') + "j_blue_estimator", + "spectrum", + "spectrum_integrated", + ] + simulation.runner.to_hdf(tardis_ref_data, "", self.name) + pytest.skip("Reference data was generated during this run.") + + @pytest.fixture(scope="class") def refdata(self, tardis_ref_data): def get_ref_data(key): - return tardis_ref_data[os.path.join( - self.name, key)] + return tardis_ref_data[os.path.join(self.name, key)] + return get_ref_data def test_j_blue_estimators(self, runner, refdata): - j_blue_estimator = refdata('j_blue_estimator').values + j_blue_estimator = refdata("j_blue_estimator").values - npt.assert_allclose( - runner.j_blue_estimator, - j_blue_estimator) + npt.assert_allclose(runner.j_blue_estimator, j_blue_estimator) def test_spectrum(self, runner, refdata): - luminosity = u.Quantity(refdata('spectrum/luminosity'), 'erg /s') + luminosity = u.Quantity(refdata("spectrum/luminosity"), "erg /s") - assert_quantity_allclose( - runner.spectrum.luminosity, - luminosity) + assert_quantity_allclose(runner.spectrum.luminosity, luminosity) def test_spectrum_integrated(self, runner, refdata): luminosity = u.Quantity( - refdata('spectrum_integrated/luminosity'), 'erg /s') + refdata("spectrum_integrated/luminosity"), "erg /s" + ) - print("actual, desired: ", luminosity, runner.spectrum_integrated.luminosity) - assert_quantity_allclose( + print( + "actual, desired: ", + luminosity, runner.spectrum_integrated.luminosity, - luminosity) + ) + assert_quantity_allclose( + runner.spectrum_integrated.luminosity, luminosity + ) diff --git a/tardis/tests/test_util.py b/tardis/tests/test_util.py index 0e6bbd0d44f..820bd6ef5dd 100644 --- a/tardis/tests/test_util.py +++ b/tardis/tests/test_util.py @@ -5,45 +5,72 @@ from astropy import units as u from io import StringIO -from tardis.util.base import MalformedSpeciesError, MalformedElementSymbolError, MalformedQuantityError, int_to_roman, \ - roman_to_int, calculate_luminosity, create_synpp_yaml, intensity_black_body, \ - species_tuple_to_string, species_string_to_tuple, parse_quantity, element_symbol2atomic_number, \ - atomic_number2element_symbol, reformat_element_symbol, quantity_linspace, convert_abundances_format - -data_path = os.path.join('tardis', 'io', 'tests', 'data') +from tardis.util.base import ( + MalformedSpeciesError, + MalformedElementSymbolError, + MalformedQuantityError, + int_to_roman, + roman_to_int, + calculate_luminosity, + create_synpp_yaml, + intensity_black_body, + species_tuple_to_string, + species_string_to_tuple, + parse_quantity, + element_symbol2atomic_number, + atomic_number2element_symbol, + reformat_element_symbol, + quantity_linspace, + convert_abundances_format, +) + +data_path = os.path.join("tardis", "io", "tests", "data") @pytest.fixture def artis_abundances_fname(): - return os.path.join(data_path, 'artis_abundances.dat') + return os.path.join(data_path, "artis_abundances.dat") + def test_malformed_species_error(): - malformed_species_error = MalformedSpeciesError('He') - assert malformed_species_error.malformed_element_symbol == 'He' - assert str(malformed_species_error) == 'Expecting a species notation (e.g. "Si 2", "Si II", "Fe IV") - supplied He' + malformed_species_error = MalformedSpeciesError("He") + assert malformed_species_error.malformed_element_symbol == "He" + assert ( + str(malformed_species_error) + == 'Expecting a species notation (e.g. "Si 2", "Si II", "Fe IV") - supplied He' + ) def test_malformed_elements_symbol_error(): - malformed_elements_symbol_error = MalformedElementSymbolError('Hx') - assert malformed_elements_symbol_error.malformed_element_symbol == 'Hx' - assert str(malformed_elements_symbol_error) == 'Expecting an atomic symbol (e.g. Fe) - supplied Hx' + malformed_elements_symbol_error = MalformedElementSymbolError("Hx") + assert malformed_elements_symbol_error.malformed_element_symbol == "Hx" + assert ( + str(malformed_elements_symbol_error) + == "Expecting an atomic symbol (e.g. Fe) - supplied Hx" + ) def test_malformed_quantity_error(): - malformed_quantity_error = MalformedQuantityError('abcd') - assert malformed_quantity_error.malformed_quantity_string == 'abcd' - assert str(malformed_quantity_error) == 'Expecting a quantity string(e.g. "5 km/s") for keyword - supplied abcd' - - -@pytest.mark.parametrize(['test_input', 'expected_result'], [ - (1, 'I'), - (5, 'V'), - (19, 'XIX'), - (556, 'DLVI'), - (1400, 'MCD'), - (1999, 'MCMXCIX'), - (3000, 'MMM') -]) + malformed_quantity_error = MalformedQuantityError("abcd") + assert malformed_quantity_error.malformed_quantity_string == "abcd" + assert ( + str(malformed_quantity_error) + == 'Expecting a quantity string(e.g. "5 km/s") for keyword - supplied abcd' + ) + + +@pytest.mark.parametrize( + ["test_input", "expected_result"], + [ + (1, "I"), + (5, "V"), + (19, "XIX"), + (556, "DLVI"), + (1400, "MCD"), + (1999, "MCMXCIX"), + (3000, "MMM"), + ], +) def test_int_to_roman(test_input, expected_result): assert int_to_roman(test_input) == expected_result @@ -51,15 +78,18 @@ def test_int_to_roman(test_input, expected_result): int_to_roman(1.5) -@pytest.mark.parametrize(['test_input', 'expected_result'], [ - ('I', 1), - ('V', 5), - ('XIX', 19), - ('DLVI', 556), - ('MCD', 1400), - ('MCMXCIX', 1999), - ('MMM', 3000) -]) +@pytest.mark.parametrize( + ["test_input", "expected_result"], + [ + ("I", 1), + ("V", 5), + ("XIX", 19), + ("DLVI", 556), + ("MCD", 1400), + ("MCMXCIX", 1999), + ("MMM", 3000), + ], +) def test_roman_to_int(test_input, expected_result): assert roman_to_int(test_input) == expected_result @@ -67,117 +97,147 @@ def test_roman_to_int(test_input, expected_result): roman_to_int(1) - -@pytest.mark.parametrize(['string_io', 'distance', 'result'], [ - (StringIO(u'4000 1e-21\n4500 3e-21\n5000 5e-21'), '100 km', (0.0037699111843077517, 4000.0, 5000.0)), - (StringIO(u'7600 2.4e-19\n7800 1.6e-19\n8100 9.1e-20'), '500 km', (2.439446695512474, 7600.0, 8100.0)) -]) +@pytest.mark.parametrize( + ["string_io", "distance", "result"], + [ + ( + StringIO("4000 1e-21\n4500 3e-21\n5000 5e-21"), + "100 km", + (0.0037699111843077517, 4000.0, 5000.0), + ), + ( + StringIO("7600 2.4e-19\n7800 1.6e-19\n8100 9.1e-20"), + "500 km", + (2.439446695512474, 7600.0, 8100.0), + ), + ], +) def test_calculate_luminosity(string_io, distance, result): assert calculate_luminosity(string_io, distance) == result -@pytest.mark.parametrize(['nu', 't', 'i'], [ - (10**6, 1000, 3.072357852080765e-22), - (10**6, 300, 9.21707305730458e-23), - (10**8, 1000, 6.1562660718558254e-24), - (10**8, 300, 1.846869480674048e-24), -]) +@pytest.mark.parametrize( + ["nu", "t", "i"], + [ + (10 ** 6, 1000, 3.072357852080765e-22), + (10 ** 6, 300, 9.21707305730458e-23), + (10 ** 8, 1000, 6.1562660718558254e-24), + (10 ** 8, 300, 1.846869480674048e-24), + ], +) def test_intensity_black_body(nu, t, i): assert np.isclose(intensity_black_body(nu, t), i) - -@pytest.mark.parametrize(['species_tuple', 'roman_numerals', 'species_string'], [ - ((14, 1), True, 'Si II'), - ((14, 1), False, 'Si 1'), - ((14, 3), True, 'Si IV'), - ((14, 3), False, 'Si 3'), - ((14, 8), True, 'Si IX'), - ((14, 8), False, 'Si 8'), -]) +@pytest.mark.parametrize( + ["species_tuple", "roman_numerals", "species_string"], + [ + ((14, 1), True, "Si II"), + ((14, 1), False, "Si 1"), + ((14, 3), True, "Si IV"), + ((14, 3), False, "Si 3"), + ((14, 8), True, "Si IX"), + ((14, 8), False, "Si 8"), + ], +) def test_species_tuple_to_string(species_tuple, roman_numerals, species_string): - assert species_tuple_to_string(species_tuple, roman_numerals=roman_numerals) == species_string + assert ( + species_tuple_to_string(species_tuple, roman_numerals=roman_numerals) + == species_string + ) -@pytest.mark.parametrize(['species_string', 'species_tuple'], [ - ('si ii', (14, 1)), - ('si 2', (14, 1)), - ('si ix', (14, 8)), -]) +@pytest.mark.parametrize( + ["species_string", "species_tuple"], + [("si ii", (14, 1)), ("si 2", (14, 1)), ("si ix", (14, 8)),], +) def test_species_string_to_tuple(species_string, species_tuple): assert species_string_to_tuple(species_string) == species_tuple with pytest.raises(MalformedSpeciesError): - species_string_to_tuple('II') + species_string_to_tuple("II") with pytest.raises(MalformedSpeciesError): - species_string_to_tuple('He Si') + species_string_to_tuple("He Si") with pytest.raises(ValueError): - species_string_to_tuple('He IX') + species_string_to_tuple("He IX") def test_parse_quantity(): - q1 = parse_quantity('5 km/s') - assert q1.value == 5. - assert q1.unit == u.Unit('km/s') + q1 = parse_quantity("5 km/s") + assert q1.value == 5.0 + assert q1.unit == u.Unit("km/s") with pytest.raises(MalformedQuantityError): parse_quantity(5) with pytest.raises(MalformedQuantityError): - parse_quantity('abcd') + parse_quantity("abcd") with pytest.raises(MalformedQuantityError): - parse_quantity('a abcd') + parse_quantity("a abcd") with pytest.raises(MalformedQuantityError): - parse_quantity('5 abcd') + parse_quantity("5 abcd") -@pytest.mark.parametrize(['element_symbol', 'atomic_number'], [ - ('sI', 14), - ('ca', 20), - ('Fe', 26) -]) +@pytest.mark.parametrize( + ["element_symbol", "atomic_number"], [("sI", 14), ("ca", 20), ("Fe", 26)] +) def test_element_symbol2atomic_number(element_symbol, atomic_number): assert element_symbol2atomic_number(element_symbol) == atomic_number with pytest.raises(MalformedElementSymbolError): - element_symbol2atomic_number('Hx') + element_symbol2atomic_number("Hx") def test_atomic_number2element_symbol(): - assert atomic_number2element_symbol(14) == 'Si' - - -@pytest.mark.parametrize(['unformatted_element_string', 'formatted_element_string'], [ - ('si', 'Si'), - ('sI', 'Si'), - ('Si', 'Si'), - ('c', 'C'), - ('C', 'C'), -]) -def test_reformat_element_symbol(unformatted_element_string, formatted_element_string): - assert reformat_element_symbol(unformatted_element_string) == formatted_element_string - - -@pytest.mark.parametrize(['start', 'stop', 'num', 'expected'], [ - (u.Quantity(1, 'km/s'), u.Quantity(5, 'km/s'), 5, u.Quantity(np.array([1., 2., 3., 4., 5.]), 'km/s')), - (u.Quantity(0.5, 'eV'), u.Quantity(0.6, 'eV'), 3, u.Quantity(np.array([0.5, 0.55, 0.6]), 'eV')) -]) + assert atomic_number2element_symbol(14) == "Si" + + +@pytest.mark.parametrize( + ["unformatted_element_string", "formatted_element_string"], + [("si", "Si"), ("sI", "Si"), ("Si", "Si"), ("c", "C"), ("C", "C"),], +) +def test_reformat_element_symbol( + unformatted_element_string, formatted_element_string +): + assert ( + reformat_element_symbol(unformatted_element_string) + == formatted_element_string + ) + + +@pytest.mark.parametrize( + ["start", "stop", "num", "expected"], + [ + ( + u.Quantity(1, "km/s"), + u.Quantity(5, "km/s"), + 5, + u.Quantity(np.array([1.0, 2.0, 3.0, 4.0, 5.0]), "km/s"), + ), + ( + u.Quantity(0.5, "eV"), + u.Quantity(0.6, "eV"), + 3, + u.Quantity(np.array([0.5, 0.55, 0.6]), "eV"), + ), + ], +) def test_quantity_linspace(start, stop, num, expected): obtained = quantity_linspace(start, stop, num) assert obtained.unit == expected.unit assert obtained.value.all() == expected.value.all() with pytest.raises(ValueError): - quantity_linspace(u.Quantity(0.5, 'eV'), '0.6 eV', 3) + quantity_linspace(u.Quantity(0.5, "eV"), "0.6 eV", 3) def test_convert_abundances_format(artis_abundances_fname): abundances = convert_abundances_format(artis_abundances_fname) - assert np.isclose(abundances.loc[3, 'O'], 1.240199e-08, atol=1.e-12) - assert np.isclose(abundances.loc[1, 'Co'], 2.306023e-05, atol=1.e-12) - assert np.isclose(abundances.loc[69, 'Ni'], 1.029928e-17, atol=1.e-12) - assert np.isclose(abundances.loc[2, 'C'], 4.425876e-09, atol=1.e-12) + assert np.isclose(abundances.loc[3, "O"], 1.240199e-08, atol=1.0e-12) + assert np.isclose(abundances.loc[1, "Co"], 2.306023e-05, atol=1.0e-12) + assert np.isclose(abundances.loc[69, "Ni"], 1.029928e-17, atol=1.0e-12) + assert np.isclose(abundances.loc[2, "C"], 4.425876e-09, atol=1.0e-12) diff --git a/tardis/util/__init__.py b/tardis/util/__init__.py index 0d1b60fa94c..6dd4694b3ef 100644 --- a/tardis/util/__init__.py +++ b/tardis/util/__init__.py @@ -1,3 +1 @@ # Utilities for TARDIS - - diff --git a/tardis/util/base.py b/tardis/util/base.py index 7777cf756e1..9f6bca3f3ab 100644 --- a/tardis/util/base.py +++ b/tardis/util/base.py @@ -23,52 +23,66 @@ logger = logging.getLogger(__name__) tardis_dir = os.path.realpath(tardis.__path__[0]) -ATOMIC_SYMBOLS_DATA = pd.read_csv(get_internal_data_path('atomic_symbols.dat'), delim_whitespace=True, - names=['atomic_number', 'symbol']).set_index('atomic_number').squeeze() +ATOMIC_SYMBOLS_DATA = ( + pd.read_csv( + get_internal_data_path("atomic_symbols.dat"), + delim_whitespace=True, + names=["atomic_number", "symbol"], + ) + .set_index("atomic_number") + .squeeze() +) ATOMIC_NUMBER2SYMBOL = OrderedDict(ATOMIC_SYMBOLS_DATA.to_dict()) -SYMBOL2ATOMIC_NUMBER = OrderedDict((y, x) for x, y in ATOMIC_NUMBER2SYMBOL.items()) +SYMBOL2ATOMIC_NUMBER = OrderedDict( + (y, x) for x, y in ATOMIC_NUMBER2SYMBOL.items() +) -synpp_default_yaml_fname = get_internal_data_path('synpp_default.yaml') +synpp_default_yaml_fname = get_internal_data_path("synpp_default.yaml") -NUMERAL_MAP = tuple(zip( - (1000, 900, 500, 400, 100, 90, 50, 40, 10, 9, 5, 4, 1), - ('M', 'CM', 'D', 'CD', 'C', 'XC', 'L', 'XL', 'X', 'IX', 'V', 'IV', 'I') -)) +NUMERAL_MAP = tuple( + zip( + (1000, 900, 500, 400, 100, 90, 50, 40, 10, 9, 5, 4, 1), + ("M", "CM", "D", "CD", "C", "XC", "L", "XL", "X", "IX", "V", "IV", "I"), + ) +) + class MalformedError(Exception): pass class MalformedSpeciesError(MalformedError): - def __init__(self, malformed_element_symbol): self.malformed_element_symbol = malformed_element_symbol def __str__(self): - return ('Expecting a species notation (e.g. "Si 2", "Si II", "Fe IV") ' - '- supplied {0}'.format(self.malformed_element_symbol)) + return ( + 'Expecting a species notation (e.g. "Si 2", "Si II", "Fe IV") ' + "- supplied {0}".format(self.malformed_element_symbol) + ) class MalformedElementSymbolError(MalformedError): - def __init__(self, malformed_element_symbol): self.malformed_element_symbol = malformed_element_symbol def __str__(self): - return ('Expecting an atomic symbol (e.g. Fe) - supplied {0}').format( - self.malformed_element_symbol) + return ("Expecting an atomic symbol (e.g. Fe) - supplied {0}").format( + self.malformed_element_symbol + ) class MalformedQuantityError(MalformedError): - def __init__(self, malformed_quantity_string): self.malformed_quantity_string = malformed_quantity_string def __str__(self): - return ('Expecting a quantity string(e.g. "5 km/s") for keyword ' - '- supplied {0}').format(self.malformed_quantity_string) + return ( + 'Expecting a quantity string(e.g. "5 km/s") for keyword ' + "- supplied {0}" + ).format(self.malformed_quantity_string) def int_to_roman(i): @@ -90,7 +104,8 @@ def int_to_roman(i): count = i // integer result.append(numeral * count) i -= integer * count - return ''.join(result) + return "".join(result) + def roman_to_int(roman_string): """ @@ -110,22 +125,29 @@ def roman_to_int(roman_string): NUMERALS_SET = set(list(zip(*NUMERAL_MAP))[1]) roman_string = roman_string.upper() if len(set(list(roman_string.upper())) - NUMERALS_SET) != 0: - raise ValueError('{0} does not seem to be a roman numeral'.format( - roman_string)) + raise ValueError( + "{0} does not seem to be a roman numeral".format(roman_string) + ) i = result = 0 for integer, numeral in NUMERAL_MAP: - while roman_string[i:i + len(numeral)] == numeral: + while roman_string[i : i + len(numeral)] == numeral: result += integer i += len(numeral) if result < 1: - raise ValueError('Can not interpret Roman Numeral {0}'.format(roman_string)) + raise ValueError( + "Can not interpret Roman Numeral {0}".format(roman_string) + ) return result def calculate_luminosity( - spec_fname, distance, wavelength_column=0, - wavelength_unit=u.angstrom, flux_column=1, - flux_unit=u.Unit('erg / (Angstrom cm2 s)')): + spec_fname, + distance, + wavelength_column=0, + wavelength_unit=u.angstrom, + flux_column=1, + flux_unit=u.Unit("erg / (Angstrom cm2 s)"), +): """ Calculates luminosity of star. @@ -153,13 +175,15 @@ def calculate_luminosity( wavelength.max() : float Maximum value of wavelength of light """ - #BAD STYLE change to parse quantity + # BAD STYLE change to parse quantity distance = u.Unit(distance) - wavelength, flux = np.loadtxt(spec_fname, usecols=(wavelength_column, flux_column), unpack=True) + wavelength, flux = np.loadtxt( + spec_fname, usecols=(wavelength_column, flux_column), unpack=True + ) flux_density = np.trapz(flux, wavelength) * (flux_unit * wavelength_unit) - luminosity = (flux_density * 4 * np.pi * distance**2).to('erg/s') + luminosity = (flux_density * 4 * np.pi * distance ** 2).to("erg/s") return luminosity.value, wavelength.min(), wavelength.max() @@ -184,66 +208,79 @@ def create_synpp_yaml(radial1d_mdl, fname, shell_no=0, lines_db=None): If the current dataset does not contain necessary reference files """ - logger.warning('Currently only works with Si and a special setup') + logger.warning("Currently only works with Si and a special setup") if radial1d_mdl.atom_data.synpp_refs is not None: raise ValueError( - 'The current atom dataset does not contain the ' - 'necessary reference files (please contact the authors)') + "The current atom dataset does not contain the " + "necessary reference files (please contact the authors)" + ) - radial1d_mdl.atom_data.synpp_refs['ref_log_tau'] = -99.0 + radial1d_mdl.atom_data.synpp_refs["ref_log_tau"] = -99.0 for key, value in radial1d_mdl.atom_data.synpp_refs.iterrows(): try: - radial1d_mdl.atom_data.synpp_refs['ref_log_tau'].loc[key] = np.log10( - radial1d_mdl.plasma.tau_sobolevs[0].loc[value['line_id']]) + radial1d_mdl.atom_data.synpp_refs["ref_log_tau"].loc[ + key + ] = np.log10( + radial1d_mdl.plasma.tau_sobolevs[0].loc[value["line_id"]] + ) except KeyError: pass - relevant_synpp_refs = radial1d_mdl.atom_data.synpp_refs[ - radial1d_mdl.atom_data.synpp_refs['ref_log_tau'] > -50] + radial1d_mdl.atom_data.synpp_refs["ref_log_tau"] > -50 + ] with open(synpp_default_yaml_fname) as stream: yaml_reference = yaml.load(stream, Loader=yaml.CLoader) if lines_db is not None: - yaml_reference['opacity']['line_dir'] = os.path.join(lines_db, 'lines') - yaml_reference['opacity']['line_dir'] = os.path.join(lines_db, 'refs.dat') - - yaml_reference['output']['min_wl'] = float( - radial1d_mdl.runner.spectrum.wavelength.to('angstrom').value.min()) - yaml_reference['output']['max_wl'] = float( - radial1d_mdl.runner.spectrum.wavelength.to('angstrom').value.max()) - - - #raise Exception("there's a problem here with units what units does synpp expect?") - yaml_reference['opacity']['v_ref'] = float( - (radial1d_mdl.tardis_config.structure.v_inner[0].to('km/s') / - (1000. * u.km / u.s)).value) - yaml_reference['grid']['v_outer_max'] = float( - (radial1d_mdl.tardis_config.structure.v_outer[-1].to('km/s') / - (1000. * u.km / u.s)).value) - - #pdb.set_trace() - - yaml_setup = yaml_reference['setups'][0] - yaml_setup['ions'] = [] - yaml_setup['log_tau'] = [] - yaml_setup['active'] = [] - yaml_setup['temp'] = [] - yaml_setup['v_min'] = [] - yaml_setup['v_max'] = [] - yaml_setup['aux'] = [] + yaml_reference["opacity"]["line_dir"] = os.path.join(lines_db, "lines") + yaml_reference["opacity"]["line_dir"] = os.path.join( + lines_db, "refs.dat" + ) + + yaml_reference["output"]["min_wl"] = float( + radial1d_mdl.runner.spectrum.wavelength.to("angstrom").value.min() + ) + yaml_reference["output"]["max_wl"] = float( + radial1d_mdl.runner.spectrum.wavelength.to("angstrom").value.max() + ) + + # raise Exception("there's a problem here with units what units does synpp expect?") + yaml_reference["opacity"]["v_ref"] = float( + ( + radial1d_mdl.tardis_config.structure.v_inner[0].to("km/s") + / (1000.0 * u.km / u.s) + ).value + ) + yaml_reference["grid"]["v_outer_max"] = float( + ( + radial1d_mdl.tardis_config.structure.v_outer[-1].to("km/s") + / (1000.0 * u.km / u.s) + ).value + ) + + # pdb.set_trace() + + yaml_setup = yaml_reference["setups"][0] + yaml_setup["ions"] = [] + yaml_setup["log_tau"] = [] + yaml_setup["active"] = [] + yaml_setup["temp"] = [] + yaml_setup["v_min"] = [] + yaml_setup["v_max"] = [] + yaml_setup["aux"] = [] for species, synpp_ref in relevant_synpp_refs.iterrows(): - yaml_setup['ions'].append(100 * species[0] + species[1]) - yaml_setup['log_tau'].append(float(synpp_ref['ref_log_tau'])) - yaml_setup['active'].append(True) - yaml_setup['temp'].append(yaml_setup['t_phot']) - yaml_setup['v_min'].append(yaml_reference['opacity']['v_ref']) - yaml_setup['v_max'].append(yaml_reference['grid']['v_outer_max']) - yaml_setup['aux'].append(1e200) - - with open(fname, 'w') as f: + yaml_setup["ions"].append(100 * species[0] + species[1]) + yaml_setup["log_tau"].append(float(synpp_ref["ref_log_tau"])) + yaml_setup["active"].append(True) + yaml_setup["temp"].append(yaml_setup["t_phot"]) + yaml_setup["v_min"].append(yaml_reference["opacity"]["v_ref"]) + yaml_setup["v_max"].append(yaml_reference["grid"]["v_outer_max"]) + yaml_setup["aux"].append(1e200) + + with open(fname, "w") as f: yaml.dump(yaml_reference, stream=f, explicit_start=True) @@ -269,8 +306,9 @@ def intensity_black_body(nu, T): """ beta_rad = 1 / (k_B_cgs * T) coefficient = 2 * h_cgs / c_cgs ** 2 - intensity = ne.evaluate('coefficient * nu**3 / ' - '(exp(h_cgs * nu * beta_rad) -1 )') + intensity = ne.evaluate( + "coefficient * nu**3 / " "(exp(h_cgs * nu * beta_rad) -1 )" + ) return intensity @@ -295,10 +333,10 @@ def species_tuple_to_string(species_tuple, roman_numerals=True): atomic_number, ion_number = species_tuple element_symbol = ATOMIC_NUMBER2SYMBOL[atomic_number] if roman_numerals: - roman_ion_number = int_to_roman(ion_number+1) - return '{0} {1}'.format(str(element_symbol), roman_ion_number) + roman_ion_number = int_to_roman(ion_number + 1) + return "{0} {1}".format(str(element_symbol), roman_ion_number) else: - return '{0} {1:d}'.format(element_symbol, ion_number) + return "{0} {1:d}".format(element_symbol, ion_number) def species_string_to_tuple(species_string): @@ -322,15 +360,17 @@ def species_string_to_tuple(species_string): """ try: - element_symbol, ion_number_string = re.match(r'^(\w+)\s*(\d+)', - species_string).groups() + element_symbol, ion_number_string = re.match( + r"^(\w+)\s*(\d+)", species_string + ).groups() except AttributeError: try: element_symbol, ion_number_string = species_string.split() except ValueError: raise MalformedSpeciesError( 'Species string "{0}" is not of format ' - ' (e.g. Fe 2, Fe2, ..)'.format(species_string)) + " (e.g. Fe 2, Fe2, ..)".format(species_string) + ) atomic_number = element_symbol2atomic_number(element_symbol) @@ -342,11 +382,14 @@ def species_string_to_tuple(species_string): except ValueError: raise MalformedSpeciesError( "Given ion number ('{}') could not be parsed".format( - ion_number_string)) + ion_number_string + ) + ) if ion_number > atomic_number: raise ValueError( - 'Species given does not exist: ion number > atomic number') + "Species given does not exist: ion number > atomic number" + ) return atomic_number, ion_number - 1 @@ -472,15 +515,18 @@ def quantity_linspace(start, stop, num, **kwargs): ValueError If start and stop values have no unit attribute. """ - if not (hasattr(start, 'unit') and hasattr(stop, 'unit')): - raise ValueError('Both start and stop need to be quantities with a ' - 'unit attribute') + if not (hasattr(start, "unit") and hasattr(stop, "unit")): + raise ValueError( + "Both start and stop need to be quantities with a " "unit attribute" + ) - return (np.linspace(start.value, stop.to(start.unit).value, num, **kwargs) - * start.unit) + return ( + np.linspace(start.value, stop.to(start.unit).value, num, **kwargs) + * start.unit + ) -def convert_abundances_format(fname, delimiter=r'\s+'): +def convert_abundances_format(fname, delimiter=r"\s+"): """ Changes format of file containing abundances into data frame @@ -496,10 +542,9 @@ def convert_abundances_format(fname, delimiter=r'\s+'): DataFrame Corresponding data frame """ - df = pd.read_csv(fname, delimiter=delimiter, comment='#', header=None) + df = pd.read_csv(fname, delimiter=delimiter, comment="#", header=None) # Drop shell index column df.drop(df.columns[0], axis=1, inplace=True) # Assign header row - df.columns = [nucname.name(i) - for i in range(1, df.shape[1] + 1)] - return df \ No newline at end of file + df.columns = [nucname.name(i) for i in range(1, df.shape[1] + 1)] + return df diff --git a/tardis/util/colored_logger.py b/tardis/util/colored_logger.py index 37130295345..f45a7c423e3 100644 --- a/tardis/util/colored_logger.py +++ b/tardis/util/colored_logger.py @@ -1,29 +1,33 @@ import logging -''' + +""" Code for Custom Logger Classes (ColoredFormatter and ColorLogger) and its helper function (formatter_message) is used from this thread http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output -''' +""" + def formatter_message(message, use_color=True): - ''' + """ Helper Function used for Coloring Log Output - ''' - #These are the sequences need to get colored ouput + """ + # These are the sequences need to get colored ouput RESET_SEQ = "\033[0m" BOLD_SEQ = "\033[1m" if use_color: - message = message.replace( - "$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ) + message = message.replace("$RESET", RESET_SEQ).replace( + "$BOLD", BOLD_SEQ + ) else: message = message.replace("$RESET", "").replace("$BOLD", "") return message class ColoredFormatter(logging.Formatter): - ''' + """ Custom logger class for changing levels color - ''' + """ + def __init__(self, msg, use_color=True): logging.Formatter.__init__(self, msg) self.use_color = use_color @@ -33,24 +37,26 @@ def format(self, record): RESET_SEQ = "\033[0m" BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) COLORS = { - 'WARNING': YELLOW, - 'INFO': WHITE, - 'DEBUG': BLUE, - 'CRITICAL': YELLOW, - 'ERROR': RED + "WARNING": YELLOW, + "INFO": WHITE, + "DEBUG": BLUE, + "CRITICAL": YELLOW, + "ERROR": RED, } levelname = record.levelname if self.use_color and levelname in COLORS: - levelname_color = COLOR_SEQ % ( - 30 + COLORS[levelname]) + levelname + RESET_SEQ + levelname_color = ( + COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ + ) record.levelname = levelname_color return logging.Formatter.format(self, record) class ColoredLogger(logging.Logger): - ''' + """ Custom logger class with multiple destinations - ''' + """ + FORMAT = "[$BOLD%(name)-20s$RESET][%(levelname)-18s] %(message)s ($BOLD%(filename)s$RESET:%(lineno)d)" COLOR_FORMAT = formatter_message(FORMAT, True) @@ -63,4 +69,4 @@ def __init__(self, name): console.setFormatter(color_formatter) self.addHandler(console) - return \ No newline at end of file + return diff --git a/tardis/widgets/base.py b/tardis/widgets/base.py index e8eb894e6fc..2a55a8aa0da 100644 --- a/tardis/widgets/base.py +++ b/tardis/widgets/base.py @@ -1,6 +1,9 @@ from tardis.base import run_tardis from tardis.io.atom_data.util import download_atom_data -from tardis.util.base import atomic_number2element_symbol, species_tuple_to_string +from tardis.util.base import ( + atomic_number2element_symbol, + species_tuple_to_string, +) from tardis.simulation import Simulation import pandas as pd @@ -59,8 +62,12 @@ def shells_data(self): Dataframe containing Rad. Temp. and W against each shell of simulation model """ - shells_temp_w = pd.DataFrame({"Rad. Temp.": self.t_radiative, "W": self.w}) - shells_temp_w.index = range(1, len(self.t_radiative) + 1) # Overwrite index + shells_temp_w = pd.DataFrame( + {"Rad. Temp.": self.t_radiative, "W": self.w} + ) + shells_temp_w.index = range( + 1, len(self.t_radiative) + 1 + ) # Overwrite index shells_temp_w.index.name = "Shell No." # Format to string to make qgrid show values in scientific notations return shells_temp_w.applymap(lambda x: "{:.6e}".format(x)) @@ -86,11 +93,13 @@ def element_count(self, shell_num): element_count_data.fillna(0, inplace=True) return pd.DataFrame( { - "Element": element_count_data.index.map(atomic_number2element_symbol), - # Format to string to show in scientific notation - "Frac. Ab. (Shell {})".format(shell_num): element_count_data.map( - "{:.6e}".format + "Element": element_count_data.index.map( + atomic_number2element_symbol ), + # Format to string to show in scientific notation + "Frac. Ab. (Shell {})".format( + shell_num + ): element_count_data.map("{:.6e}".format), } ) @@ -152,7 +161,9 @@ def level_count(self, ion, atomic_num, shell_num): level_num_density = self.level_number_density[shell_num - 1].loc[ atomic_num, ion ] - ion_num_density = self.ion_number_density[shell_num - 1].loc[atomic_num, ion] + ion_num_density = self.ion_number_density[shell_num - 1].loc[ + atomic_num, ion + ] level_count_data = level_num_density / ion_num_density # Normalization level_count_data.index.name = "Level" level_count_data.name = "Frac. Ab. (Ion={})".format(ion) @@ -252,7 +263,8 @@ def __init__(self, shell_info_data): # Creating the ion count table widget self.ion_count_table = self.create_table_widget( self.data.ion_count( - self.element_count_table.df.index[0], self.shells_table.df.index[0] + self.element_count_table.df.index[0], + self.shells_table.df.index[0], ), [20, 30, 50], changeable_col={ @@ -280,7 +292,9 @@ def __init__(self, shell_info_data): # element count table "other_names": [ "Frac. Ab. (Ion={})".format(ion) - for ion in range(0, self.element_count_table.df.index.max() + 1) + for ion in range( + 0, self.element_count_table.df.index.max() + 1 + ) ], }, ) @@ -467,7 +481,9 @@ def update_level_count_table(self, event, qgrid_widget): ion = self.ion_count_table.df.index[event["new"][0]] # Update data in level_count_table - self.level_count_table.df = self.data.level_count(ion, atomic_num, shell_num) + self.level_count_table.df = self.data.level_count( + ion, atomic_num, shell_num + ) def display( self, @@ -475,7 +491,7 @@ def display( element_count_table_width="24%", ion_count_table_width="24%", level_count_table_width="18%", - **layout_kwargs + **layout_kwargs, ): """Display the shell info widget by putting all component widgets nicely together and allowing interaction between the table widgets @@ -505,7 +521,9 @@ def display( """ # CSS properties of the layout of shell info tables container tables_container_layout = dict( - display="flex", align_items="flex-start", justify_content="space-between" + display="flex", + align_items="flex-start", + justify_content="space-between", ) tables_container_layout.update(layout_kwargs) @@ -516,9 +534,15 @@ def display( self.level_count_table.layout.width = level_count_table_width # Attach event listeners to table widgets - self.shells_table.on("selection_changed", self.update_element_count_table) - self.element_count_table.on("selection_changed", self.update_ion_count_table) - self.ion_count_table.on("selection_changed", self.update_level_count_table) + self.shells_table.on( + "selection_changed", self.update_element_count_table + ) + self.element_count_table.on( + "selection_changed", self.update_ion_count_table + ) + self.ion_count_table.on( + "selection_changed", self.update_level_count_table + ) # Putting all table widgets in a container styled with tables_container_layout shell_info_tables_container = ipw.Box( diff --git a/tardis/widgets/tests/test_base.py b/tardis/widgets/tests/test_base.py index 1f14b075e62..47c43e74afa 100644 --- a/tardis/widgets/tests/test_base.py +++ b/tardis/widgets/tests/test_base.py @@ -36,7 +36,10 @@ def hdf_shell_info(hdf_file_path, simulation_verysimple): class TestBaseShellInfo: def test_shells_data(self, base_shell_info, simulation_verysimple): shells_data = base_shell_info.shells_data() - assert shells_data.shape == (len(simulation_verysimple.model.t_radiative), 2) + assert shells_data.shape == ( + len(simulation_verysimple.model.t_radiative), + 2, + ) assert np.allclose( shells_data.iloc[:, 0].map(np.float), simulation_verysimple.model.t_radiative.value, @@ -66,7 +69,9 @@ def test_ion_count_data( ion_count_data = base_shell_info.ion_count(atomic_num, shell_num) sim_ion_number_density = simulation_verysimple.plasma.ion_number_density[ shell_num - 1 - ].loc[atomic_num] + ].loc[ + atomic_num + ] sim_element_number_density = simulation_verysimple.plasma.number_density.loc[ atomic_num, shell_num - 1 ] @@ -80,15 +85,26 @@ def test_ion_count_data( ("ion_num", "atomic_num", "shell_num"), [(2, 12, 1), (3, 20, 20)] ) def test_level_count_data( - self, base_shell_info, simulation_verysimple, ion_num, atomic_num, shell_num + self, + base_shell_info, + simulation_verysimple, + ion_num, + atomic_num, + shell_num, ): - level_count_data = base_shell_info.level_count(ion_num, atomic_num, shell_num) + level_count_data = base_shell_info.level_count( + ion_num, atomic_num, shell_num + ) sim_level_number_density = simulation_verysimple.plasma.level_number_density[ shell_num - 1 - ].loc[atomic_num, ion_num] + ].loc[ + atomic_num, ion_num + ] sim_ion_number_density = simulation_verysimple.plasma.ion_number_density[ shell_num - 1 - ].loc[atomic_num, ion_num] + ].loc[ + atomic_num, ion_num + ] assert level_count_data.shape == (len(sim_level_number_density), 1) assert np.allclose( level_count_data.iloc[:, 0].map(np.float), @@ -123,10 +139,14 @@ def shell_info_widget(self, base_shell_info): _ = shell_info_widget.display() return shell_info_widget - def test_selection_on_shells_table(self, base_shell_info, shell_info_widget): + def test_selection_on_shells_table( + self, base_shell_info, shell_info_widget + ): shell_info_widget.shells_table.change_selection([self.select_shell_num]) - expected_element_count = base_shell_info.element_count(self.select_shell_num) + expected_element_count = base_shell_info.element_count( + self.select_shell_num + ) pdt.assert_frame_equal( expected_element_count, shell_info_widget.element_count_table.df ) @@ -134,7 +154,9 @@ def test_selection_on_shells_table(self, base_shell_info, shell_info_widget): expected_ion_count = base_shell_info.ion_count( expected_element_count.index[0], self.select_shell_num ) - pdt.assert_frame_equal(expected_ion_count, shell_info_widget.ion_count_table.df) + pdt.assert_frame_equal( + expected_ion_count, shell_info_widget.ion_count_table.df + ) expected_level_count = base_shell_info.level_count( expected_ion_count.index[0], @@ -145,23 +167,35 @@ def test_selection_on_shells_table(self, base_shell_info, shell_info_widget): expected_level_count, shell_info_widget.level_count_table.df ) - def test_selection_on_element_count_table(self, base_shell_info, shell_info_widget): - shell_info_widget.element_count_table.change_selection([self.select_atomic_num]) + def test_selection_on_element_count_table( + self, base_shell_info, shell_info_widget + ): + shell_info_widget.element_count_table.change_selection( + [self.select_atomic_num] + ) expected_ion_count = base_shell_info.ion_count( self.select_atomic_num, self.select_shell_num ) - pdt.assert_frame_equal(expected_ion_count, shell_info_widget.ion_count_table.df) + pdt.assert_frame_equal( + expected_ion_count, shell_info_widget.ion_count_table.df + ) expected_level_count = base_shell_info.level_count( - expected_ion_count.index[0], self.select_atomic_num, self.select_shell_num + expected_ion_count.index[0], + self.select_atomic_num, + self.select_shell_num, ) pdt.assert_frame_equal( expected_level_count, shell_info_widget.level_count_table.df ) - def test_selection_on_ion_count_table(self, base_shell_info, shell_info_widget): - shell_info_widget.ion_count_table.change_selection([self.select_ion_num]) + def test_selection_on_ion_count_table( + self, base_shell_info, shell_info_widget + ): + shell_info_widget.ion_count_table.change_selection( + [self.select_ion_num] + ) expected_level_count = base_shell_info.level_count( self.select_ion_num, self.select_atomic_num, self.select_shell_num