diff --git a/pyiron_atomistics/atomistics/master/murnaghan.py b/pyiron_atomistics/atomistics/master/murnaghan.py index c360f5eab..1d80e403e 100644 --- a/pyiron_atomistics/atomistics/master/murnaghan.py +++ b/pyiron_atomistics/atomistics/master/murnaghan.py @@ -743,7 +743,6 @@ def collect_output(self): erg_lst, vol_lst, err_lst, id_lst = [], [], [], [] for job_id in self.child_ids: ham = self.project_hdf5.inspect(job_id) - print("job_id: ", job_id, ham.status) if "energy_tot" in ham["output/generic"].list_nodes(): energy = ham["output/generic/energy_tot"][-1] elif "energy_pot" in ham["output/generic"].list_nodes(): @@ -774,7 +773,7 @@ def collect_output(self): else: self._fit_eos_general(fittype=self.input["fit_type"]) - def plot(self, num_steps=100, plt_show=True): + def plot(self, num_steps=100, plt_show=True, ax=None, plot_kwargs=None): if not self.status.finished: raise ValueError( "Job must be successfully run, before calling this method." @@ -783,6 +782,11 @@ def plot(self, num_steps=100, plt_show=True): import matplotlib.pylab as plt except ImportError: import matplotlib.pyplot as plt + + if ax is None: + ax = plt.subplot(111) + else: + plt_show = False if not self.fit_dict: if self.input["fit_type"] == "polynomial": self.fit_polynomial(fit_order=self.input["fit_order"]) @@ -791,20 +795,38 @@ def plot(self, num_steps=100, plt_show=True): df = self.output_to_pandas() vol_lst, erg_lst = df["volume"].values, df["energy"].values x_i = np.linspace(np.min(vol_lst), np.max(vol_lst), num_steps) - color = "blue" + + if plot_kwargs is None: + plot_kwargs = {} + + if "color" in plot_kwargs.keys(): + color = plot_kwargs["color"] + del plot_kwargs["color"] + else: + color = "blue" + + if "marker" in plot_kwargs.keys(): + del plot_kwargs["marker"] + + if "label" in plot_kwargs.keys(): + label = plot_kwargs["label"] + del plot_kwargs["label"] + else: + label = self.input["fit_type"] if self.fit_dict is not None: if self.input["fit_type"] == "polynomial": p_fit = np.poly1d(self.fit_dict["poly_fit"]) least_square_error = self.fit_module.get_error(vol_lst, erg_lst, p_fit) - plt.title("Murnaghan: error: " + str(least_square_error)) - plt.plot( + ax.set_title("Murnaghan: error: " + str(least_square_error)) + ax.plot( x_i, p_fit(x_i), "-", - label=self.input["fit_type"], + label=label, color=color, linewidth=3, + **plot_kwargs, ) else: V0 = self.fit_dict["volume_eq"] @@ -814,21 +836,23 @@ def plot(self, num_steps=100, plt_show=True): eng_fit_lst = fitfunction( parameters=[E0, B0, BP, V0], vol=x_i, fittype=self.input["fit_type"] ) - plt.plot( + ax.plot( x_i, eng_fit_lst, "-", - label=self.input["fit_type"], + label=label, color=color, linewidth=3, + **plot_kwargs, ) - plt.plot(vol_lst, erg_lst, "x", color=color, markersize=20) - plt.legend() - plt.xlabel("Volume ($\AA^3$)") - plt.ylabel("energy (eV)") + ax.plot(vol_lst, erg_lst, "x", color=color, markersize=20, **plot_kwargs) + ax.legend() + ax.set_xlabel("Volume ($\AA^3$)") + ax.set_ylabel("energy (eV)") if plt_show: plt.show() + return ax def _get_structure(self, frame=-1, wrap_atoms=True): """ diff --git a/tests/atomistics/master/test_murnaghan.py b/tests/atomistics/master/test_murnaghan.py index a04ae447f..30fdcd8dd 100644 --- a/tests/atomistics/master/test_murnaghan.py +++ b/tests/atomistics/master/test_murnaghan.py @@ -3,9 +3,9 @@ # Distributed under the terms of "New BSD License", see the LICENSE file. import unittest - +import matplotlib +import matplotlib.pylab as plt import numpy as np - from pyiron_atomistics.atomistics.structure.atoms import CrystalStructure from pyiron_base._tests import TestWithProject @@ -90,7 +90,12 @@ def test_fitting_routines(self): murn._hdf5["output/equilibrium_volume"] = 448.4033384110422 murn.status.finished = True - murn.plot(plt_show=False) + self.assertIsInstance(murn.plot(plt_show=False), matplotlib.axes.Axes) + _, ax_list = plt.subplots(ncols=2, nrows=1) + for i, ax in enumerate(ax_list): + ax = murn.plot(ax=ax, plot_kwargs={"color": "black", "label": f"plot{i}", "marker": "x"}) + ax.set_title(f"Axis {i+1}") + self.assertEqual(len(ax.lines), 2) with self.subTest(msg="standard polynomial fit"): self.assertAlmostEqual(-90.71969974284912, murn.equilibrium_energy) self.assertAlmostEqual(448.1341230545222, murn.equilibrium_volume)