In [None]:
import numpy as np
from scipy import special
from e3nn import o3
import torch

In [None]:
def euler_rotation_matrix(alpha, beta, gamma):
    """Return the rotation matrix for the Euler angles alpha, beta and gamma."""

    rotation_matrix = np.zeros((3, 3))

    rotation_matrix[0, 0] = np.cos(beta) * np.cos(gamma)
    rotation_matrix[0, 1] = np.sin(alpha) * np.sin(beta) * np.cos(gamma) - np.cos(alpha) * np.sin(gamma)
    rotation_matrix[0, 2] = np.cos(alpha) * np.sin(beta) * np.cos(gamma) + np.sin(alpha) * np.sin(gamma)
    rotation_matrix[1, 0] = np.cos(beta) * np.sin(gamma)
    rotation_matrix[1, 1] = np.sin(alpha) * np.sin(beta) * np.sin(gamma) + np.cos(alpha) * np.cos(gamma)
    rotation_matrix[1, 2] = np.cos(alpha) * np.sin(beta) * np.sin(gamma) - np.sin(alpha) * np.cos(gamma)
    rotation_matrix[2, 0] = -np.sin(beta)
    rotation_matrix[2, 1] = np.sin(alpha) * np.cos(beta)
    rotation_matrix[2, 2] = np.cos(alpha) * np.cos(beta)

    return rotation_matrix

def cartesian_to_spherical(cartesian):
    """Convert cartesian coordinates to spherical coordinates."""
    r = np.linalg.norm(cartesian)
    phi = np.arccos(cartesian[2] / r)
    theta = np.arctan2(cartesian[1], cartesian[0])
    return np.array([r, theta, phi])


def convert_to_real_spherical_harmonic(complex_spherical_harmonic):
    """Convert the complex spherical harmonics to real spherical harmonics."""
    real_spherical_harmonic = [
        -1**-2 / 1j / np.sqrt(2) * (complex_spherical_harmonic[0] - complex_spherical_harmonic[0].conjugate()),
        -1**-1 / 1j/ np.sqrt(2) * (complex_spherical_harmonic[1] - complex_spherical_harmonic[1].conjugate()),
        complex_spherical_harmonic[2],
        -1**1 / np.sqrt(2) * (complex_spherical_harmonic[3] + complex_spherical_harmonic[3].conjugate()),
        -1**2 / np.sqrt(2) * (complex_spherical_harmonic[4] + complex_spherical_harmonic[4].conjugate()),
    ]
    real_spherical_harmonic = np.array(real_spherical_harmonic)
    real_spherical_harmonic = np.real(real_spherical_harmonic)
    return real_spherical_harmonic

In [None]:
cartesian = np.array([0.2, 0.6, 0.3])
print("Original cartesian coordinates:", cartesian)
angles = 2 * np.pi * np.random.rand(3)
print("Rotated by angles:", angles)
rotation_matrix = euler_rotation_matrix(*angles) 
rotated_cartesian = rotation_matrix @ cartesian 
print("Rotated cartesian coordinates:", rotated_cartesian)

In [None]:
spherical = cartesian_to_spherical(cartesian)
scipy_sph = [
    special.sph_harm(-2, 2, spherical[1], spherical[2]),
    special.sph_harm(-1, 2, spherical[1], spherical[2]),
    special.sph_harm(0, 2, spherical[1], spherical[2]),
    special.sph_harm(1, 2, spherical[1], spherical[2]),
    special.sph_harm(2, 2, spherical[1], spherical[2]),
]
scipy_spherical_harmonic = convert_to_real_spherical_harmonic(scipy_sph)
print("Scipy spherical harmonic:", scipy_spherical_harmonic)
spherical_rotated = cartesian_to_spherical(rotated_cartesian)
scipy_sph_rotated = [
    special.sph_harm(-2, 2, spherical_rotated[1], spherical_rotated[2]),
    special.sph_harm(-1, 2, spherical_rotated[1], spherical_rotated[2]),
    special.sph_harm(0, 2, spherical_rotated[1], spherical_rotated[2]),
    special.sph_harm(1, 2, spherical_rotated[1], spherical_rotated[2]),
    special.sph_harm(2, 2, spherical_rotated[1], spherical_rotated[2]),
]
scipy_spherical_harmonic_rotated = convert_to_real_spherical_harmonic(scipy_sph_rotated)
print("Scipy spherical harmonic rotated:", scipy_spherical_harmonic_rotated)

In [None]:
def rotation_matrix_from_vectors(vec1, vec2):
    """ Find the rotation matrix that aligns vec1 to vec2
    :param vec1: A 3d "source" vector
    :param vec2: A 3d "destination" vector
    :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
    """
    a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
    v = np.cross(a, b)
    c = np.dot(a, b)
    s = np.linalg.norm(v)
    kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
    rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
    return rotation_matrix

In [None]:
tensor_cartesian = torch.tensor(cartesian)
tensor_rotated_cartesian = torch.tensor(rotated_cartesian)
tensor_cartesian = torch.stack([tensor_cartesian[1], tensor_cartesian[2], tensor_cartesian[0]])
tensor_rotated_cartesian = torch.stack([tensor_rotated_cartesian[1], tensor_rotated_cartesian[2], tensor_rotated_cartesian[0]])
spherical_harmonic = o3.spherical_harmonics(2, tensor_cartesian, normalize=True)
print("e3nn spherical harmonic:", spherical_harmonic)
rotated_spherical_harmonic = o3.spherical_harmonics(2, tensor_rotated_cartesian, normalize=True)
print("Rotated e3nn spherical harmonic:", rotated_spherical_harmonic)

In [None]:
# Determine the rotation matrix between tensor_cartesian and tensor_rotated_cartesian
rotation_matrix = rotation_matrix_from_vectors(tensor_cartesian.detach().numpy(), tensor_rotated_cartesian.detach().numpy())
rotation_matrix = torch.tensor(rotation_matrix)
wigner_d_matrix = o3.Irreps('1x2e').D_from_matrix(rotation_matrix)
DY0 = wigner_d_matrix @ spherical_harmonic
print(DY0)
print(rotated_spherical_harmonic)
