In [None]:
import numpy as np
import torch
from e3nn.o3 import spherical_harmonics
import plotly.express as px
import plotly.graph_objects as go

In [None]:
# Create a tensor which is a grid of x, y, z positions in space with center at (0, 0, 0)
# and with a spacing of 0.1 between each point
x = torch.linspace(-100, 100, 21)
y = torch.linspace(-100, 100, 21)
z = torch.linspace(-100, 100, 21)
grid = torch.stack(torch.meshgrid(x, y, z), dim=-1).reshape(-1, 3)
print(grid.shape)

In [None]:
s_sh = spherical_harmonics(1, grid, normalize=True)
print(s_sh.shape)
# Plot the spherical harmonic
fig = px.scatter_3d(grid, x=0, y=1, z=2, color=s_sh[:, 2])
fig.show()
# Plot the z-averaged dimension by reshaping the tensor
reshaped_s = s_sh.reshape(21, 21, 21, 3)

xy_averaged_s = reshaped_s[...,2].mean(dim=(0, 1))
fig = px.scatter(x=z, y=xy_averaged_s)
fig.show()

In [None]:
# Plot the iso-surface of the spherical harmonic
fig = go.Figure(data=go.Isosurface(
    x=grid[:, 0], y=grid[:, 1], z=grid[:, 2],
    value=s_sh[:, 0],
    isomin=-0.1,
    isomax=0.1,
    surface_count=1,
    # caps=dict(x_show=False, y_show=False, z_show=False)
))
fig.show()