/
visualization.py
178 lines (154 loc) · 5.13 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
pyscal visualization module
---------------------------
Used for visualization of system objects in jupyter notebooks and lab. Uses plotly.
"""
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
import itertools
def create_box_plot(box, origin=[0,0,0]):
"""
Create traces which correspond to the simulation cell
Parameters
----------
box : list
dimensions of the simulation box
origin : list, optional
Origin of the simulation box. Default [0, 0, 0]
Returns
-------
traces : list of Scatter3d objects
"""
box = np.array(box)
origin = np.array(origin)
combos = list(itertools.combinations(range(3), 2))
faces = []
for combo in combos:
f1 = [origin, box[combo[0]], box[combo[0]]+box[combo[1]], box[combo[1]], origin]
s = combo[0] + combo[1]
t = 3-s
f2 = [origin + box[t], box[combo[0]]+ box[t], box[combo[0]]+box[combo[1]]+ box[t], box[combo[1]]+ box[t], origin + box[t]]
faces.append(np.array(f1))
faces.append(np.array(f2))
traces = []
for face in faces:
trace = go.Scatter3d(
x=face[:,0],
y=face[:,1],
z=face[:,2],
mode='lines',
name='lines',
line=dict(width=2.0, color='#263238'),
showlegend=False
)
traces.append(trace)
return traces
def plot_3d(pos, color=None, radius=17,
colorscale='Spectral', opacity=1.0,
traces=None, cmap_title=None):
"""
Plot the atoms along with the simulation box
Parameters
----------
pos : list of positions
list of atomic positions
color : list
list of colors to use for plotting
radius : int, optional
radius of plotted atom objects
colorscale : string, optional
color map for coloring atoms
opacity : float, optional
opacity of atoms
traces : box plot objects
cmap_title : string
title of cmap
"""
data=go.Scatter3d(
x=pos[:,0],
y=pos[:,1],
z=pos[:,2],
mode='markers',
opacity=1.0,
marker=dict(
sizemode='diameter',
sizeref=750,
size=radius,
color = color,
opacity = opacity,
colorscale = colorscale,
colorbar=dict(thickness=20, title=cmap_title),
line=dict(width=0.5, color='#455A64')
)
)
traces.append(data)
fig = go.Figure(data=traces)
fig.update_layout(scene = dict(
xaxis_title="",
yaxis_title="",
zaxis_title="",
xaxis = dict(
showticklabels=False,
showbackground=False,
zerolinecolor="#455A64",),
yaxis = dict(
showticklabels=False,
showbackground=False,
zerolinecolor="#455A64"),
zaxis = dict(
showticklabels=False,
showbackground=False,
zerolinecolor="#455A64",),),
width=700,
margin=dict(
r=10, l=10,
b=10, t=10)
)
fig.update_layout(showlegend=False)
fig.show()
def plot_system(sys, colorby=None, filterby=None):
"""
Plot the system
Parameters
----------
sys : System object
colorby : string, optional
property over which the atoms are to be colored. It can be any
attributed of Atom, a custom attribute, or calculated q values which can be accessed
as `qx` or `aqx` where x stands for the q number.
filterby : string, optional
property over which the atoms are to be filtered before plotting.
It can be any attribute of atom, or a custom value of atom. It should provide
a True or False value.
Returns
-------
None
"""
atoms = sys.atoms
positions = []
colors = []
filters = []
for count, atom in enumerate(atoms):
if filterby is not None:
if sys.get_custom(atom, [filterby])[0]:
positions.append(np.array(sys.remap_atom(atom.pos)))
else:
positions.append(np.array(sys.remap_atom(atom.pos)))
if colorby is not None:
cx = sys.get_custom(atom, [colorby])[0]
colors.append(cx)
else:
colors.append(1)
colors = np.array(colors).astype(float)
boxtraces = create_box_plot(sys.box)
if colorby is None:
ctitle = ""
else:
ctitle = colorby
radius = widgets.FloatSlider(min=1, max=30, step=1)
widgets.interact_manual.opts['manual_name'] = 'Render plot'
im = widgets.interact_manual(plot_3d, radius=radius,
pos=widgets.fixed(np.array(positions)), cmap_title=widgets.fixed(ctitle),
color=widgets.fixed(colors), opacity=widgets.fixed(1.0),
traces=widgets.fixed(boxtraces), description="test")