Skip to content
Merged
6 changes: 3 additions & 3 deletions structuretoolkit/analyse/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,17 @@ def set_to_high_symmetry_points(positions, structure, neigh, decimals=4):
def cluster_by_steinhardt(positions, neigh, l_values, q_eps, var_ratio, min_samples):
"""
Clusters candidate positions via Steinhardt parameters and the variance in distances to host atoms.

The cluster that has the lowest variance is returned, i.e. those positions that have the most "regular" coordination polyhedra.

Args:
positions (array): candidate positions
neigh (Neighbor): neighborhood information of the candidate positions
l_values (list of int): which steinhardt parameters to use for clustering
q_eps (float): maximum intercluster distance in steinhardt parameters for DBSCAN clustering
var_ratio (float): multiplier to make steinhardt's and distance variance numerically comparable
min_samples (int): minimum size of clusters

Returns:
array: Positions of the most likely interstitial sites
"""
Expand Down
1 change: 0 additions & 1 deletion structuretoolkit/analyse/strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Strain:

"""
Calculate local strain of each atom following the Lagrangian strain tensor:

Expand Down
1 change: 0 additions & 1 deletion structuretoolkit/analyse/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class Symmetry(dict):

"""

Return a class for operations related to box symmetries. Main attributes:
Expand Down
22 changes: 22 additions & 0 deletions structuretoolkit/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def plot3d(
elif mode == "plotly":
return _plot3d_plotly(
structure=structure,
show_cell=show_cell,
camera=camera,
particle_size=particle_size,
select_atoms=select_atoms,
Expand All @@ -143,8 +144,20 @@ def plot3d(
raise ValueError("plot method not recognized")


def _get_box_skeleton(cell):
lines_dz = np.stack(np.meshgrid(*3 * [[0, 1]], indexing="ij"), axis=-1)
# eight corners of a unit cube, paired as four z-axis lines

all_lines = np.reshape(
[np.roll(lines_dz, i, axis=-1) for i in range(3)], (-1, 2, 3)
)
# All 12 two-point lines on the unit square
return all_lines @ cell


def _plot3d_plotly(
structure,
show_cell=True,
scalar_field=None,
select_atoms=None,
particle_size=1.0,
Expand Down Expand Up @@ -177,6 +190,7 @@ def _plot3d_plotly(
"""
try:
import plotly.express as px
import plotly.graph_objects as go
except ModuleNotFoundError:
raise ModuleNotFoundError("plotly not installed - use plot3d instead")
if select_atoms is None:
Expand All @@ -196,6 +210,13 @@ def _plot3d_plotly(
scale=particle_size / (0.1 * structure.get_volume() ** (1 / 3)),
),
)
if show_cell:
data = fig.data
for lines in _get_box_skeleton(structure.cell):
fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)})
fig.update_traces(line_color="#000000")
data = fig.data + data
fig = go.Figure(data=data)
fig.layout.scene.camera.projection.type = camera
rot = _get_orientation(view_plane).T
rot[0, :] *= distance_from_camera * 1.25
Expand All @@ -206,6 +227,7 @@ def _plot3d_plotly(
fig.update_layout(scene_camera=angle)
fig.update_traces(marker=dict(line=dict(width=0.1, color="DarkSlateGrey")))
fig.update_scenes(aspectmode="data")
fig.update_layout(legend={"itemsizing": "constant"})
return fig


Expand Down
21 changes: 20 additions & 1 deletion tests/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import unittest
import numpy as np
from structuretoolkit.visualize import _get_flattened_orientation
from structuretoolkit.visualize import _get_flattened_orientation, _get_box_skeleton


class TestAtoms(unittest.TestCase):
Expand All @@ -21,6 +21,25 @@ def test_get_flattened_orientation(self):
R = np.array(_get_flattened_orientation(R, 1)).reshape(4, 4)
self.assertAlmostEqual(np.linalg.det(R), 1)

def test_get_frame(self):
frame = _get_box_skeleton(np.eye(3))
self.assertLessEqual(
np.unique(frame.reshape(-1, 6), axis=0, return_counts=True)[1].max(),
1
)
dx, counts = np.unique(
np.diff(frame, axis=-2).squeeze().astype(int), axis=0, return_counts=True
)
self.assertEqual(
dx.ptp(), 1, msg="Frames not drawn along the nearest edges"
)
msg = (
"There must be four lines along each direction"
+ " (4 x [1, 0, 0], 4 x [0, 1, 0] and 4 x [0, 0, 1])"
)
self.assertEqual(counts.min(), 4, msg=msg)
self.assertEqual(counts.max(), 4, msg=msg)


if __name__ == "__main__":
unittest.main()