Skip to content

Commit

Permalink
Remove ase dependency and update ROY dataset (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX authored Jun 15, 2024
1 parent bc76938 commit f28cfbd
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 61 deletions.
10 changes: 5 additions & 5 deletions examples/selection/GCH-ROY.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@

roy_data = load_roy_dataset()

structures = roy_data["structures"]

density = np.array([s.info["density"] for s in structures])
energy = np.array([s.info["energy"] for s in structures])
structype = np.array([s.info["type"] for s in structures])
density = roy_data["densities"]
energy = roy_data["energies"]
structype = roy_data["structure_types"]
iknown = np.where(structype == "known")[0]
iothers = np.where(structype != "known")[0]

Expand Down Expand Up @@ -247,3 +245,5 @@
},
)
"""

# %%
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ dynamic = ["version"]

[project.optional-dependencies]
examples = [
"ase",
"matplotlib",
"pandas",
"tqdm",
Expand Down
39 changes: 17 additions & 22 deletions src/skmatter/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,32 +119,27 @@ def load_who_dataset():


def load_roy_dataset():
"""Load and returns the ROY dataset, which contains structures,
energies and SOAP-derived descriptors for 264 polymorphs of ROY, from [Beran et Al,
Chemical Science (2022)](https://doi.org/10.1039/D1SC06074K)
"""Load and returns the ROY dataset, which contains densities,
energies and SOAP-derived descriptors for 264 structures of polymorphs of ROY,
from [Beran et Al, Chemical Science (2022)](https://doi.org/10.1039/D1SC06074K)
Each structure is labeled as "Known" or "Unknown".
Returns
-------
roy_dataset : sklearn.utils.Bunch
Dictionary-like object, with the following attributes:
structures : `ase.Atoms` -- the roy structures as ASE objects
features: `np.array` -- SOAP-derived descriptors for the structures
energies: `np.array` -- energies of the structures
densities : `np.array` -- the densities of the structures
structure_types : `np.array` -- the type of the structures
features : `np.array` -- SOAP-derived descriptors for the structures
energies : `np.array` -- energies of the structures
"""
module_path = dirname(__file__)
target_structures = join(module_path, "data", "beran_roy_structures.xyz.bz2")

try:
from ase.io import read
except ImportError:
raise ImportError("load_roy_dataset requires the ASE package.")

import bz2

structures = read(bz2.open(target_structures, "rt"), ":", format="extxyz")
energies = np.array([f.info["energy"] for f in structures])

target_features = join(module_path, "data", "beran_roy_features.npz")
features = np.load(target_features)["feats"]

return Bunch(structures=structures, features=features, energies=energies)
target_properties = join(module_path, "data", "beran_roy_properties.npz")
properties = np.load(target_properties)

return Bunch(
densities=properties["densities"],
energies=properties["energies"],
structure_types=properties["structure_types"],
features=properties["feats"],
)
Binary file removed src/skmatter/datasets/data/beran_roy_features.npz
Binary file not shown.
Binary file not shown.
Binary file not shown.
37 changes: 4 additions & 33 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,45 +107,16 @@ class ROYTests(unittest.TestCase):
def setUpClass(cls):
cls.size = 264
cls.shape = (264, 32)
try:
from ase.io import read # NoQa: F401

cls.has_ase = True
cls.roy = load_roy_dataset()
except ImportError:
cls.has_ase = False

def test_load_dataset_without_ase(self):
"""Check if the correct exception occurs when ase isn't present."""
with unittest.mock.patch.dict("sys.modules", {"ase.io": None}):
with self.assertRaises(ImportError) as cm:
_ = load_roy_dataset()
self.assertEqual(
str(cm.exception), "load_roy_dataset requires the ASE package."
)
cls.roy = load_roy_dataset()

def test_dataset_content(self):
"""Check if the correct number of datapoints are present in the dataset.
Also check if the size of the dataset is correct.
"""
if self.has_ase is True:
self.assertEqual(len(self.roy["structures"]), self.size)
self.assertEqual(self.roy["features"].shape, self.shape)
self.assertEqual(len(self.roy["energies"]), self.size)

def test_dataset_consistency(self):
"""Check if the energies in the structures are the same as in the explicit
array.
"""
if self.has_ase is True:
self.assertTrue(
np.allclose(
self.roy["energies"],
[f.info["energy"] for f in self.roy["structures"]],
rtol=1e-6,
)
)
self.assertEqual(len(self.roy["structure_types"]), self.size)
self.assertEqual(self.roy["features"].shape, self.shape)
self.assertEqual(len(self.roy["energies"]), self.size)


if __name__ == "__main__":
Expand Down

0 comments on commit f28cfbd

Please sign in to comment.