Skip to content

Commit

Permalink
Merge branch 'main' into ddpm
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbaird committed Jun 11, 2022
2 parents 7cd400c + 96e7dbb commit 06a978c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: '^docs/conf.py'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand Down
21 changes: 20 additions & 1 deletion src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from xtal2png import __version__
from xtal2png.utils.data import dummy_structures, rgb_scaler, rgb_unscaler
Expand Down Expand Up @@ -104,6 +105,14 @@ class XtalConverter:
save_dir : Union[str, 'PathLike[str]']
Directory to save PNG files via ``func:xtal2png``,
by default path.join("data", "interim")
symprec : float, optional
The symmetry precision to use when decoding `pymatgen` structures via
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By
default 0.1.
angle_tolerance : Union[float, int], optional
The angle tolerance (degrees) to use when decoding `pymatgen` structures via
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By
default 5.0.
Examples
--------
Expand All @@ -125,6 +134,8 @@ def __init__(
distance_range: Tuple[float, float] = (0.0, 18.0),
max_sites: int = 52,
save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"),
symprec: float = 0.1,
angle_tolerance: float = 5.0,
):
"""Instantiate an XtalConverter object with desired ranges and ``max_sites``."""
self.atom_range = atom_range
Expand All @@ -138,6 +149,8 @@ def __init__(
self.distance_range = distance_range
self.max_sites = max_sites
self.save_dir = save_dir
self.symprec = symprec
self.angle_tolerance = angle_tolerance

Path(save_dir).mkdir(exist_ok=True, parents=True)

Expand Down Expand Up @@ -283,7 +296,9 @@ def png2xtal(
if save:
for s in S:
fpath = path.join(self.save_dir, construct_save_name(s) + ".cif")
CifWriter(s).write_file(fpath)
CifWriter(
s, symprec=self.symprec, angle_tolerance=self.angle_tolerance
).write_file(fpath)

return S

Expand Down Expand Up @@ -625,6 +640,10 @@ def arrays_to_structures(self, data: np.ndarray):
a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma
)
structure = Structure(lattice, at, fr)
spa = SpacegroupAnalyzer(
structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance
)
structure = spa.get_refined_structure()
S.append(structure)

return S
Expand Down
34 changes: 27 additions & 7 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# (test), and back to crystal Structure (test)


from warnings import warn

import plotly.express as px
from numpy.testing import assert_allclose, assert_equal
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher

from xtal2png.core import XtalConverter
from xtal2png.utils.data import (
Expand All @@ -20,7 +23,24 @@


def assert_structures_approximate_match(example_structures, structures):
for s, structure in zip(example_structures, structures):
for i, (s, structure) in enumerate(zip(example_structures, structures)):
# d = np.linalg.norm(s._lattice.abc)
# sm = StructureMatcher(
# ltol=rgb_loose_tol * d,
# stol=rgb_loose_tol * d,
# angle_tol=rgb_loose_tol * 180,
# comparator=ElementComparator(),
# )
sm = StructureMatcher(comparator=ElementComparator())
is_match = sm.fit(s, structure)
if not is_match:
warn(
f"{i}-th original and decoded structures do not match according to StructureMatcher(comparator=ElementComparator()).fit(s, structure).\n\nOriginal (s): {s}\n\nDecoded (structure): {structure}" # noqa: E501
)

sm = StructureMatcher(primitive_cell=False, comparator=ElementComparator())
s2 = sm.get_s2_like_s1(s, structure)

a_check = s._lattice.a
b_check = s._lattice.b
c_check = s._lattice.c
Expand All @@ -29,12 +49,12 @@ def assert_structures_approximate_match(example_structures, structures):
frac_coords_check = s.frac_coords
space_group_check = s.get_space_group_info()[1]

latt_a = structure._lattice.a
latt_b = structure._lattice.b
latt_c = structure._lattice.c
angles = structure._lattice.angles
atomic_numbers = structure.atomic_numbers
frac_coords = structure.frac_coords
latt_a = s2._lattice.a
latt_b = s2._lattice.b
latt_c = s2._lattice.c
angles = s2._lattice.angles
atomic_numbers = s2.atomic_numbers
frac_coords = s2.frac_coords
space_group = s.get_space_group_info()[1]

assert_allclose(
Expand Down

0 comments on commit 06a978c

Please sign in to comment.