Skip to content

Commit

Permalink
Merge 1eadcc1 into f58fc69
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbaird committed Jun 17, 2022
2 parents f58fc69 + 1eadcc1 commit 10c9cec
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 34 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ install_requires =
pymatgen
plotly
kaleido
m3gnet


[options.packages.find]
Expand Down
90 changes: 68 additions & 22 deletions src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import numpy as np
import pandas as pd
import tensorflow as tf
from m3gnet.models import Relaxer
from numpy.typing import NDArray
from PIL import Image
from pymatgen.core.lattice import Lattice
Expand Down Expand Up @@ -61,7 +63,7 @@
DISTANCE_KEY = "distance"


def construct_save_name(s: Structure):
def construct_save_name(s: Structure) -> str:
save_name = f"{s.formula.replace(' ', '')},volume={int(np.round(s.volume))},uid={str(uuid4())[0:4]}" # noqa: E501
return save_name

Expand Down Expand Up @@ -125,6 +127,10 @@ class XtalConverter:
Decode structures as symmetrized, primitive structures. Uses ``symprec`` if
``symprec`` is of type float, else uses ``symprec[1]`` if ``symprec`` is of type
tuple. Same applies for ``angle_tolerance``. By default True
relax_on_decode: bool, optional
Use m3gnet to relax the decoded crystal structures.
verbose: bool, optional
Whether to print verbose debugging information or not.
Examples
--------
Expand All @@ -150,6 +156,8 @@ def __init__(
angle_tolerance: Union[float, int, Tuple[float, float], Tuple[int, int]] = 5.0,
encode_as_primitive: bool = False,
decode_as_primitive: bool = False,
relax_on_decode: bool = True,
verbose: bool = True,
):
"""Instantiate an XtalConverter object with desired ranges and ``max_sites``."""
self.atom_range = atom_range
Expand Down Expand Up @@ -180,14 +188,14 @@ def __init__(

self.encode_as_primitive = encode_as_primitive
self.decode_as_primitive = decode_as_primitive
self.relax_on_decode = relax_on_decode
self.verbose = verbose

Path(save_dir).mkdir(exist_ok=True, parents=True)

def xtal2png(
self,
structures: Union[
List[Union[Structure, str, "PathLike[str]"]], str, "PathLike[str]"
],
structures: List[Union[Structure, str, "PathLike[str]"]],
show: bool = False,
save: bool = True,
):
Expand All @@ -196,8 +204,7 @@ def xtal2png(
Parameters
----------
structures : List[Union[Structure, str, PathLike[str]]]
pymatgen Structure objects or path to CIF files or path to directory
containing CIF files.
pymatgen Structure objects or path to CIF files.
show : bool, optional
Whether to display the PNG-encoded file, by default False
save : bool, optional
Expand Down Expand Up @@ -229,10 +236,11 @@ def xtal2png(
>>> xc = XtalConverter()
>>> xc.xtal2png(structures, show=False, save=True)
"""
save_names, structures = self.process_filepaths_or_structures(structures)

save_names, S = self.process_filepaths_or_structures(structures)

# convert structures to 3D NumPy Matrices
self.data, self.id_data, self.id_mapper = self.structures_to_arrays(structures)
self.data, self.id_data, self.id_mapper = self.structures_to_arrays(S)
mn, mx = self.data.min(), self.data.max()
if mn < 0:
warn(
Expand Down Expand Up @@ -263,14 +271,14 @@ def xtal2png(

def fit(
self,
structures: Union[
List[Union[Structure, str, "PathLike[str]"]], str, "PathLike[str]"
],
structures: List[Union[Structure, str, "PathLike[str]"]],
y=None,
fit_quantiles=(0.00, 0.99),
verbose=True,
verbose=None,
):
_, structures = self.process_filepaths_or_structures(structures)
verbose = self.verbose if verbose is None else verbose

_, S = self.process_filepaths_or_structures(structures)

# TODO: deal with arbitrary site_properties
atomic_numbers = []
Expand All @@ -282,7 +290,7 @@ def fit(
distance = []
num_sites = []

for s in tqdm(structures):
for s in tqdm(S):
atomic_numbers.append(s.atomic_numbers)
lattice = s.lattice
a.append(lattice.a)
Expand Down Expand Up @@ -346,7 +354,8 @@ def fit(
setattr(self, name + "_range", tuple(bounds))

def process_filepaths_or_structures(
self, structures: Union[List[PathLike], List[Structure]]
self,
structures: List[Union[Structure, str, "PathLike[str]"]],
) -> Tuple[List[str], List[Structure]]:
"""Extract (or create) save names and convert/passthrough the structures.
Expand All @@ -368,23 +377,29 @@ def process_filepaths_or_structures(
Raises
------
ValueError
_description_
"structures should be of same datatype, either strs or pymatgen Structures.
structures[0] is {type(structures[0])}, but got type {type(s)} for entry
{i}"
ValueError
_description_
"structures should be of same datatype, either strs or pymatgen Structures.
structures[0] is {type(structures[0])}, but got type {type(s)} for entry
{i}"
ValueError
_description_
"structures should be of type `str`, `os.PathLike` or
`pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})"
Examples
--------
>>> save_names, structures = process_filepaths_or_structures(structures)
"""
save_names: List[str] = []

first_is_structure = isinstance(structures[0], Structure)
for i, s in enumerate(structures):
if isinstance(s, str) or isinstance(s, PathLike):
if first_is_structure:
raise ValueError(
f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa
f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa: E501
)

structures[i] = Structure.from_file(s)
Expand All @@ -403,7 +418,13 @@ def process_filepaths_or_structures(
f"structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})" # noqa
)

return save_names, structures
for i, s in enumerate(structures):
assert isinstance(
s, Structure
), f"structures[{i}]: {type(s)}, expected: Structure"
assert not isinstance(s, str) and not isinstance(s, PathLike)

return save_names, structures # type: ignore

def png2xtal(
self, images: List[Union[Image.Image, "PathLike"]], save: bool = False
Expand Down Expand Up @@ -820,10 +841,13 @@ def arrays_to_structures(

atomic_numbers = np.round(atomic_numbers).astype(int)

# REVIEW: round fractional coordinates to nearest multiple?

# TODO: tweak lattice parameters to match predicted space group rules

if self.relax_on_decode:
if not self.verbose:
tf.get_logger().setLevel(logging.ERROR)
relaxer = Relaxer() # This loads the default pre-trained model

# build Structure-s
S: List[Structure] = []
for i in range(len(atomic_numbers)):
Expand All @@ -843,6 +867,23 @@ def arrays_to_structures(
a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma
)
structure = Structure(lattice, at, fr)

# REVIEW: round fractional coordinates to nearest multiple?
if self.relax_on_decode:
relaxed_results = relaxer.relax(structure, verbose=self.verbose)
structure = relaxed_results["final_structure"]

# relax_results = relaxer.relax()
# final_structure = relax_results["final_structure"]
# final_energy = relax_results["trajectory"].energies[-1] / 2

# print(
# f"Relaxed lattice parameter is
# {final_structure.lattice.abc[0]:.3f} Å"
# )
# # TODO: print the initial energy as well (assuming it's available)
# print(f"Final energy is {final_energy.item(): .3f} eV/atom")

spa = SpacegroupAnalyzer(
structure,
symprec=self.decode_symprec,
Expand All @@ -852,8 +893,13 @@ def arrays_to_structures(
structure = spa.get_primitive_standard_structure()
else:
structure = spa.get_refined_structure()

S.append(structure)

if self.relax_on_decode:
# restore default https://stackoverflow.com/a/51340381/13697228
sys.stdout = sys.__stdout__

return S


Expand Down
36 changes: 24 additions & 12 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,62 +114,62 @@ def assert_structures_approximate_match(


def test_structures_to_arrays():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
data, _, _ = xc.structures_to_arrays(example_structures)
return data


def test_structures_to_arrays_single():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
data, _, _ = xc.structures_to_arrays([example_structures[0]])
return data


def test_arrays_to_structures():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays(example_structures)
structures = xc.arrays_to_structures(data, id_data, id_mapper)
assert_structures_approximate_match(example_structures, structures)
return structures


def test_arrays_to_structures_single():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays([example_structures[0]])
structures = xc.arrays_to_structures(data, id_data, id_mapper)
assert_structures_approximate_match([example_structures[0]], structures)
return structures


def test_xtal2png():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png(example_structures, show=False, save=True)
return imgs


def test_xtal2png_single():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png([example_structures[0]], show=False, save=True)
return imgs


def test_png2xtal():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png(example_structures, show=True, save=True)
decoded_structures = xc.png2xtal(imgs)
assert_structures_approximate_match(example_structures, decoded_structures)


def test_png2xtal_single():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png([example_structures[0]], show=True, save=True)
decoded_structures = xc.png2xtal(imgs, save=False)
assert_structures_approximate_match([example_structures[0]], decoded_structures)
return decoded_structures


def test_png2xtal_rgb_image():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png(example_structures, show=False, save=False)
imgs = [img.convert("RGB") for img in imgs]
decoded_structures = xc.png2xtal(imgs)
Expand All @@ -183,6 +183,7 @@ def test_primitive_encoding():
angle_tolerance=5.0,
encode_as_primitive=True,
decode_as_primitive=False,
relax_on_decode=False,
)
input_structures = [
SpacegroupAnalyzer(
Expand All @@ -195,7 +196,7 @@ def test_primitive_encoding():
assert_structures_approximate_match(
example_structures, decoded_structures, tol_multiplier=2.0
)
1 + 1
return decoded_structures


def test_primitive_decoding():
Expand All @@ -204,6 +205,7 @@ def test_primitive_decoding():
angle_tolerance=5.0,
encode_as_primitive=False,
decode_as_primitive=True,
relax_on_decode=False,
)
input_structures = [
SpacegroupAnalyzer(
Expand All @@ -224,11 +226,10 @@ def test_primitive_decoding():
example_structures, decoded_structures, tol_multiplier=2.0
)
return decoded_structures
1 + 1


def test_fit():
xc = XtalConverter()
xc = XtalConverter(relax_on_decode=False)
xc.fit(example_structures + dummy_structures)
assert_array_equal((14, 82), xc.atom_range)
assert_allclose((3.84, 12.718448099999998), xc.a_range)
Expand Down Expand Up @@ -274,6 +275,16 @@ def test_rgb_scaler_unscaler():
assert_allclose(check_output, scaled)


def test_relax_on_decode():
xc = XtalConverter(relax_on_decode=True)
imgs = xc.xtal2png(example_structures, show=False, save=False)
decoded_structures = xc.png2xtal(imgs)
assert_structures_approximate_match(
example_structures, decoded_structures, tol_multiplier=4.0
)
return decoded_structures


def test_plot_and_save():
df = px.data.tips()
fig = px.histogram(df, x="day")
Expand All @@ -284,6 +295,7 @@ def test_plot_and_save():


if __name__ == "__main__":
test_relax_on_decode()
test_primitive_decoding()
test_primitive_encoding()
test_fit()
Expand Down

0 comments on commit 10c9cec

Please sign in to comment.