diff --git a/README.md b/README.md index 1b2a859..4718536 100644 --- a/README.md +++ b/README.md @@ -34,20 +34,26 @@ conda activate xtal2png from xtal2png.utils.data import example_structures from xtal2png.core import XtalConverter -xc = XtalConverter() +xc = XtalConverter() # DFT surrogate relaxation via m3gnet by default +data = xc.xtal2png(example_structures, show=True, save=True) +relaxed_decoded_structures = xc.png2xtal(data, save=False) + + +xc = XtalConverter(relax_on_decode=False) data = xc.xtal2png(example_structures, show=True, save=True) decoded_structures = xc.png2xtal(data, save=False) ``` ### Output ```python -print(example_structures[0], decoded_structures[0]) +print(example_structures[0], decoded_structures[0], relaxed_decoded_structures[0]) ``` + +\n", "\n", "\n", + "\n", "\n", "\n", "\n", + "\n", "\n", "
Original Decoded Relaxed Decoded
@@ -58,11 +64,28 @@ Lattice abc : 5.033788 11.523021 10.74117 angles : 90.0 90.0 90.0 volume : 623.0356027127609 - A : 5.033788 0.0 3.082306e-16 - B : 1.853043e-15 11.523021 7.055815e-16 + A : 5.033788 0.0 3.0823061808931787e-16 + B : 1.8530431062799525e-15 11.523021 7.055815392078867e-16 C : 0.0 0.0 10.74117 -PeriodicSite: Zn2+ (0.912, 5.770, 9.126) [0.181, 0.501, 0.850] -PeriodicSite: Zn2+ (4.122, 5.753, 1.616) [0.8188, 0.499, 0.150] +PeriodicSite: Zn2+ (0.9120, 5.7699, 9.1255) [0.1812, 0.5007, 0.8496] +PeriodicSite: Zn2+ (4.1218, 5.7531, 1.6156) [0.8188, 0.4993, 0.1504] +... +``` + + + +```python +Structure Summary +Lattice + abc : 5.0250980392156865 11.533333333333331 10.8 + angles : 90.0 90.0 90.0 + volume : 625.9262117647058 + A : 5.0250980392156865 0.0 0.0 + B : 0.0 11.533333333333331 0.0 + C : 0.0 0.0 10.8 +PeriodicSite: Zn (0.9016, 5.7780, 3.8012) [0.1794, 0.5010, 0.3520] +PeriodicSite: Zn (4.1235, 5.7554, 6.9988) [0.8206, 0.4990, 0.6480] ... ``` @@ -72,14 +95,14 @@ PeriodicSite: Zn2+ (4.122, 5.753, 1.616) [0.8188, 0.499, 0.150] ```python Structure Summary Lattice - abc : 5.058824 11.529412 10.764706 - angles : 90.352941 90.352941 90.352941 - volume : 627.818381 - A : 5.058728 0.0 -0.031162 - B : -0.071459 11.528972 -0.071021 - C : 0.0 0.0 10.764706 -PeriodicSite: Zn (0.877, 5.787, 9.119) [0.180, 0.502, 0.851] -PeriodicSite: Zn (4.111, 5.742, 1.543) [0.820, 0.498, 0.149] + abc : 5.026834307381214 11.578854613685237 10.724087971087924 + angles : 90.0 90.0 90.0 + volume : 624.1953646135236 + A : 5.026834307381214 0.0 0.0 + B : 0.0 11.578854613685237 0.0 + C : 0.0 0.0 10.724087971087924 +PeriodicSite: Zn (0.9050, 5.7978, 3.7547) [0.1800, 0.5007, 0.3501] +PeriodicSite: Zn (4.1218, 5.7810, 6.9693) [0.8200, 0.4993, 0.6499] ... ``` diff --git a/notebooks/1.0-xtal2png-tutorial.ipynb b/notebooks/1.0-xtal2png-tutorial.ipynb index 9b1a102..893f528 100644 --- a/notebooks/1.0-xtal2png-tutorial.ipynb +++ b/notebooks/1.0-xtal2png-tutorial.ipynb @@ -1,27 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "xtal2png-tutorial.ipynb", - "provenance": [], - "authorship_tag": "ABX9TyNJG5GbIiRm3b22dp5jDrbI", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -29,39 +12,39 @@ }, { "cell_type": "markdown", - "source": [ - "# Encode/decode a crystal structure to/from a grayscale PNG image" - ], "metadata": { "id": "mHJ4l5gKuXx-" - } + }, + "source": [ + "# Encode/decode a crystal structure to/from a grayscale PNG image" + ] }, { "cell_type": "markdown", - "source": [ - "In this notebook, we will install the `xtal2png` package, encode/decode two example pymatgen `Structure` objects, and show some visualizations of the intermediate PNG representations and before/after crystal structure plots. Finally, we comment on how you can use `xtal2png` with state-of-the-art machine learning image models." - ], "metadata": { "id": "pjrtpDdbuyqp" - } + }, + "source": [ + "In this notebook, we will install the `xtal2png` package, encode/decode two example pymatgen `Structure` objects, and show some visualizations of the intermediate PNG representations and before/after crystal structure plots. Finally, we comment on how you can use `xtal2png` with state-of-the-art machine learning image models." + ] }, { "cell_type": "markdown", - "source": [ - "## Installation" - ], "metadata": { "id": "In1KcDcVuxw6" - } + }, + "source": [ + "## Installation" + ] }, { "cell_type": "markdown", - "source": [ - "Install the `xtal2png` package. Optionally install `ase` and `nglview` which can be used to visualize crystal structures. You may need to restart the runtime via `Ctrl+M, .` or `Runtime --> Restart Runtime` (via menubar)." - ], "metadata": { "id": "93hxUzlfqb4v" - } + }, + "source": [ + "Install the `xtal2png` package. Optionally install `ase` and `nglview` which can be used to visualize crystal structures. You may need to restart the runtime via `Ctrl+M, .` or `Runtime --> Restart Runtime` (via menubar)." + ] }, { "cell_type": "code", @@ -75,8 +58,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting xtal2png\n", @@ -143,134 +126,157 @@ }, { "cell_type": "markdown", - "source": [ - "## Encode(/decode) two pymatgen `Structure` objects" - ], "metadata": { "id": "5_ur18igvGap" - } + }, + "source": [ + "## Encode(/decode) two pymatgen `Structure` objects" + ] }, { "cell_type": "markdown", - "source": [ - "Import a list of two example pymatgen `Structure` objects (these correspond to [mp-560471](https://next-gen.materialsproject.org/materials/mp-560471)/$Zn_2B_2PbO_6$ and [mp-7823](https://next-gen.materialsproject.org/materials/mp-7823)/$V_2NiSe_4$, respectively)" - ], "metadata": { "id": "auHOZ-dAhd5t" - } + }, + "source": [ + "Import a list of two example pymatgen `Structure` objects (these correspond to [mp-560471](https://next-gen.materialsproject.org/materials/mp-560471)/$Zn_2B_2PbO_6$ and [mp-7823](https://next-gen.materialsproject.org/materials/mp-7823)/$V_2NiSe_4$, respectively)" + ] }, { "cell_type": "code", - "source": [ - "from xtal2png.utils.data import example_structures\n", - "from xtal2png.core import XtalConverter" - ], + "execution_count": null, "metadata": { "id": "qDET_Hc6UXP7" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "from xtal2png.utils.data import example_structures\n", + "from xtal2png.core import XtalConverter" + ] }, { "cell_type": "markdown", - "source": [ - "Let's take a look at the second `Structure` which has a smaller footprint." - ], "metadata": { "id": "EcxElT0bhyxF" - } + }, + "source": [ + "Let's take a look at the second `Structure` which has a smaller footprint." + ] }, { "cell_type": "code", - "source": [ - "example_structures[1]" - ], + "execution_count": null, "metadata": { "id": "S2I58v39hyNO" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "example_structures[1]" + ] }, { "cell_type": "markdown", - "source": [ - "We will be using the `XtalConverter` class, for which more information including its `__init__` arguments and functions can be displayed via `help(XtalConverter)`. For just the parameters for class instantiation, try `help(XtalConverter.__init__)`. Note that `max_sites` is not tested for values other than `52`." - ], "metadata": { "id": "ulVWYCeoiuTM" - } + }, + "source": [ + "We will be using the `XtalConverter` class, for which more information including its `__init__` arguments and functions can be displayed via `help(XtalConverter)`. For just the parameters for class instantiation, try `help(XtalConverter.__init__)`. Note that `max_sites` is not tested for values other than `52`." + ] }, { "cell_type": "code", - "source": [ - "help(XtalConverter.__init__)" - ], + "execution_count": null, "metadata": { "id": "hZ4LtYcJkoOC" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "help(XtalConverter.__init__)" + ] }, { "cell_type": "markdown", - "source": [ - "Let's specify the save directory (`save_dir`) for the PNG files as `\"data\"`, which will be automatically created. In this case, it will be saved to temporary Google Colab storage." - ], "metadata": { "id": "weBReOogvmeK" - } + }, + "source": [ + "Let's specify the save directory (`save_dir`) for the PNG files as `\"data\"`, which will be automatically created. In this case, it will be saved to temporary Google Colab storage." + ] }, { "cell_type": "code", - "source": [ - "xc = XtalConverter(save_dir=\"data\")\n", - "data = xc.xtal2png(example_structures, save=True)\n", - "decoded_structures = xc.png2xtal(data, save=False)" - ], + "execution_count": null, "metadata": { "id": "B_S0jeVsVfYK" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "xc = XtalConverter(save_dir=\"data\") # DFT surrogate relaxation via m3gnet by default\n", + "data = xc.xtal2png(example_structures, save=True)\n", + "relaxed_decoded_structures = xc.png2xtal(data, save=False)" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ - "## Visualization" - ], + "We also take a look at the unrelaxed decoded structures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xc = XtalConverter(save_dir=\"data\", relax_on_decode=False)\n", + "data = xc.xtal2png(example_structures, save=True)\n", + "decoded_structures = xc.png2xtal(data, save=False)" + ] + }, + { + "cell_type": "markdown", "metadata": { "id": "DNL8GcdZwixh" - } + }, + "source": [ + "## Visualization" + ] }, { "cell_type": "markdown", - "source": [ - "For visualization, we'll cover two aspects: the structure-encoded PNG images and visualizing before/after crystal structures." - ], "metadata": { "id": "Cexf9PqBwkMR" - } + }, + "source": [ + "For visualization, we'll cover two aspects: the structure-encoded PNG images and visualizing before/after crystal structures." + ] }, { "cell_type": "markdown", - "source": [ - "### Structure-encoded PNG images" - ], "metadata": { "id": "Xm8i5EZ5wsqJ" - } + }, + "source": [ + "### Structure-encoded PNG images" + ] }, { "cell_type": "markdown", - "source": [ - "Note that images won't show via `im.show()` command _on Google Colab_ even if you specify `xc.xtal2png(..., show=True, ...)`, so for this Colab example we'll open the images ad-hoc based on where they were saved in local Colab storage. We display the images stacked one on top of another using `display(im)` instead of `im.show()`. Note that the filepaths have the chemical formula, `volume`, and a randomly generated `uid` portion to promote uniqueness, especially when dealing with allotropes (same chemical formula, different crystal structure)." - ], "metadata": { "id": "TQ8jkgYrhcg2" - } + }, + "source": [ + "Note that images won't show via `im.show()` command _on Google Colab_ even if you specify `xc.xtal2png(..., show=True, ...)`, so for this Colab example we'll open the images ad-hoc based on where they were saved in local Colab storage. We display the images stacked one on top of another using `display(im)` instead of `im.show()`. Note that the filepaths have the chemical formula, `volume`, and a randomly generated `uid` portion to promote uniqueness, especially when dealing with allotropes (same chemical formula, different crystal structure)." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pV0L8N85kTFk" + }, + "outputs": [], "source": [ "import glob, os\n", "from PIL import Image\n", @@ -279,26 +285,24 @@ " im = im.resize((64*5, 64*5), Image.BOX)\n", " print(fpath)\n", " display(im)" - ], - "metadata": { - "id": "pV0L8N85kTFk" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "oU6FjfDYnnuI" + }, "source": [ "As mentioned in the README, the legend key for these images is as follows:\n", "\n", "        " - ], - "metadata": { - "id": "oU6FjfDYnnuI" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "U-rTkYeMxxAQ" + }, "source": [ "Also described in the README, the match between the encoded and decoded versions is within an expected tolerance, given that PNG images are represented as discrete RGB values between 0 and 255 (i.e. there is a round-off error).\n", "\n", @@ -306,6 +310,7 @@ "
Original Decoded Relaxed Decoded
\n", @@ -316,11 +321,11 @@ " abc : 5.033788 11.523021 10.74117\n", " angles : 90.0 90.0 90.0\n", " volume : 623.0356027127609\n", - " A : 5.033788 0.0 3.082306e-16\n", - " B : 1.853043e-15 11.523021 7.055815e-16\n", + " A : 5.033788 0.0 3.0823061808931787e-16\n", + " B : 1.8530431062799525e-15 11.523021 7.055815392078867e-16\n", " C : 0.0 0.0 10.74117\n", - "PeriodicSite: Zn2+ (0.912, 5.770, 9.126) [0.181, 0.501, 0.850]\n", - "PeriodicSite: Zn2+ (4.122, 5.753, 1.616) [0.8188, 0.499, 0.150]\n", + "PeriodicSite: Zn2+ (0.9120, 5.7699, 9.1255) [0.1812, 0.5007, 0.8496]\n", + "PeriodicSite: Zn2+ (4.1218, 5.7531, 1.6156) [0.8188, 0.4993, 0.1504]\n", "...\n", "```\n", "\n", @@ -330,128 +335,168 @@ "```python\n", "Structure Summary\n", "Lattice\n", - " abc : 5.058824 11.529412 10.764706\n", - " angles : 90.352941 90.352941 90.352941\n", - " volume : 627.818381\n", - " A : 5.058728 0.0 -0.031162\n", - " B : -0.071459 11.528972 -0.071021\n", - " C : 0.0 0.0 10.764706\n", - "PeriodicSite: Zn (0.877, 5.787, 9.119) [0.180, 0.502, 0.851]\n", - "PeriodicSite: Zn (4.111, 5.742, 1.543) [0.820, 0.498, 0.149]\n", + " abc : 5.0250980392156865 11.533333333333331 10.8\n", + " angles : 90.0 90.0 90.0\n", + " volume : 625.9262117647058\n", + " A : 5.0250980392156865 0.0 0.0\n", + " B : 0.0 11.533333333333331 0.0\n", + " C : 0.0 0.0 10.8\n", + "PeriodicSite: Zn (0.9016, 5.7780, 3.8012) [0.1794, 0.5010, 0.3520]\n", + "PeriodicSite: Zn (4.1235, 5.7554, 6.9988) [0.8206, 0.4990, 0.6480]\n", + "...\n", + "```\n", + "\n", + "\n", + "\n", + "```python\n", + "Structure Summary\n", + "Lattice\n", + " abc : 5.026834307381214 11.578854613685237 10.724087971087924\n", + " angles : 90.0 90.0 90.0\n", + " volume : 624.1953646135236\n", + " A : 5.026834307381214 0.0 0.0\n", + " B : 0.0 11.578854613685237 0.0\n", + " C : 0.0 0.0 10.724087971087924\n", + "PeriodicSite: Zn (0.9050, 5.7978, 3.7547) [0.1800, 0.5007, 0.3501]\n", + "PeriodicSite: Zn (4.1218, 5.7810, 6.9693) [0.8200, 0.4993, 0.6499]\n", "...\n", "```\n", "\n", "
" - ], - "metadata": { - "id": "U-rTkYeMxxAQ" - } + ] }, { "cell_type": "markdown", - "source": [ - "### Before/after Crystal Structure Visualization" - ], "metadata": { "id": "6K8M_EWewvgK" - } + }, + "source": [ + "### Before/after Crystal Structure Visualization" + ] }, { "cell_type": "markdown", - "source": [ - "To visualize the crystal structures before and after, we can use `nglview` after a bit of Colab finnagling with external ipywidgets." - ], "metadata": { "id": "e1TFsM32pwsN" - } + }, + "source": [ + "To visualize the crystal structures before and after, we can use `nglview` after a bit of Colab finnagling with external ipywidgets." + ] }, { "cell_type": "code", - "source": [ - "from google.colab import output\n", - "output.enable_custom_widget_manager()" - ], + "execution_count": null, "metadata": { "id": "VDhN-1cFuGIq" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.enable_custom_widget_manager()" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kT2kMJkBriRT" + }, + "outputs": [], "source": [ "from pymatgen.io.ase import AseAtomsAdaptor\n", "from ase.visualize import view\n", "aaa = AseAtomsAdaptor()" - ], - "metadata": { - "id": "kT2kMJkBriRT" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "[display(view(aaa.get_atoms(s), viewer='ngl')) for s in example_structures]" - ], + "execution_count": null, "metadata": { "id": "40sCJggKsV5E" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "[display(view(aaa.get_atoms(s), viewer='ngl')) for s in example_structures]" + ] }, { "cell_type": "code", - "source": [ - "[display(view(aaa.get_atoms(s), viewer='ngl')) for s in decoded_structures]" - ], + "execution_count": null, "metadata": { "id": "FWMOd7gHsTB0" }, + "outputs": [], + "source": [ + "[display(view(aaa.get_atoms(s), viewer='ngl')) for s in decoded_structures]" + ] + }, + { + "cell_type": "code", "execution_count": null, - "outputs": [] + "metadata": {}, + "outputs": [], + "source": [ + "[display(view(aaa.get_atoms(s), viewer='ngl')) for s in relaxed_decoded_structures]" + ] }, { "cell_type": "markdown", - "source": [ - "Undo the Colab finnagling of external ipywidgets." - ], "metadata": { "id": "1u_95y04uIRa" - } + }, + "source": [ + "Undo the Colab finnagling of external ipywidgets." + ] }, { "cell_type": "code", - "source": [ - "from google.colab import output\n", - "output.disable_custom_widget_manager()" - ], + "execution_count": null, "metadata": { "id": "IVFYjEvUrfCP" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.disable_custom_widget_manager()" + ] }, { "cell_type": "markdown", - "source": [ - "## Final Remarks" - ], "metadata": { "id": "DLAfVjyrw5VK" - } + }, + "source": [ + "## Final Remarks" + ] }, { "cell_type": "markdown", - "source": [ - "This tool makes it possible to use state-of-the-art image-based machine learning models with minimal \"plumbing\" required. Just follow the normal instructions for custom image datasets. For example, this can be used with [Palette](https://iterative-refinement.github.io/palette/), an image-to-image guided diffusion model by Google, which has an unofficial implementation [here](https://github.com/Janspiry/Palette-Image-to-Image-Diffusion-Models)." - ], "metadata": { "id": "7QHtKaklxKip" - } + }, + "source": [ + "This tool makes it possible to use state-of-the-art image-based machine learning models with minimal \"plumbing\" required. Just follow the normal instructions for custom image datasets. For example, this can be used with [Palette](https://iterative-refinement.github.io/palette/), an image-to-image guided diffusion model by Google, which has an unofficial implementation [here](https://github.com/Janspiry/Palette-Image-to-Image-Diffusion-Models)." + ] } - ] + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyNJG5GbIiRm3b22dp5jDrbI", + "include_colab_link": true, + "name": "xtal2png-tutorial.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/setup.cfg b/setup.cfg index 3a1b2da..2c04569 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ install_requires = pymatgen plotly kaleido + m3gnet [options.packages.find] diff --git a/src/xtal2png/core.py b/src/xtal2png/core.py index 26dbc46..6331d67 100644 --- a/src/xtal2png/core.py +++ b/src/xtal2png/core.py @@ -12,12 +12,16 @@ from warnings import warn import numpy as np +import pandas as pd +import tensorflow as tf +from m3gnet.models import Relaxer from numpy.typing import NDArray from PIL import Image from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.io.cif import CifWriter from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from tqdm import tqdm from xtal2png import __version__ from xtal2png.utils.data import dummy_structures, rgb_scaler, rgb_unscaler @@ -59,7 +63,7 @@ DISTANCE_KEY = "distance" -def construct_save_name(s: Structure): +def construct_save_name(s: Structure) -> str: save_name = f"{s.formula.replace(' ', '')},volume={int(np.round(s.volume))},uid={str(uuid4())[0:4]}" # noqa: E501 return save_name @@ -105,14 +109,28 @@ class XtalConverter: save_dir : Union[str, 'PathLike[str]'] Directory to save PNG files via :func:``xtal2png``, by default path.join("data", "interim") - symprec : float, optional + symprec : Union[float, Tuple[float, float]], optional The symmetry precision to use when decoding `pymatgen` structures via - ``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By - default 0.1. - angle_tolerance : Union[float, int], optional + ``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. If + specified as a tuple, then ``symprec[0]`` applies to encoding and ``symprec[1]`` + applies to decoding. By default 0.1. + angle_tolerance : Union[float, int, Tuple[float, float], Tuple[int, int]], optional The angle tolerance (degrees) to use when decoding `pymatgen` structures via - ``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. By - default 5.0. + ``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. If + specified as a tuple, then ``angle_tolerance[0]`` applies to encoding and + ``angle_tolerance[1]`` applies to decoding. By default 5.0. + encode_as_primitive : bool, optional + Encode structures as symmetrized, primitive structures. Uses ``symprec`` if + ``symprec`` is of type float, else uses ``symprec[0]`` if ``symprec`` is of type + tuple. Same applies for ``angle_tolerance``. By default True + decode_as_primitive : bool, optional + Decode structures as symmetrized, primitive structures. Uses ``symprec`` if + ``symprec`` is of type float, else uses ``symprec[1]`` if ``symprec`` is of type + tuple. Same applies for ``angle_tolerance``. By default True + relax_on_decode: bool, optional + Use m3gnet to relax the decoded crystal structures. + verbose: bool, optional + Whether to print verbose debugging information or not. Examples -------- @@ -134,8 +152,12 @@ def __init__( distance_range: Tuple[float, float] = (0.0, 18.0), max_sites: int = 52, save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"), - symprec: float = 0.1, - angle_tolerance: float = 5.0, + symprec: Union[float, Tuple[float, float]] = 0.1, + angle_tolerance: Union[float, int, Tuple[float, float], Tuple[int, int]] = 5.0, + encode_as_primitive: bool = False, + decode_as_primitive: bool = False, + relax_on_decode: bool = True, + verbose: bool = True, ): """Instantiate an XtalConverter object with desired ranges and ``max_sites``.""" self.atom_range = atom_range @@ -149,16 +171,31 @@ def __init__( self.distance_range = distance_range self.max_sites = max_sites self.save_dir = save_dir - self.symprec = symprec - self.angle_tolerance = angle_tolerance + + if isinstance(symprec, (float, int)): + self.encode_symprec = symprec + self.decode_symprec = symprec + elif isinstance(symprec, tuple): + self.encode_symprec = symprec[0] + self.decode_symprec = symprec[1] + + if isinstance(angle_tolerance, (float, int)): + self.encode_angle_tolerance = angle_tolerance + self.decode_angle_tolerance = angle_tolerance + elif isinstance(angle_tolerance, tuple): + self.encode_angle_tolerance = angle_tolerance[0] + self.decode_angle_tolerance = angle_tolerance[1] + + self.encode_as_primitive = encode_as_primitive + self.decode_as_primitive = decode_as_primitive + self.relax_on_decode = relax_on_decode + self.verbose = verbose Path(save_dir).mkdir(exist_ok=True, parents=True) def xtal2png( self, - structures: Union[ - List[Union[Structure, str, "PathLike[str]"]], str, "PathLike[str]" - ], + structures: List[Union[Structure, str, "PathLike[str]"]], show: bool = False, save: bool = True, ): @@ -167,8 +204,7 @@ def xtal2png( Parameters ---------- structures : List[Union[Structure, str, PathLike[str]]] - pymatgen Structure objects or path to CIF files or path to directory - containing CIF files. + pymatgen Structure objects or path to CIF files. show : bool, optional Whether to display the PNG-encoded file, by default False save : bool, optional @@ -200,10 +236,11 @@ def xtal2png( >>> xc = XtalConverter() >>> xc.xtal2png(structures, show=False, save=True) """ + save_names, S = self.process_filepaths_or_structures(structures) # convert structures to 3D NumPy Matrices - self.data, self.id_data, self.id_keys = self.structures_to_arrays(S) + self.data, self.id_data, self.id_mapper = self.structures_to_arrays(S) mn, mx = self.data.min(), self.data.max() if mn < 0: warn( @@ -232,19 +269,140 @@ def xtal2png( return imgs - def process_filepaths_or_structures(self, structures): + def fit( + self, + structures: List[Union[Structure, str, "PathLike[str]"]], + y=None, + fit_quantiles=(0.00, 0.99), + verbose=None, + ): + verbose = self.verbose if verbose is None else verbose + + _, S = self.process_filepaths_or_structures(structures) + + # TODO: deal with arbitrary site_properties + atomic_numbers = [] + a = [] + b = [] + c = [] + space_group = [] + volume = [] + distance = [] + num_sites = [] + + for s in tqdm(S): + atomic_numbers.append(s.atomic_numbers) + lattice = s.lattice + a.append(lattice.a) + b.append(lattice.b) + c.append(lattice.c) + space_group.append(s.get_space_group_info()[1]) + volume.append(lattice.volume) + distance.append(s.distance_matrix) + num_sites.append(len(list(s.sites))) + + if verbose: + print("range of atomic_numbers is: ", min(a), "-", max(a)) + print("range of a is: ", min(a), "-", max(a)) + print("range of b is: ", min(b), "-", max(b)) + print("range of c is: ", min(c), "-", max(c)) + print("range of space_group is: ", min(space_group), "-", max(space_group)) + print("range of volume is: ", min(volume), "-", max(volume)) + print("range of num_sites is: ", min(num_sites), "-", max(num_sites)) + + dis_min_tmp = [] + dis_max_tmp = [] + for d in tqdm(range(len(distance))): + dis_min_tmp.append(min(distance[d][np.nonzero(distance[d])])) + dis_max_tmp.append(max(distance[d][np.nonzero(distance[d])])) + + atoms = np.array(atomic_numbers, dtype="object") + self.atom_range = (min(np.min(atoms)), max(np.max(atoms))) + self.space_group_range = (np.min(space_group), np.max(space_group)) + + self.num_sites = np.max(num_sites) + + df = pd.DataFrame( + dict( + a=a, + b=b, + c=c, + volume=volume, + min_distance=dis_min_tmp, + max_distance=dis_max_tmp, + ) + ) + + low_quantile, upp_quantile = fit_quantiles + + low_df = ( + df.apply(lambda a: np.quantile(a, low_quantile)) + .drop(["max_distance"]) + .rename(index={"min_distance": "distance"}) + ) + upp_df = ( + df.apply(lambda a: np.quantile(a, upp_quantile)) + .drop(["min_distance"]) + .rename(index={"max_distance": "distance"}) + ) + low_df.name = "low" + upp_df.name = "upp" + + range_df = pd.concat((low_df, upp_df), axis=1) + + for name, bounds in range_df.iterrows(): + setattr(self, name + "_range", tuple(bounds)) + + def process_filepaths_or_structures( + self, + structures: List[Union[Structure, str, "PathLike[str]"]], + ) -> Tuple[List[str], List[Structure]]: + """Extract (or create) save names and convert/passthrough the structures. + + Parameters + ---------- + structures : Union[PathLike, Structure] + List of filepaths or list of structures to be processed. + + Returns + ------- + save_names : List[str] + Save names of the files if filepaths are passed, otherwise some relatively + unique names (due to 4 random characters being appended at the end) for each + structure. See ``construct_save_name``. + + S : List[Structure] + Processed structures. + + Raises + ------ + ValueError + "structures should be of same datatype, either strs or pymatgen Structures. + structures[0] is {type(structures[0])}, but got type {type(s)} for entry + {i}" + ValueError + "structures should be of same datatype, either strs or pymatgen Structures. + structures[0] is {type(structures[0])}, but got type {type(s)} for entry + {i}" + ValueError + "structures should be of type `str`, `os.PathLike` or + `pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})" + + Examples + -------- + >>> save_names, structures = process_filepaths_or_structures(structures) + """ save_names: List[str] = [] - S: List[Structure] = [] + first_is_structure = isinstance(structures[0], Structure) for i, s in enumerate(structures): if isinstance(s, str) or isinstance(s, PathLike): if first_is_structure: raise ValueError( - f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa + f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa: E501 ) - # load the CIF and convert to a pymatgen Structure - S.append(Structure.from_file(s)) + structures[i] = Structure.from_file(s) save_names.append(Path(str(s)).stem) elif isinstance(s, Structure): @@ -253,14 +411,20 @@ def process_filepaths_or_structures(self, structures): f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa ) - S.append(s) + structures[i] = s save_names.append(construct_save_name(s)) else: raise ValueError( - f"structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure`, not {type(S)} (entry {i})" # noqa + f"structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})" # noqa ) - return save_names, S + for i, s in enumerate(structures): + assert isinstance( + s, Structure + ), f"structures[{i}]: {type(s)}, expected: Structure" + assert not isinstance(s, str) and not isinstance(s, PathLike) + + return save_names, structures # type: ignore def png2xtal( self, images: List[Union[Image.Image, "PathLike"]], save: bool = False @@ -297,14 +461,19 @@ def png2xtal( for s in S: fpath = path.join(self.save_dir, construct_save_name(s) + ".cif") CifWriter( - s, symprec=self.symprec, angle_tolerance=self.angle_tolerance + s, + symprec=self.decode_symprec, + angle_tolerance=self.decode_angle_tolerance, ).write_file(fpath) return S # unscale values - def structures_to_arrays(self, structures: Sequence[Structure]): + def structures_to_arrays( + self, + structures: Sequence[Structure], + ): """Convert pymatgen Structure to scaled 3D array of crystallographic info. ``atomic_numbers`` and ``distance_matrix` get padded or cropped as appropriate, @@ -314,6 +483,38 @@ def structures_to_arrays(self, structures: Sequence[Structure]): ---------- S : Sequence[Structure] Sequence (e.g. list) of pymatgen Structure object(s) + + Returns + ------- + data : ArrayLike + RGB-scaled arrays with first dimension corresponding to each crystal + structure. + + id_data : ArrayLike + Same shape as ``data``, except one-hot encoded to distinguish between the + various types of information contained in ``data``. See ``id_mapper`` for + the "legend" for this data. + + id_mapper : ArrayLike + Dictionary containing the legend/key between the names of the blocks and the + corresponding numbers in ``id_data``. + + Raises + ------ + ValueError + "`structures` should be a list of pymatgen Structure(s)" + ValueError + "crystal supplied with {n_sites} sites, which is more than {self.max_sites} + sites. Remove crystal or increase `max_sites`." + ValueError + "len(atomic_numbers) {n_sites} and distance_matrix.shape[0] + {s.distance_matrix.shape[0]} do not match" + + Examples + -------- + >>> xc = XtalConverter() + >>> data = xc.structures_to_arrays(structures) + OUTPUT """ if isinstance(structures, Structure): raise ValueError("`structures` should be a list of pymatgen Structure(s)") @@ -321,14 +522,29 @@ def structures_to_arrays(self, structures: Sequence[Structure]): # extract crystallographic information atomic_numbers: List[List[int]] = [] frac_coords_tmp: List[NDArray] = [] - latt_a: List[List[float]] = [] - latt_b: List[List[float]] = [] - latt_c: List[List[float]] = [] + latt_a: List[float] = [] + latt_b: List[float] = [] + latt_c: List[float] = [] angles: List[List[float]] = [] volume: List[float] = [] space_group: List[int] = [] distance_matrix_tmp: List[NDArray[np.float64]] = [] + sym_structures = [] + for s in structures: + spa = SpacegroupAnalyzer( + s, + symprec=self.encode_symprec, + angle_tolerance=self.encode_angle_tolerance, + ) + if self.encode_as_primitive: + s = spa.get_primitive_standard_structure() + else: + s = spa.get_refined_structure() + sym_structures.append(s) + + structures = sym_structures + for s in structures: n_sites = len(s.atomic_numbers) if n_sites > self.max_sites: @@ -509,18 +725,6 @@ def disassemble_blocks( id_data is not None and id_mapper is not None ), "id_data and id_mapper should not be None at this point" - # keys = [ - # ATOM_KEY, - # FRAC_KEY, - # A_KEY, - # B_KEY, - # C_KEY, - # ANGLES_KEY, - # VOLUME_KEY, - # SPACE_GROUP_KEY, - # DISTANCE_KEY, - # ] - [a.shape for a in np.array_split(data, [12], axis=1)] zero_pad = 12 @@ -567,15 +771,33 @@ def average_vert_horz(vert, horz): distance_arr, ) - def arrays_to_structures(self, data: np.ndarray): + def arrays_to_structures( + self, + data: np.ndarray, + id_data: Optional[np.ndarray] = None, + id_mapper: Optional[dict] = None, + ): """Convert scaled crystal (xtal) arrays to pymatgen Structures. Parameters ---------- data : np.ndarray 3D array containing crystallographic information. + + id_data : ArrayLike + Same shape as ``data``, except one-hot encoded to distinguish between the + various types of information contained in ``data``. See ``id_mapper`` for + the "legend" for this data. + + id_mapper : ArrayLike + Dictionary containing the legend/key between the names of the blocks and the + corresponding numbers in ``id_data``. """ - arrays = self.disassemble_blocks(data) + if not isinstance(data, np.ndarray): + raise ValueError( + f"`data` should be of type `np.ndarray`. Received type {type(data)}. Maybe you passed a tuple of (data, id_data, id_mapper) returned from `structures_to_arrays()` by accident?" # noqa: E501 + ) + arrays = self.disassemble_blocks(data, id_data=id_data, id_mapper=id_mapper) ( atom_scaled, @@ -619,10 +841,13 @@ def arrays_to_structures(self, data: np.ndarray): atomic_numbers = np.round(atomic_numbers).astype(int) - # REVIEW: round fractional coordinates to nearest multiple? - # TODO: tweak lattice parameters to match predicted space group rules + if self.relax_on_decode: + if not self.verbose: + tf.get_logger().setLevel(logging.ERROR) + relaxer = Relaxer() # This loads the default pre-trained model + # build Structure-s S: List[Structure] = [] for i in range(len(atomic_numbers)): @@ -642,12 +867,39 @@ def arrays_to_structures(self, data: np.ndarray): a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma ) structure = Structure(lattice, at, fr) + + # REVIEW: round fractional coordinates to nearest multiple? + if self.relax_on_decode: + relaxed_results = relaxer.relax(structure, verbose=self.verbose) + structure = relaxed_results["final_structure"] + + # relax_results = relaxer.relax() + # final_structure = relax_results["final_structure"] + # final_energy = relax_results["trajectory"].energies[-1] / 2 + + # print( + # f"Relaxed lattice parameter is + # {final_structure.lattice.abc[0]:.3f} Å" + # ) + # # TODO: print the initial energy as well (assuming it's available) + # print(f"Final energy is {final_energy.item(): .3f} eV/atom") + spa = SpacegroupAnalyzer( - structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance + structure, + symprec=self.decode_symprec, + angle_tolerance=self.decode_angle_tolerance, ) - structure = spa.get_refined_structure() + if self.decode_as_primitive: + structure = spa.get_primitive_standard_structure() + else: + structure = spa.get_refined_structure() + S.append(structure) + if self.relax_on_decode: + # restore default https://stackoverflow.com/a/51340381/13697228 + sys.stdout = sys.__stdout__ + return S diff --git a/tests/xtal2png_test.py b/tests/xtal2png_test.py index d6b6e36..7e3052e 100644 --- a/tests/xtal2png_test.py +++ b/tests/xtal2png_test.py @@ -5,11 +5,13 @@ from warnings import warn import plotly.express as px -from numpy.testing import assert_allclose, assert_equal +from numpy.testing import assert_allclose, assert_array_equal, assert_equal from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from xtal2png.core import XtalConverter from xtal2png.utils.data import ( + dummy_structures, element_wise_scaler, element_wise_unscaler, example_structures, @@ -22,16 +24,20 @@ rgb_loose_tol = 1.5 / 255 -def assert_structures_approximate_match(example_structures, structures): +def assert_structures_approximate_match( + example_structures, structures, tol_multiplier=1.0 +): for i, (s, structure) in enumerate(zip(example_structures, structures)): - # d = np.linalg.norm(s._lattice.abc) - # sm = StructureMatcher( - # ltol=rgb_loose_tol * d, - # stol=rgb_loose_tol * d, - # angle_tol=rgb_loose_tol * 180, - # comparator=ElementComparator(), - # ) - sm = StructureMatcher(comparator=ElementComparator()) + dummy_matcher = StructureMatcher() + ltol = dummy_matcher.ltol * tol_multiplier + stol = dummy_matcher.stol * tol_multiplier + angle_tol = dummy_matcher.angle_tol * tol_multiplier + sm = StructureMatcher( + ltol=ltol, + stol=stol, + angle_tol=angle_tol, + comparator=ElementComparator(), + ) is_match = sm.fit(s, structure) if not is_match: warn( @@ -60,35 +66,35 @@ def assert_structures_approximate_match(example_structures, structures): assert_allclose( a_check, latt_a, - rtol=rgb_loose_tol, + rtol=rgb_loose_tol * tol_multiplier, err_msg="lattice parameter length `a` not all close", ) assert_allclose( b_check, latt_b, - rtol=rgb_loose_tol, + rtol=rgb_loose_tol * tol_multiplier, err_msg="lattice parameter length `b` not all close", ) assert_allclose( c_check, latt_c, - rtol=rgb_loose_tol * 2, + rtol=rgb_loose_tol * 2 * tol_multiplier, err_msg="lattice parameter length `c` not all close", ) assert_allclose( angles_check, angles, - rtol=rgb_loose_tol, + rtol=rgb_loose_tol * tol_multiplier, err_msg="lattice parameter angles not all close", ) assert_allclose( atomic_numbers_check, atomic_numbers, - rtol=rgb_loose_tol, + rtol=rgb_loose_tol * tol_multiplier, err_msg="atomic numbers not all close", ) @@ -96,7 +102,7 @@ def assert_structures_approximate_match(example_structures, structures): assert_allclose( frac_coords_check, frac_coords, - atol=rgb_tol, + atol=rgb_tol * tol_multiplier, err_msg="atomic numbers not all close", ) @@ -108,54 +114,54 @@ def assert_structures_approximate_match(example_structures, structures): def test_structures_to_arrays(): - xc = XtalConverter() - data = xc.structures_to_arrays(example_structures) + xc = XtalConverter(relax_on_decode=False) + data, _, _ = xc.structures_to_arrays(example_structures) return data def test_structures_to_arrays_single(): - xc = XtalConverter() + xc = XtalConverter(relax_on_decode=False) data, _, _ = xc.structures_to_arrays([example_structures[0]]) return data def test_arrays_to_structures(): - xc = XtalConverter() - data, _, _ = xc.structures_to_arrays(example_structures) - structures = xc.arrays_to_structures(data) + xc = XtalConverter(relax_on_decode=False) + data, id_data, id_mapper = xc.structures_to_arrays(example_structures) + structures = xc.arrays_to_structures(data, id_data, id_mapper) assert_structures_approximate_match(example_structures, structures) return structures def test_arrays_to_structures_single(): - xc = XtalConverter() - data, _, _ = xc.structures_to_arrays([example_structures[0]]) - structures = xc.arrays_to_structures(data) + xc = XtalConverter(relax_on_decode=False) + data, id_data, id_mapper = xc.structures_to_arrays([example_structures[0]]) + structures = xc.arrays_to_structures(data, id_data, id_mapper) assert_structures_approximate_match([example_structures[0]], structures) return structures def test_xtal2png(): - xc = XtalConverter() + xc = XtalConverter(relax_on_decode=False) imgs = xc.xtal2png(example_structures, show=False, save=True) return imgs def test_xtal2png_single(): - xc = XtalConverter() + xc = XtalConverter(relax_on_decode=False) imgs = xc.xtal2png([example_structures[0]], show=False, save=True) return imgs def test_png2xtal(): - xc = XtalConverter() + xc = XtalConverter(relax_on_decode=False) imgs = xc.xtal2png(example_structures, show=True, save=True) decoded_structures = xc.png2xtal(imgs) assert_structures_approximate_match(example_structures, decoded_structures) def test_png2xtal_single(): - xc = XtalConverter() + xc = XtalConverter(relax_on_decode=False) imgs = xc.xtal2png([example_structures[0]], show=True, save=True) decoded_structures = xc.png2xtal(imgs, save=False) assert_structures_approximate_match([example_structures[0]], decoded_structures) @@ -163,7 +169,7 @@ def test_png2xtal_single(): def test_png2xtal_rgb_image(): - xc = XtalConverter() + xc = XtalConverter(relax_on_decode=False) imgs = xc.xtal2png(example_structures, show=False, save=False) imgs = [img.convert("RGB") for img in imgs] decoded_structures = xc.png2xtal(imgs) @@ -171,6 +177,71 @@ def test_png2xtal_rgb_image(): return decoded_structures +def test_primitive_encoding(): + xc = XtalConverter( + symprec=0.1, + angle_tolerance=5.0, + encode_as_primitive=True, + decode_as_primitive=False, + relax_on_decode=False, + ) + input_structures = [ + SpacegroupAnalyzer( + s, symprec=0.1, angle_tolerance=5.0 + ).get_conventional_standard_structure() + for s in example_structures + ] + data, id_data, id_mapper = xc.structures_to_arrays(input_structures) + decoded_structures = xc.arrays_to_structures(data, id_data, id_mapper) + assert_structures_approximate_match( + example_structures, decoded_structures, tol_multiplier=2.0 + ) + return decoded_structures + + +def test_primitive_decoding(): + xc = XtalConverter( + symprec=0.1, + angle_tolerance=5.0, + encode_as_primitive=False, + decode_as_primitive=True, + relax_on_decode=False, + ) + input_structures = [ + SpacegroupAnalyzer( + s, symprec=0.1, angle_tolerance=5.0 + ).get_conventional_standard_structure() + for s in example_structures + ] + data, id_data, id_mapper = xc.structures_to_arrays(input_structures) + decoded_structures = xc.arrays_to_structures(data, id_data, id_mapper) + # decoded has to be conventional too for compatibility with `get_s1_like_s2` + decoded_structures = [ + SpacegroupAnalyzer( + s, symprec=0.1, angle_tolerance=5.0 + ).get_conventional_standard_structure() + for s in decoded_structures + ] + assert_structures_approximate_match( + example_structures, decoded_structures, tol_multiplier=2.0 + ) + return decoded_structures + + +def test_fit(): + xc = XtalConverter(relax_on_decode=False) + xc.fit(example_structures + dummy_structures) + assert_array_equal((14, 82), xc.atom_range) + assert_allclose((3.84, 12.718448099999998), xc.a_range) + assert_allclose((3.395504, 11.292530369999998), xc.b_range) + assert_allclose((3.84, 10.6047314973), xc.c_range) + assert_array_equal((0.0, 180.0), xc.angles_range) + assert_allclose((12, 227), xc.space_group_range) + assert_allclose((40.03858081023111, 611.6423774462978), xc.volume_range) + assert_allclose((1.383037596160554, 7.8291318247510695), xc.distance_range) + assert_equal(44, xc.num_sites) + + def test_element_wise_scaler_unscaler(): check_input = [[1, 2], [3, 4]] feature_range = [1, 4] @@ -204,6 +275,16 @@ def test_rgb_scaler_unscaler(): assert_allclose(check_output, scaled) +def test_relax_on_decode(): + xc = XtalConverter(relax_on_decode=True) + imgs = xc.xtal2png(example_structures, show=False, save=False) + decoded_structures = xc.png2xtal(imgs) + assert_structures_approximate_match( + example_structures, decoded_structures, tol_multiplier=4.0 + ) + return decoded_structures + + def test_plot_and_save(): df = px.data.tips() fig = px.histogram(df, x="day") @@ -214,6 +295,10 @@ def test_plot_and_save(): if __name__ == "__main__": + test_relax_on_decode() + test_primitive_decoding() + test_primitive_encoding() + test_fit() test_png2xtal_rgb_image() test_element_wise_scaler_unscaler() test_rgb_scaler_unscaler() @@ -229,3 +314,10 @@ def test_plot_and_save(): test_png2xtal_single() 1 + 1 + +# %% Code Graveyard +# from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +# spa = SpacegroupAnalyzer(s, symprec=0.1, angle_tolerance=5.0) +# s = spa.get_refined_structure() +# spa = SpacegroupAnalyzer(structure, symprec=0.1, angle_tolerance=5.0) +# structure = spa.get_refined_structure()