Skip to content

Commit

Permalink
Merge pull request #111 from sparks-baird/rgb-averaging
Browse files Browse the repository at this point in the history
support RGB (3-channel) averaging in addition to grayscale images
  • Loading branch information
sgbaird authored Jun 18, 2022
2 parents 2e3e9bf + c30f7a3 commit 03a397b
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
7 changes: 5 additions & 2 deletions scripts/denoising_diffusion_pytorch_pretrained_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pymatviz.elements import ptable_heatmap_plotly

from xtal2png.core import XtalConverter
from xtal2png.utils.data import rgb_scaler
from xtal2png.utils.data import get_image_mode, rgb_scaler

mpt = MPTimeSplit()
mpt.load()
Expand Down Expand Up @@ -63,7 +63,10 @@
unscaled_arrays = np.squeeze(img_arrays_torch.cpu().numpy())
rgb_arrays = rgb_scaler(unscaled_arrays, data_range=(0, 1))

sampled_images = [Image.fromarray(arr, "I") for arr in rgb_arrays]
mode = get_image_mode(rgb_arrays[0])
if mode == "RGB":
rgb_arrays = [arr.transpose(1, 2, 0) for arr in rgb_arrays]
sampled_images = [Image.fromarray(arr, mode) for arr in rgb_arrays]

gen_path = path.join(
"data", "preprocessed", "mp-time-split", "ddpm", f"fold={fold}", uid
Expand Down
48 changes: 44 additions & 4 deletions src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
dummy_structures,
element_wise_scaler,
element_wise_unscaler,
get_image_mode,
rgb_scaler,
rgb_unscaler,
)
Expand Down Expand Up @@ -133,6 +134,13 @@ class XtalConverter:
tuple. Same applies for ``angle_tolerance``. By default True
relax_on_decode: bool, optional
Use m3gnet to relax the decoded crystal structures.
channels : int, optional
Number of channels, a positive integer. Typically choices would be 1 (grayscale)
or 3 (RGB), and are the only compatible choices when using
func:``XtalConverter().xtal2png`` and func:``XtalConverter().png2xtal``. For
positive integers other than 1 or 3, use
func:``XtalConverter().structures_to_arrays`` and
func:``XtalConverter().arrays_to_structures`` directly instead.
verbose: bool, optional
Whether to print verbose debugging information or not.
Expand Down Expand Up @@ -161,6 +169,7 @@ def __init__(
encode_as_primitive: bool = False,
decode_as_primitive: bool = False,
relax_on_decode: bool = True,
channels: int = 1,
verbose: bool = True,
):
"""Instantiate an XtalConverter object with desired ranges and ``max_sites``."""
Expand Down Expand Up @@ -193,6 +202,8 @@ def __init__(
self.encode_as_primitive = encode_as_primitive
self.decode_as_primitive = decode_as_primitive
self.relax_on_decode = relax_on_decode

self.channels = channels
self.verbose = verbose

Path(save_dir).mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -263,7 +274,11 @@ def xtal2png(
# convert to PNG images. Save and/or show, if applicable
imgs: List[Image.Image] = []
for d, save_name in zip(self.data, save_names):
img = Image.fromarray(d, mode="L")
mode = get_image_mode(d)
d = np.squeeze(d)
if mode == "RGB":
d = d.transpose(1, 2, 0)
img = Image.fromarray(d, mode=mode)
imgs.append(img)
if save:
savepath = path.join(self.save_dir, save_name + ".png")
Expand Down Expand Up @@ -449,16 +464,30 @@ def png2xtal(
OUTPUT
"""
data_tmp = []
if self.channels == 1:
mode = "L"
elif self.channels == 3:
mode = "RGB"
else:
raise ValueError(
f"expected grayscale (1-channel) or RGB (3-channels) image, but got {self.channels}-channels. Either set channels to 1 or 3 or use xc.structures_to_arrays and xc.arrays_to_structures directly instead of xc.xtal2png and xc.png2xtal" # noqa: E501
)
for img in images:
if isinstance(img, str):
# load image from file
with Image.open(img).convert("L") as im:
data_tmp.append(np.asarray(im))
with Image.open(img).convert(mode) as im:
arr = np.asarray(im)
elif isinstance(img, Image.Image):
data_tmp.append(np.asarray(img.convert("L")))
arr = np.asarray(img.convert(mode))
if mode == "RGB":
arr = arr.transpose(2, 0, 1)
data_tmp.append(arr)

data = np.stack(data_tmp, axis=0)

if mode == "L":
data = np.expand_dims(data, 1)

S = self.arrays_to_structures(data)

if save:
Expand Down Expand Up @@ -695,6 +724,12 @@ def structures_to_arrays(
]
id_data = self.assemble_blocks(*id_blocks)

data = np.expand_dims(data, 1)
id_data = np.expand_dims(id_data, 1)

data = np.repeat(data, self.channels, 1)
id_data = np.repeat(id_data, self.channels, 1)

return data, id_data, id_mapper

def assemble_blocks(
Expand Down Expand Up @@ -843,6 +878,11 @@ def arrays_to_structures(
raise ValueError(
f"`data` should be of type `np.ndarray`. Received type {type(data)}. Maybe you passed a tuple of (data, id_data, id_mapper) returned from `structures_to_arrays()` by accident?" # noqa: E501
)

# convert to single channel and remove singleton dimension before disassembly
data = np.mean(data, axis=1)
if id_data is not None:
id_data = np.mean(id_data, axis=1)
arrays = self.disassemble_blocks(data, id_data=id_data, id_mapper=id_mapper)

(
Expand Down
15 changes: 15 additions & 0 deletions src/xtal2png/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,18 @@ def rgb_unscaler(
rgb_range = [0, 255]
X_scaled = element_wise_unscaler(X, data_range=data_range, feature_range=rgb_range)
return X_scaled


def get_image_mode(d):
if d.ndim != 3:
raise ValueError("expected an array with 3 dimensions, received {d.ndim} dims")
if d.shape[0] == 3:
mode = "RGB"
elif d.shape[0] == 1:
mode = "L"
else:
raise ValueError(
f"Expected a single-channel or 3-channel array, but received a {d.ndim}-channel array." # noqa: E501
)

return mode
18 changes: 18 additions & 0 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def test_xtal2png_single():
return imgs


def test_xtal2png_three_channels():
xc = XtalConverter(relax_on_decode=False, channels=3)
imgs = xc.xtal2png(example_structures, show=False, save=False)
return imgs


def test_png2xtal():
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png(example_structures, show=True, save=True)
Expand All @@ -203,6 +209,16 @@ def test_png2xtal_rgb_image():
return decoded_structures


def test_png2xtal_three_channels():
xc = XtalConverter(relax_on_decode=False, channels=3)
imgs = xc.xtal2png(example_structures, show=False, save=False)
img_shape = np.asarray(imgs[0]).shape
if img_shape != (64, 64, 3):
raise ValueError(f"Expected image shape: (3, 64, 64), received: {img_shape}")
decoded_structures = xc.png2xtal(imgs)
assert_structures_approximate_match(example_structures, decoded_structures)


def test_primitive_encoding():
xc = XtalConverter(
symprec=0.1,
Expand Down Expand Up @@ -321,6 +337,8 @@ def test_plot_and_save():


if __name__ == "__main__":
test_xtal2png_three_channels()
test_png2xtal_three_channels()
test_structures_to_arrays_zero_one()
test_arrays_to_structures_zero_one()
test_relax_on_decode()
Expand Down

0 comments on commit 03a397b

Please sign in to comment.