Skip to content

Commit

Permalink
Merge 825066c into 2ae0110
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbaird committed Jun 24, 2022
2 parents 2ae0110 + 825066c commit 1811a1f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 6 deletions.
71 changes: 65 additions & 6 deletions src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,27 @@

ATOM_KEY = "atom"
FRAC_KEY = "frac"
A_KEY = "latt_a"
B_KEY = "latt_b"
C_KEY = "latt_c"
A_KEY = "a"
B_KEY = "b"
C_KEY = "c"
ANGLES_KEY = "angles"
VOLUME_KEY = "volume"
SPACE_GROUP_KEY = "space_group"
DISTANCE_KEY = "distance"
LOWER_TRI_KEY = "lower_tri"

SUPPORTED_MASK_KEYS = [
ATOM_KEY,
FRAC_KEY,
A_KEY,
B_KEY,
C_KEY,
ANGLES_KEY,
VOLUME_KEY,
SPACE_GROUP_KEY,
DISTANCE_KEY,
LOWER_TRI_KEY,
]


def construct_save_name(s: Structure) -> str:
Expand Down Expand Up @@ -145,6 +159,11 @@ class XtalConverter:
func:``XtalConverter().arrays_to_structures`` directly instead.
verbose: bool, optional
Whether to print verbose debugging information or not.
mask_types : List[str], optional
List of information types to mask out (assign as 0) from the array/image. values
are "atom", "frac", "a", "b", "c", "angles", "volume", "space_group",
"distance", "diagonal", and None. If None, then no masking is applied. If
"diagonal" is present, then zeros out the lower triangle. By default, None.
Examples
--------
Expand Down Expand Up @@ -173,6 +192,7 @@ def __init__(
relax_on_decode: bool = False,
channels: int = 1,
verbose: bool = True,
mask_types: List[str] = [],
):
"""Instantiate an XtalConverter object with desired ranges and ``max_sites``."""
self.atom_range = atom_range
Expand Down Expand Up @@ -213,6 +233,14 @@ def __init__(
else:
self.tqdm_if_verbose = lambda x: x

unsupported_mask_types = np.setdiff1d(mask_types, SUPPORTED_MASK_KEYS).tolist()
if unsupported_mask_types != []:
raise ValueError(
f"{unsupported_mask_types} is/are not a valid mask type. Expected one of {SUPPORTED_MASK_KEYS}. Received {mask_types}" # noqa: E501
)

self.mask_types = mask_types

Path(save_dir).mkdir(exist_ok=True, parents=True)

def xtal2png(
Expand Down Expand Up @@ -299,9 +327,29 @@ def fit(
self,
structures: List[Union[Structure, str, "PathLike[str]"]],
y=None,
fit_quantiles=(0.00, 0.99),
verbose=None,
fit_quantiles: Tuple[float, float] = (0.00, 0.99),
verbose: Optional[bool] = None,
):
"""Find optimal range parameters for encoding crystal structures.
Parameters
----------
structures : List[Union[Structure, str, "PathLike[str]"]]
List of pymatgen Structure objects.
y : NoneType, optional
No effect, for compatibility only, by default None
fit_quantiles : Tuple[float,float], optional
The lower and upper quantiles to use for fitting ranges to the data, by
default (0.00, 0.99)
verbose : Optional[bool], optional
Whether to print information about the fitted ranges. If None, then defaults
to ``self.verbose``. By default None
Examples
--------
>>> fit(structures, , y=None, fit_quantiles=(0.00, 0.99), verbose=None, )
OUTPUT
"""
verbose = self.verbose if verbose is None else verbose

_, S = self.process_filepaths_or_structures(structures)
Expand Down Expand Up @@ -576,7 +624,7 @@ def structures_to_arrays(
n_sites = len(s.atomic_numbers)
if n_sites > self.max_sites:
raise ValueError(
f"crystal supplied with {n_sites} sites, which is more than {self.max_sites} sites. Remove crystal or increase `max_sites`." # noqa
f"crystal supplied with {n_sites} sites, which is more than {self.max_sites} sites. Remove the offending crystal(s), increase `max_sites`, or use a more compact cell_type (see encode_cell_type and decode_cell_type kwargs)." # noqa: E501
)
atomic_numbers.append(
np.pad(
Expand Down Expand Up @@ -720,6 +768,17 @@ def structures_to_arrays(
data = np.expand_dims(data, 1)
id_data = np.expand_dims(id_data, 1)

for mask_type in self.mask_types:
if mask_type == LOWER_TRI_KEY:
for d in data:
if d.shape[1] != d.shape[2]:
raise ValueError(
f"Expected square matrix in last two dimensions, received {d.shape}" # noqa: E501
)
d[:, np.mask_indices(d.shape[1], np.tril)] = 0.0
else:
data[id_data == id_mapper[mask_type]] = 0.0

data = np.repeat(data, self.channels, 1)
id_data = np.repeat(id_data, self.channels, 1)

Expand Down
35 changes: 35 additions & 0 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,45 @@ def test_plot_and_save():
plot_and_save("reports/figures/tmp", fig, mpl_kwargs={})


def test_distance_mask():
xc = XtalConverter(mask_types=["distance"])
imgs = xc.xtal2png(example_structures)
if not np.all(xc.data[xc.id_data == xc.id_mapper["distance"]] == 0):
raise ValueError("Distance mask not applied correctly (id_mapper)")

if not np.all(xc.data[:, :, 12:, 12:] == 0):
raise ValueError("Distance mask not applied correctly (hardcoded)")

return imgs


def test_lower_tri_mask():
xc = XtalConverter(mask_types=["lower_tri"])
imgs = xc.xtal2png(example_structures)
if not np.all(xc.data[np.tril(xc.data[0, 0])] == 0):
raise ValueError("Lower triangle mask not applied correctly")

return imgs


def test_mask_error():
xc = XtalConverter(mask_types=["atom"])
imgs = xc.xtal2png(example_structures)

decoded_structures = xc.png2xtal(imgs)

for s in decoded_structures:
if s.num_sites > 0:
raise ValueError("Atom mask should have wiped out atomic sites.")


# TODO: test_matplotlibify with assertion


if __name__ == "__main__":
test_lower_tri_mask()
test_mask_error()
test_distance_mask()
test_xtal2png_three_channels()
test_png2xtal_three_channels()
test_structures_to_arrays_zero_one()
Expand Down

0 comments on commit 1811a1f

Please sign in to comment.