Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converted the reference data to regression in test TARDIS full #2611

Merged
2 changes: 1 addition & 1 deletion tardis/tests/fixtures/regression_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def sync_hdf_store(self, tardis_module, update_fname=True):
with pd.HDFStore(self.fpath, mode="w") as store:
tardis_module.to_hdf(store, overwrite=True)
pytest.skip(
f"Skipping test to generate regression_data {self.fpath} data"
f"Skipping test to generate regression data: {self.fpath}"
)
else:
return pd.HDFStore(self.fpath, mode="r")
Expand Down
71 changes: 33 additions & 38 deletions tardis/tests/test_tardis_full.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pathlib import Path

import pytest
import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose

from tardis.simulation.base import Simulation
from tardis.io.configuration.config_reader import Configuration

from tardis import run_tardis
from tardis.io.configuration.config_reader import Configuration
from tardis.simulation.base import Simulation
from tardis.tests.fixtures.regression_data import RegressionData


def test_run_tardis_from_config_obj(
Expand All @@ -35,64 +35,59 @@ class TestTransportSimple:
Very simple run
"""

name = "test_transport_simple"
regression_data: RegressionData = None

@pytest.fixture(scope="class")
def transport(
def transport_state(
self,
request,
atomic_data_fname,
tardis_ref_data,
generate_reference,
example_configuration_dir: Path,
):
config = Configuration.from_yaml(
example_configuration_dir / "tardis_configv1_verysimple.yml"
str(example_configuration_dir / "tardis_configv1_verysimple.yml")
)
config["atom_data"] = atomic_data_fname

simulation = Simulation.from_config(config)
simulation.run_convergence()
simulation.run_final()
if not generate_reference:
return simulation.transport
else:
simulation.transport.hdf_properties = [
"transport_state",
]
simulation.transport.to_hdf(
tardis_ref_data, "", self.name, overwrite=True
)
pytest.skip("Reference data was generated during this run.")

@pytest.fixture(scope="class")
def refdata(self, tardis_ref_data):
def get_ref_data(key):
return tardis_ref_data[f"{self.name}/{key}"]
transport_state = simulation.transport.transport_state
request.cls.regression_data = RegressionData(request)
request.cls.regression_data.sync_hdf_store(transport_state)

return transport_state

return get_ref_data
def get_expected_data(self, key: str):
return pd.read_hdf(self.regression_data.fpath, key)

def test_j_blue_estimators(self, transport, refdata):
j_blue_estimator = refdata("transport_state/j_blue_estimator").values
def test_j_blue_estimators(self, transport_state):
key = "transport_state/j_blue_estimator"
expected = self.get_expected_data(key)

npt.assert_allclose(
transport.transport_state.radfield_mc_estimators.j_blue_estimator,
j_blue_estimator,
transport_state.radfield_mc_estimators.j_blue_estimator,
expected.values,
)

def test_spectrum(self, transport, refdata):
luminosity = u.Quantity(
refdata("transport_state/spectrum/luminosity"), "erg /s"
)
def test_spectrum(self, transport_state):
key = "transport_state/spectrum/luminosity"
expected = self.get_expected_data(key)

luminosity = u.Quantity(expected, "erg /s")

assert_quantity_allclose(
transport.transport_state.spectrum.luminosity, luminosity
transport_state.spectrum.luminosity, luminosity
)

def test_virtual_spectrum(self, transport, refdata):
luminosity = u.Quantity(
refdata("transport_state/spectrum_virtual/luminosity"), "erg /s"
)
def test_virtual_spectrum(self, transport_state):
key = "transport_state/spectrum_virtual/luminosity"
expected = self.get_expected_data(key)

luminosity = u.Quantity(expected, "erg /s")

assert_quantity_allclose(
transport.transport_state.spectrum_virtual.luminosity, luminosity
transport_state.spectrum_virtual.luminosity, luminosity
)
Loading