-
Notifications
You must be signed in to change notification settings - Fork 3
/
core.py
1030 lines (885 loc) · 37.1 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Crystal to PNG conversion core functions and scripts."""
import argparse
import logging
import sys
from glob import glob
# from itertools import zip_longest
from os import PathLike, path
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
from uuid import uuid4
from warnings import warn
import numpy as np
import pandas as pd
from m3gnet.models import Relaxer
from numpy.typing import NDArray
from PIL import Image
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from tqdm import tqdm
from xtal2png import __version__
from xtal2png.utils.data import dummy_structures, rgb_scaler, rgb_unscaler
# from sklearn.preprocessing import MinMaxScaler
__author__ = "sgbaird"
__copyright__ = "sgbaird"
__license__ = "MIT"
_logger = logging.getLogger(__name__)
# ---- Python API ----
# The functions defined in this section can be imported by users in their
# Python scripts/interactive interpreter, e.g. via `from xtal2png.core import
# XtalConverter`, when using this Python module as a library.
ATOM_ID = 1
FRAC_ID = 2
A_ID = 3
B_ID = 4
C_ID = 5
ANGLES_ID = 6
VOLUME_ID = 7
SPACE_GROUP_ID = 8
DISTANCE_ID = 9
ATOM_KEY = "atom"
FRAC_KEY = "frac"
A_KEY = "latt_a"
B_KEY = "latt_b"
C_KEY = "latt_c"
ANGLES_KEY = "angles"
VOLUME_KEY = "volume"
SPACE_GROUP_KEY = "space_group"
DISTANCE_KEY = "distance"
def construct_save_name(s: Structure):
save_name = f"{s.formula.replace(' ', '')},volume={int(np.round(s.volume))},uid={str(uuid4())[0:4]}" # noqa: E501
return save_name
class XtalConverter:
"""Convert between pymatgen Structure object and PNG-encoded representation.
Note that if you modify the ranges to be different than their defaults, you have
effectively created a new representation. In the future, anytime you use
:func:`XtalConverter` with a dataset that used modified range(s), you will need to
specify the same ranges; otherwise, your data will be decoded (unscaled)
incorrectly. In other words, make sure you're using the same :func:`XtalConverter`
object for both encoding and decoding.
We encourage you to use the default ranges, which were carefully selected based on a
trade-off between keeping the range as low as possible and trying to incorporate as
much of what's been observed on Materials Project with no more than 52 sites. For
more details, see the corresponding notebook in the ``notebooks`` directory:
https://github.com/sparks-baird/xtal2png/tree/main/notebooks
Parameters
----------
atom_range : Tuple[int, int], optional
Expected range for atomic number, by default (0, 117)
frac_range : Tuple[float, float], optional
Expected range for fractional coordinates, by default (0.0, 1.0)
a_range : Tuple[float, float], optional
Expected range for lattice parameter length a, by default (2.0, 15.3)
b_range : Tuple[float, float], optional
Expected range for lattice parameter length b, by default (2.0, 15.0)
c_range : Tuple[float, float], optional
Expected range for lattice parameter length c, by default (2.0, 36.0)
angles_range : Tuple[float, float], optional
Expected range for lattice parameter angles, by default (0.0, 180.0)
volume_range : Tuple[float, float], optional
Expected range for unit cell volumes, by default (0.0, 1000.0)
space_group_range : Tuple[int, int], optional
Expected range for space group numbers, by default (1, 230)
distance_range : Tuple[float, float], optional
Expected range for pairwise distances between sites, by default (0.0, 25.0)
max_sites : int, optional
Maximum number of sites to accomodate in encoding, by default 52
save_dir : Union[str, 'PathLike[str]']
Directory to save PNG files via :func:``xtal2png``,
by default path.join("data", "interim")
symprec : Union[float, Tuple[float, float]], optional
The symmetry precision to use when decoding `pymatgen` structures via
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. If
specified as a tuple, then ``symprec[0]`` applies to encoding and ``symprec[1]``
applies to decoding. By default 0.1.
angle_tolerance : Union[float, int, Tuple[float, float], Tuple[int, int]], optional
The angle tolerance (degrees) to use when decoding `pymatgen` structures via
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. If
specified as a tuple, then ``angle_tolerance[0]`` applies to encoding and
``angle_tolerance[1]`` applies to decoding. By default 5.0.
encode_as_primitive : bool, optional
Encode structures as symmetrized, primitive structures. Uses ``symprec`` if
``symprec`` is of type float, else uses ``symprec[0]`` if ``symprec`` is of type
tuple. Same applies for ``angle_tolerance``. By default True
decode_as_primitive : bool, optional
Decode structures as symmetrized, primitive structures. Uses ``symprec`` if
``symprec`` is of type float, else uses ``symprec[1]`` if ``symprec`` is of type
tuple. Same applies for ``angle_tolerance``. By default True
relax_on_decode: bool, optional
Use m3gnet to relax the decoded crystal structures.
Examples
--------
>>> xc = XtalConverter()
>>> xc = XtalConverter(atom_range=(0, 83)) # assumes no radioactive elements in data
"""
def __init__(
self,
atom_range: Tuple[int, int] = (0, 117),
frac_range: Tuple[float, float] = (0.0, 1.0),
a_range: Tuple[float, float] = (2.0, 15.3),
b_range: Tuple[float, float] = (2.0, 15.0),
c_range: Tuple[float, float] = (2.0, 36.0),
angles_range: Tuple[float, float] = (0.0, 180.0),
volume_range: Tuple[float, float] = (0.0, 1500.0),
space_group_range: Tuple[int, int] = (1, 230),
distance_range: Tuple[float, float] = (0.0, 18.0),
max_sites: int = 52,
save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"),
symprec: Union[float, Tuple[float, float]] = 0.1,
angle_tolerance: Union[float, int, Tuple[float, float], Tuple[int, int]] = 5.0,
encode_as_primitive: bool = False,
decode_as_primitive: bool = False,
relax_on_decode: bool = True,
):
"""Instantiate an XtalConverter object with desired ranges and ``max_sites``."""
self.atom_range = atom_range
self.frac_range = frac_range
self.a_range = a_range
self.b_range = b_range
self.c_range = c_range
self.angles_range = angles_range
self.volume_range = volume_range
self.space_group_range = space_group_range
self.distance_range = distance_range
self.max_sites = max_sites
self.save_dir = save_dir
if isinstance(symprec, (float, int)):
self.encode_symprec = symprec
self.decode_symprec = symprec
elif isinstance(symprec, tuple):
self.encode_symprec = symprec[0]
self.decode_symprec = symprec[1]
if isinstance(angle_tolerance, (float, int)):
self.encode_angle_tolerance = angle_tolerance
self.decode_angle_tolerance = angle_tolerance
elif isinstance(angle_tolerance, tuple):
self.encode_angle_tolerance = angle_tolerance[0]
self.decode_angle_tolerance = angle_tolerance[1]
self.encode_as_primitive = encode_as_primitive
self.decode_as_primitive = decode_as_primitive
self.relax_on_decode = relax_on_decode
Path(save_dir).mkdir(exist_ok=True, parents=True)
def xtal2png(
self,
structures: Union[
List[Union[Structure, str, "PathLike[str]"]], str, "PathLike[str]"
],
show: bool = False,
save: bool = True,
):
"""Encode crystal (via CIF filepath or Structure object) as PNG file.
Parameters
----------
structures : List[Union[Structure, str, PathLike[str]]]
pymatgen Structure objects or path to CIF files or path to directory
containing CIF files.
show : bool, optional
Whether to display the PNG-encoded file, by default False
save : bool, optional
Whether to save the PNG-encoded file, by default True
Returns
-------
imgs : List[Image.Image]
PIL images that (approximately) encode the supplied crystal structures.
Raises
------
ValueError
structures should be of same datatype
ValueError
structures should be of same datatype
ValueError
structures should be of type `str`, `os.PathLike` or
`pymatgen.core.structure.Structure`
Examples
--------
>>> coords = [[0, 0, 0], [0.75,0.5,0.75]]
>>> lattice = Lattice.from_parameters(
... a=3.84, b=3.84, c=3.84, alpha=120, beta=90, gamma=60
... )
>>> structures = [Structure(lattice, ["Si", "Si"], coords),
... Structure(lattice, ["Ni", "Ni"], coords)]
>>> xc = XtalConverter()
>>> xc.xtal2png(structures, show=False, save=True)
"""
save_names, structures = self.process_filepaths_or_structures(structures) # type: ignore # noqa: E501
# convert structures to 3D NumPy Matrices
self.data, self.id_data, self.id_mapper = self.structures_to_arrays(structures)
mn, mx = self.data.min(), self.data.max()
if mn < 0:
warn(
f"lower RGB value(s) OOB ({mn} less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)" # noqa: E501
) # noqa
self.data[self.data < 0] = 0
if mx > 255:
warn(
f"upper RGB value(s) OOB ({mx} greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)" # noqa: E501
) # noqa
self.data[self.data > 255] = 255
self.data = self.data.astype(np.uint8)
# convert to PNG images. Save and/or show, if applicable
imgs: List[Image.Image] = []
for d, save_name in zip(self.data, save_names):
img = Image.fromarray(d, mode="L")
imgs.append(img)
if save:
savepath = path.join(self.save_dir, save_name + ".png")
img.save(savepath)
if show:
img.show()
return imgs
def fit(
self,
structures: Union[
List[Union[Structure, str, "PathLike[str]"]], str, "PathLike[str]"
],
y=None,
fit_quantiles=(0.00, 0.99),
verbose=True,
):
_, structures = self.process_filepaths_or_structures(structures) # type: ignore
# TODO: deal with arbitrary site_properties
atomic_numbers = []
a = []
b = []
c = []
space_group = []
volume = []
distance = []
num_sites = []
for s in tqdm(structures):
atomic_numbers.append(s.atomic_numbers)
lattice = s.lattice
a.append(lattice.a)
b.append(lattice.b)
c.append(lattice.c)
space_group.append(s.get_space_group_info()[1])
volume.append(lattice.volume)
distance.append(s.distance_matrix)
num_sites.append(len(list(s.sites)))
if verbose:
print("range of atomic_numbers is: ", min(a), "-", max(a))
print("range of a is: ", min(a), "-", max(a))
print("range of b is: ", min(b), "-", max(b))
print("range of c is: ", min(c), "-", max(c))
print("range of space_group is: ", min(space_group), "-", max(space_group))
print("range of volume is: ", min(volume), "-", max(volume))
print("range of num_sites is: ", min(num_sites), "-", max(num_sites))
dis_min_tmp = []
dis_max_tmp = []
for d in tqdm(range(len(distance))):
dis_min_tmp.append(min(distance[d][np.nonzero(distance[d])]))
dis_max_tmp.append(max(distance[d][np.nonzero(distance[d])]))
atoms = np.array(atomic_numbers, dtype="object")
self.atom_range = (min(np.min(atoms)), max(np.max(atoms)))
self.space_group_range = (np.min(space_group), np.max(space_group))
self.num_sites = np.max(num_sites)
df = pd.DataFrame(
dict(
a=a,
b=b,
c=c,
volume=volume,
min_distance=dis_min_tmp,
max_distance=dis_max_tmp,
)
)
low_quantile, upp_quantile = fit_quantiles
low_df = (
df.apply(lambda a: np.quantile(a, low_quantile))
.drop(["max_distance"])
.rename(index={"min_distance": "distance"})
)
upp_df = (
df.apply(lambda a: np.quantile(a, upp_quantile))
.drop(["min_distance"])
.rename(index={"max_distance": "distance"})
)
low_df.name = "low"
upp_df.name = "upp"
range_df = pd.concat((low_df, upp_df), axis=1)
for name, bounds in range_df.iterrows():
setattr(self, name + "_range", tuple(bounds))
def process_filepaths_or_structures(
self, structures: Union[List[str], List[PathLike], List[Structure]]
) -> Tuple[List[str], List[Structure]]:
"""Extract (or create) save names and convert/passthrough the structures.
Parameters
----------
structures : Union[PathLike, Structure]
List of filepaths or list of structures to be processed.
Returns
-------
save_names : List[str]
Save names of the files if filepaths are passed, otherwise some relatively
unique names (due to 4 random characters being appended at the end) for each
structure. See ``construct_save_name``.
S : List[Structure]
Processed structures.
Raises
------
ValueError
"structures should be of same datatype, either strs or pymatgen Structures.
structures[0] is {type(structures[0])}, but got type {type(s)} for entry
{i}"
ValueError
"structures should be of same datatype, either strs or pymatgen Structures.
structures[0] is {type(structures[0])}, but got type {type(s)} for entry
{i}"
ValueError
"structures should be of type `str`, `os.PathLike` or
`pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})"
Examples
--------
>>> save_names, structures = process_filepaths_or_structures(structures)
"""
save_names: List[str] = []
first_is_structure = isinstance(structures[0], Structure)
for i, s in enumerate(structures):
if isinstance(s, str) or isinstance(s, PathLike):
if first_is_structure:
raise ValueError(
f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa: E501
)
structures[i] = Structure.from_file(s)
save_names.append(Path(str(s)).stem)
elif isinstance(s, Structure):
if not first_is_structure:
raise ValueError(
f"structures should be of same datatype, either strs or pymatgen Structures. structures[0] is {type(structures[0])}, but got type {type(s)} for entry {i}" # noqa
)
structures[i] = s
save_names.append(construct_save_name(s))
else:
raise ValueError(
f"structures should be of type `str`, `os.PathLike` or `pymatgen.core.structure.Structure`, not {type(structures[i])} (entry {i})" # noqa
)
return save_names, structures
def png2xtal(
self, images: List[Union[Image.Image, "PathLike"]], save: bool = False
):
"""Decode PNG files as Structure objects.
Parameters
----------
images : List[Union[Image.Image, 'PathLike']]
PIL images that (approximately) encode crystal structures.
Examples
--------
>>> from xtal2png.utils.data import example_structures
>>> xc = XtalConverter()
>>> imgs = xc.xtal2png(example_structures)
>>> xc.png2xtal(imgs)
OUTPUT
"""
data_tmp = []
for img in images:
if isinstance(img, str):
# load image from file
with Image.open(img).convert("L") as im:
data_tmp.append(np.asarray(im))
elif isinstance(img, Image.Image):
data_tmp.append(np.asarray(img.convert("L")))
data = np.stack(data_tmp, axis=0)
S = self.arrays_to_structures(data)
if save:
for s in S:
fpath = path.join(self.save_dir, construct_save_name(s) + ".cif")
CifWriter(
s,
symprec=self.decode_symprec,
angle_tolerance=self.decode_angle_tolerance,
).write_file(fpath)
return S
# unscale values
def structures_to_arrays(
self,
structures: Sequence[Structure],
):
"""Convert pymatgen Structure to scaled 3D array of crystallographic info.
``atomic_numbers`` and ``distance_matrix` get padded or cropped as appropriate,
as these depend on the number of sites in the structure.
Parameters
----------
S : Sequence[Structure]
Sequence (e.g. list) of pymatgen Structure object(s)
Returns
-------
data : ArrayLike
RGB-scaled arrays with first dimension corresponding to each crystal
structure.
id_data : ArrayLike
Same shape as ``data``, except one-hot encoded to distinguish between the
various types of information contained in ``data``. See ``id_mapper`` for
the "legend" for this data.
id_mapper : ArrayLike
Dictionary containing the legend/key between the names of the blocks and the
corresponding numbers in ``id_data``.
Raises
------
ValueError
"`structures` should be a list of pymatgen Structure(s)"
ValueError
"crystal supplied with {n_sites} sites, which is more than {self.max_sites}
sites. Remove crystal or increase `max_sites`."
ValueError
"len(atomic_numbers) {n_sites} and distance_matrix.shape[0]
{s.distance_matrix.shape[0]} do not match"
Examples
--------
>>> xc = XtalConverter()
>>> data = xc.structures_to_arrays(structures)
OUTPUT
"""
if isinstance(structures, Structure):
raise ValueError("`structures` should be a list of pymatgen Structure(s)")
# extract crystallographic information
atomic_numbers: List[List[int]] = []
frac_coords_tmp: List[NDArray] = []
latt_a: List[float] = []
latt_b: List[float] = []
latt_c: List[float] = []
angles: List[List[float]] = []
volume: List[float] = []
space_group: List[int] = []
distance_matrix_tmp: List[NDArray[np.float64]] = []
sym_structures = []
for s in structures:
spa = SpacegroupAnalyzer(
s,
symprec=self.encode_symprec,
angle_tolerance=self.encode_angle_tolerance,
)
if self.encode_as_primitive:
s = spa.get_primitive_standard_structure()
else:
s = spa.get_refined_structure()
sym_structures.append(s)
structures = sym_structures
for s in structures:
n_sites = len(s.atomic_numbers)
if n_sites > self.max_sites:
raise ValueError(
f"crystal supplied with {n_sites} sites, which is more than {self.max_sites} sites. Remove crystal or increase `max_sites`." # noqa
)
atomic_numbers.append(
np.pad(
list(s.atomic_numbers),
(0, self.max_sites - n_sites),
).tolist()
)
frac_coords_tmp.append(
np.pad(s.frac_coords, ((0, self.max_sites - n_sites), (0, 0)))
)
latt_a.append(s._lattice.a)
latt_b.append(s._lattice.b)
latt_c.append(s._lattice.c)
angles.append(list(s._lattice.angles))
volume.append(s.volume)
space_group.append(s.get_space_group_info()[1])
if n_sites != s.distance_matrix.shape[0]:
raise ValueError(
f"len(atomic_numbers) {n_sites} and distance_matrix.shape[0] {s.distance_matrix.shape[0]} do not match" # noqa
) # noqa
# assume that distance matrix is square
padwidth = (0, self.max_sites - n_sites)
distance_matrix_tmp.append(np.pad(s.distance_matrix, padwidth))
# [0:max_sites, 0:max_sites]
frac_coords = np.stack(frac_coords_tmp)
distance_matrix = np.stack(distance_matrix_tmp)
# REVIEW: consider using modified pettifor scale instead of atomic numbers
# REVIEW: consider using feature_range=atom_range or 2*atom_range
# REVIEW: since it introduces a sort of non-linearity b.c. of rounding
atom_scaled = rgb_scaler(atomic_numbers, data_range=self.atom_range)
frac_scaled = rgb_scaler(frac_coords, data_range=self.frac_range)
a_scaled = rgb_scaler(latt_a, data_range=self.a_range)
b_scaled = rgb_scaler(latt_b, data_range=self.b_range)
c_scaled = rgb_scaler(latt_c, data_range=self.c_range)
angles_scaled = rgb_scaler(angles, data_range=self.angles_range)
volume_scaled = rgb_scaler(volume, data_range=self.volume_range)
space_group_scaled = rgb_scaler(space_group, data_range=self.space_group_range)
# NOTE: max_distance could be added as another (repeated) value/row to infer
# NOTE: kind of like frac_distance_matrix, not sure if would be effective
# NOTE: Or could normalize distance_matix by cell volume
# NOTE: and possibly include cell volume as a (repeated) value/row to infer
# NOTE: It's possible extra info like this isn't so bad, instilling the physics
# NOTE: but it could also just be extraneous work to predict/infer
distance_scaled = rgb_scaler(distance_matrix, data_range=self.distance_range)
atom_arr = np.expand_dims(atom_scaled, 2)
frac_arr = frac_scaled
a_arr = np.repeat(np.expand_dims(a_scaled, (1, 2)), self.max_sites, axis=1)
b_arr = np.repeat(np.expand_dims(b_scaled, (1, 2)), self.max_sites, axis=1)
c_arr = np.repeat(np.expand_dims(c_scaled, (1, 2)), self.max_sites, axis=1)
angles_arr = np.repeat(np.expand_dims(angles_scaled, 1), self.max_sites, axis=1)
volume_arr = np.repeat(
np.expand_dims(volume_scaled, (1, 2)), self.max_sites, axis=1
)
space_group_arr = np.repeat(
np.expand_dims(space_group_scaled, (1, 2)), self.max_sites, axis=1
)
distance_arr = distance_scaled
data = self.assemble_blocks(
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
distance_arr,
)
id_mapper = {
ATOM_KEY: ATOM_ID,
FRAC_KEY: FRAC_ID,
A_KEY: A_ID,
B_KEY: B_ID,
C_KEY: C_ID,
ANGLES_KEY: ANGLES_ID,
VOLUME_KEY: VOLUME_ID,
SPACE_GROUP_KEY: SPACE_GROUP_ID,
DISTANCE_KEY: DISTANCE_ID,
}
id_blocks = [
np.ones_like(atom_arr) * ATOM_ID,
np.ones_like(frac_arr) * FRAC_ID,
np.ones_like(a_arr) * A_ID,
np.ones_like(b_arr) * B_ID,
np.ones_like(c_arr) * C_ID,
np.ones_like(angles_arr) * ANGLES_ID,
np.ones_like(volume_arr) * VOLUME_ID,
np.ones_like(space_group_arr) * SPACE_GROUP_ID,
np.ones_like(distance_arr) * DISTANCE_ID,
]
id_data = self.assemble_blocks(*id_blocks)
return data, id_data, id_mapper
def assemble_blocks(
self,
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
distance_arr,
):
arrays = [
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
]
zero_pad = sum([arr.shape[2] for arr in arrays])
n_structures = atom_arr.shape[0]
zero = np.zeros((n_structures, zero_pad, zero_pad), dtype=int)
vertical_arr = np.block(
[
[zero],
[
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
],
]
)
horizontal_arr = np.block(
[
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
]
)
horizontal_arr = np.moveaxis(horizontal_arr, 1, 2)
left_arr = vertical_arr
right_arr = np.block([[horizontal_arr], [distance_arr]])
data = np.block([left_arr, right_arr])
return data
def disassemble_blocks(
self, data, id_data: Optional[NDArray] = None, id_mapper: Optional[dict] = None
):
if (id_data is None) is not (id_mapper is None):
raise ValueError(
f"id_data (type: {type(id_data)}) and id_mapper (type: {type(id_mapper)}) should either both be assigned or both be None." # noqa
)
elif id_data is None and id_mapper is None:
_, id_data, id_mapper = self.structures_to_arrays(dummy_structures)
assert (
id_data is not None and id_mapper is not None
), "id_data and id_mapper should not be None at this point"
[a.shape for a in np.array_split(data, [12], axis=1)]
zero_pad = 12
left_arr, right_arr = np.array_split(data, [zero_pad], axis=1)
_, bottom_left = np.array_split(left_arr, [zero_pad], axis=2)
lengths = [1, 3, 1, 1, 1, 3, 1]
verts = np.array_split(bottom_left, np.cumsum(lengths), axis=1)
top_right, bottom_right = np.array_split(right_arr, [zero_pad], axis=2)
distance_arr = bottom_right
horzs = np.array_split(top_right, np.cumsum(lengths), axis=2)
def average_vert_horz(vert, horz):
vert = np.float64(vert)
horz = np.float64(horz)
avg = (vert.swapaxes(1, 2) + horz) / 2
return avg
avgs = [average_vert_horz(v, h) for v, h in zip(verts, horzs)]
(
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
) = avgs
return (
atom_arr,
frac_arr,
a_arr,
b_arr,
c_arr,
angles_arr,
volume_arr,
space_group_arr,
distance_arr,
)
def arrays_to_structures(
self,
data: np.ndarray,
id_data: Optional[np.ndarray] = None,
id_mapper: Optional[dict] = None,
):
"""Convert scaled crystal (xtal) arrays to pymatgen Structures.
Parameters
----------
data : np.ndarray
3D array containing crystallographic information.
id_data : ArrayLike
Same shape as ``data``, except one-hot encoded to distinguish between the
various types of information contained in ``data``. See ``id_mapper`` for
the "legend" for this data.
id_mapper : ArrayLike
Dictionary containing the legend/key between the names of the blocks and the
corresponding numbers in ``id_data``.
"""
if not isinstance(data, np.ndarray):
raise ValueError(
f"`data` should be of type `np.ndarray`. Received type {type(data)}. Maybe you passed a tuple of (data, id_data, id_mapper) returned from `structures_to_arrays()` by accident?" # noqa: E501
)
arrays = self.disassemble_blocks(data, id_data=id_data, id_mapper=id_mapper)
(
atom_scaled,
frac_scaled,
a_scaled_tmp,
b_scaled_tmp,
c_scaled_tmp,
angles_scaled_tmp,
volume_scaled_tmp,
space_group_scaled_tmp,
distance_scaled,
) = [np.squeeze(arr, axis=2) if arr.shape[2] == 1 else arr for arr in arrays]
a_scaled = np.mean(a_scaled_tmp, axis=1, where=a_scaled_tmp != 0)
b_scaled = np.mean(b_scaled_tmp, axis=1, where=b_scaled_tmp != 0)
c_scaled = np.mean(c_scaled_tmp, axis=1, where=c_scaled_tmp != 0)
angles_scaled = np.mean(angles_scaled_tmp, axis=1, where=angles_scaled_tmp != 0)
volume_scaled = np.mean(volume_scaled_tmp, axis=1)
space_group_scaled = np.round(np.mean(space_group_scaled_tmp, axis=1)).astype(
int
)
atomic_numbers = rgb_unscaler(atom_scaled, data_range=self.atom_range)
frac_coords = rgb_unscaler(frac_scaled, data_range=self.frac_range)
latt_a = rgb_unscaler(a_scaled, data_range=self.a_range)
latt_b = rgb_unscaler(b_scaled, data_range=self.b_range)
latt_c = rgb_unscaler(c_scaled, data_range=self.c_range)
angles = rgb_unscaler(angles_scaled, data_range=self.angles_range)
# # volume, space_group, distance_matrix unecessary for making Structure
volume = rgb_unscaler(volume_scaled, data_range=self.volume_range)
space_group = rgb_unscaler(
space_group_scaled, data_range=self.space_group_range
)
distance_matrix = rgb_unscaler(distance_scaled, data_range=self.distance_range)
for dm in distance_matrix:
np.fill_diagonal(dm, 0.0)
# technically unused, but to avoid issue with pre-commit for now:
volume, space_group, distance_matrix
atomic_numbers = np.round(atomic_numbers).astype(int)
# TODO: tweak lattice parameters to match predicted space group rules
if self.relax_on_decode:
relaxer = Relaxer() # This loads the default pre-trained model
# build Structure-s
S: List[Structure] = []
for i in range(len(atomic_numbers)):
at = atomic_numbers[i]
fr = frac_coords[i]
# di = distance_matrix[i]
site_ids = np.where(at > 0)
at = at[site_ids]
fr = fr[site_ids]
# di_cropped = di[site_ids[0]][:, site_ids[0]]
a, b, c = latt_a[i], latt_b[i], latt_c[i]
alpha, beta, gamma = angles[i]
lattice = Lattice.from_parameters(
a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma
)
structure = Structure(lattice, at, fr)
spa = SpacegroupAnalyzer(
structure,
symprec=self.decode_symprec,
angle_tolerance=self.decode_angle_tolerance,
)
if self.decode_as_primitive:
structure = spa.get_primitive_standard_structure()
else:
structure = spa.get_refined_structure()
# REVIEW: round fractional coordinates to nearest multiple?
if self.relax_on_decode:
structure = relaxer.relax(structure)["final_structure"]
# relax_results = relaxer.relax()
# final_structure = relax_results["final_structure"]
# final_energy = relax_results["trajectory"].energies[-1] / 2
# print(
# f"Relaxed lattice parameter is
# {final_structure.lattice.abc[0]:.3f} Å"
# )
# # TODO: print the initial energy as well (assuming it's available)
# print(f"Final energy is {final_energy.item(): .3f} eV/atom")
S.append(structure)
return S
# ---- CLI ----
# The functions defined in this section are wrappers around the main Python
# API allowing them to be called directly from the terminal as a CLI
# executable/script.
def parse_args(args):
"""Parse command line parameters.
Args:
args (List[str]): command line parameters as list of strings
(for example ``["--help"]``).
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = argparse.ArgumentParser(description="Crystal to PNG encoder/decoder.")
parser.add_argument(
"--version",
action="version",
version="xtal2png {ver}".format(ver=__version__),
)
parser.add_argument(
"-p",
"--path",
dest="fpath",
help="Crystallographic information file (CIF) filepath (extension must be .cif or .CIF) or path to directory containing .cif files or processed PNG filepath or path to directory containing processed .png files (extension must be .png or .PNG). Assumes CIFs if --encode flag is used. Assumes PNGs if --decode flag is used.", # noqa: E501
type=str,
metavar="STRING",
)
parser.add_argument(
"-s",
"--save-dir",
dest="save_dir",
default=".",
help="Directory to save processed PNG files or decoded CIFs to.",
type=str,
metavar="STRING",
)
parser.add_argument(
"--encode",
action="store_true",
help="Encode CIF files as PNG images.",
)
parser.add_argument(
"--decode",
action="store_true",
help="Decode PNG images as CIF files.",
)
parser.add_argument(
"-v",
"--verbose",
dest="loglevel",
help="set loglevel to INFO",
action="store_const",
const=logging.INFO,
)
parser.add_argument(
"-vv",
"--very-verbose",
dest="loglevel",
help="set loglevel to DEBUG",
action="store_const",
const=logging.DEBUG,
)
return parser.parse_args(args)
def setup_logging(loglevel):
"""Setup basic logging
Args:
loglevel (int): minimum loglevel for emitting messages
"""
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
logging.basicConfig(
level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
)
def main(args):
"""Wrapper allowing :func:`XtalConverter()` :func:`xtal2png()` and
:func:`png2xtal()` methods to be called with string arguments in a CLI fashion.
Args:
args (List[str]): command line parameters as list of strings
(for example ``["--verbose", "example.cif"]``).
"""
args = parse_args(args)
setup_logging(args.loglevel)
_logger.debug("Beginning conversion to PNG format")
if args.encode and args.decode:
raise ValueError("Specify --encode or --decode, not both.")
if args.encode:
ext = ".cif"
elif args.decode:
ext = ".png"
else:
raise ValueError("Specify at least one of --encode or --decode")
if Path(args.fpath).suffix in [ext, ext.upper()]:
fpaths = [args.fpath]
elif path.isdir(args.fpath):
fpaths = glob(path.join(args.fpath, f"*{ext}"))
if fpaths == []:
raise ValueError(
f"Assuming --path input is directory to files. No files of type {ext} present in {args.fpath}" # noqa: E501
)
else:
raise ValueError(
f"Input should be a path to a single {ext} file or a path to a directory containing {ext} file(s). Received: {args.fpath}" # noqa: E501
)