diff --git a/structuretoolkit/analyse/spatial.py b/structuretoolkit/analyse/spatial.py index 5b4fc988f..252819a1f 100644 --- a/structuretoolkit/analyse/spatial.py +++ b/structuretoolkit/analyse/spatial.py @@ -88,9 +88,9 @@ def set_to_high_symmetry_points(positions, structure, neigh, decimals=4): def cluster_by_steinhardt(positions, neigh, l_values, q_eps, var_ratio, min_samples): """ Clusters candidate positions via Steinhardt parameters and the variance in distances to host atoms. - + The cluster that has the lowest variance is returned, i.e. those positions that have the most "regular" coordination polyhedra. - + Args: positions (array): candidate positions neigh (Neighbor): neighborhood information of the candidate positions @@ -98,7 +98,7 @@ def cluster_by_steinhardt(positions, neigh, l_values, q_eps, var_ratio, min_samp q_eps (float): maximum intercluster distance in steinhardt parameters for DBSCAN clustering var_ratio (float): multiplier to make steinhardt's and distance variance numerically comparable min_samples (int): minimum size of clusters - + Returns: array: Positions of the most likely interstitial sites """ diff --git a/structuretoolkit/analyse/strain.py b/structuretoolkit/analyse/strain.py index 3632c9dda..f9ebb6aaa 100644 --- a/structuretoolkit/analyse/strain.py +++ b/structuretoolkit/analyse/strain.py @@ -6,7 +6,6 @@ class Strain: - """ Calculate local strain of each atom following the Lagrangian strain tensor: diff --git a/structuretoolkit/analyse/symmetry.py b/structuretoolkit/analyse/symmetry.py index 9fd72b1d2..4bb4e6f4d 100644 --- a/structuretoolkit/analyse/symmetry.py +++ b/structuretoolkit/analyse/symmetry.py @@ -24,7 +24,6 @@ class Symmetry(dict): - """ Return a class for operations related to box symmetries. Main attributes: diff --git a/structuretoolkit/visualize.py b/structuretoolkit/visualize.py index ae688c441..dfb59c6f1 100644 --- a/structuretoolkit/visualize.py +++ b/structuretoolkit/visualize.py @@ -120,6 +120,7 @@ def plot3d( elif mode == "plotly": return _plot3d_plotly( structure=structure, + show_cell=show_cell, camera=camera, particle_size=particle_size, select_atoms=select_atoms, @@ -143,8 +144,20 @@ def plot3d( raise ValueError("plot method not recognized") +def _get_box_skeleton(cell): + lines_dz = np.stack(np.meshgrid(*3 * [[0, 1]], indexing="ij"), axis=-1) + # eight corners of a unit cube, paired as four z-axis lines + + all_lines = np.reshape( + [np.roll(lines_dz, i, axis=-1) for i in range(3)], (-1, 2, 3) + ) + # All 12 two-point lines on the unit square + return all_lines @ cell + + def _plot3d_plotly( structure, + show_cell=True, scalar_field=None, select_atoms=None, particle_size=1.0, @@ -177,6 +190,7 @@ def _plot3d_plotly( """ try: import plotly.express as px + import plotly.graph_objects as go except ModuleNotFoundError: raise ModuleNotFoundError("plotly not installed - use plot3d instead") if select_atoms is None: @@ -196,6 +210,13 @@ def _plot3d_plotly( scale=particle_size / (0.1 * structure.get_volume() ** (1 / 3)), ), ) + if show_cell: + data = fig.data + for lines in _get_box_skeleton(structure.cell): + fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)}) + fig.update_traces(line_color="#000000") + data = fig.data + data + fig = go.Figure(data=data) fig.layout.scene.camera.projection.type = camera rot = _get_orientation(view_plane).T rot[0, :] *= distance_from_camera * 1.25 @@ -206,6 +227,7 @@ def _plot3d_plotly( fig.update_layout(scene_camera=angle) fig.update_traces(marker=dict(line=dict(width=0.1, color="DarkSlateGrey"))) fig.update_scenes(aspectmode="data") + fig.update_layout(legend={"itemsizing": "constant"}) return fig diff --git a/tests/test_visualize.py b/tests/test_visualize.py index 5abc2ae58..c649865a8 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -4,7 +4,7 @@ import unittest import numpy as np -from structuretoolkit.visualize import _get_flattened_orientation +from structuretoolkit.visualize import _get_flattened_orientation, _get_box_skeleton class TestAtoms(unittest.TestCase): @@ -21,6 +21,25 @@ def test_get_flattened_orientation(self): R = np.array(_get_flattened_orientation(R, 1)).reshape(4, 4) self.assertAlmostEqual(np.linalg.det(R), 1) + def test_get_frame(self): + frame = _get_box_skeleton(np.eye(3)) + self.assertLessEqual( + np.unique(frame.reshape(-1, 6), axis=0, return_counts=True)[1].max(), + 1 + ) + dx, counts = np.unique( + np.diff(frame, axis=-2).squeeze().astype(int), axis=0, return_counts=True + ) + self.assertEqual( + dx.ptp(), 1, msg="Frames not drawn along the nearest edges" + ) + msg = ( + "There must be four lines along each direction" + + " (4 x [1, 0, 0], 4 x [0, 1, 0] and 4 x [0, 0, 1])" + ) + self.assertEqual(counts.min(), 4, msg=msg) + self.assertEqual(counts.max(), 4, msg=msg) + if __name__ == "__main__": unittest.main()