In [None]:
import os
from itertools import product
import math

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import rotate, affine_transform
from scipy import special as sp
from sympy.physics.quantum.spin import Rotation

from src.models.layers import SHConv3DRadial, BSHConv3D, compute_clebschgordan_matrix

In [None]:
def wigner_matrix(j, alpha, beta, gamma):
    output = np.zeros((2 * j + 1, 2 * j + 1), dtype=np.complex128)
    for m1, m2 in product(range(-j, j + 1), repeat=2):
        m1_id = m1 + j
        m2_id = m2 + j
        output[m1_id, m2_id] = Rotation.D(j, m1, m2, alpha, beta, gamma).doit()
    return output


In [None]:
# Calculates Rotation Matrix given euler angles.
def compute_rotation_matrix(alpha, beta, gamma, origin=(0, 0, 0)):
    alpha = np.deg2rad(alpha)
    beta = np.deg2rad(beta)
    gamma = np.deg2rad(gamma)
    R_x = np.array([
        [1, 0, 0],
        [0, math.cos(alpha), -math.sin(alpha)],
        [0, math.sin(alpha), math.cos(alpha)],
    ])

    R_y = np.array([
        [math.cos(beta), 0, math.sin(beta)],
        [0, 1, 0],
        [-math.sin(beta), 0, math.cos(beta)],
    ])

    R_z = np.array([
        [math.cos(gamma), -math.sin(gamma), 0],
        [math.sin(gamma), math.cos(gamma), 0],
        [0, 0, 1],
    ])

    # return np.dot(R_z, np.dot(R_y, R_x))
    translation = np.eye(4)
    translation[:3, 3] = -np.array(origin)

    translation_inv = np.eye(4)
    translation_inv[:3, 3] = np.array(origin)

    matrix = np.eye(4)
    matrix[:3, :3] = R_z @ R_y @ R_x
    return (translation_inv @ matrix @ translation)


In [None]:
D1 = wigner_matrix(1, np.pi / 2, np.pi / 2, np.pi / 2)
D2 = wigner_matrix(4, np.pi / 2, np.pi / 2, np.pi / 2)
cg_mat = compute_clebschgordan_matrix(1, 4)


In [None]:
F = np.array([1, 0, 0])
np.abs(D1)

In [None]:
block_diag_matrix = cg_mat.T @ np.kron(D1, D2) @ cg_mat


In [None]:
plt.imshow(np.abs(block_diag_matrix))
plt.colorbar()

In [None]:
wigner_matrix(2, np.pi, 0, 0)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
kernel_size = 17
layer = SHConv3DRadial(1,
                       kernel_size,
                       max_degree=3,
                       padding="same",
                       initializer=tf.keras.initializers.Constant(value=1.0))


In [None]:
image_size = 31
dirac = np.zeros((1, image_size, image_size, image_size, 1))
dirac[0, image_size // 2, image_size // 2, image_size // 2, 0] = 1
plt.imshow(dirac[0, :, :, image_size // 2, 0])


In [None]:
impulse_response = np.squeeze(layer(dirac))

In [None]:
impulse_response.shape

In [None]:
f = 0
plt.subplot(121)
plt.imshow(np.real(impulse_response[:, :, image_size // 2, f]))
plt.colorbar()
plt.subplot(122)
plt.imshow(np.imag(impulse_response[:, :, image_size // 2, f]))
plt.colorbar()

In [None]:
impulse_response.shape

In [None]:
layer.filters.shape

In [None]:
F_1 = np.squeeze(layer.filters.numpy())[..., [1, 2, 3]]
H_1 = impulse_response[..., [1, 2, 3]]


In [None]:
F_1.shape

In [None]:
def rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, -angle1, axes=(0, 1), reshape=False)
    image = rotate(image, angle2, axes=(1, 2), reshape=False)
    image = rotate(image, -angle3, axes=(0, 1), reshape=False)
    return image


def inv_rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, angle3, axes=(0, 1), reshape=False)
    image = rotate(image, -angle2, axes=(1, 2), reshape=False)
    image = rotate(image, angle1, axes=(0, 1), reshape=False)
    return image

In [None]:
angle1, angle2, angle3 = 0, 90, 0
D1 = wigner_matrix(1, angle1 * np.pi / 180, angle2 * np.pi / 180,
                   angle3 * np.pi / 180)
F_1_rotated = F_1 @ D1
# F_1_rotated = rotate_3d(F_1, angle1, angle2, angle3)
H_1_rotated = H_1 @ D1
f=0

In [None]:
plt.subplot(121)
plt.imshow(np.real(F_1[:, :, kernel_size // 2, f]))
plt.colorbar()
plt.subplot(122)
plt.imshow(np.imag(F_1[:, :, kernel_size // 2, f]))
plt.colorbar()

In [None]:
plt.subplot(121)
plt.imshow(np.real(F_1_rotated[:, :, kernel_size // 2, f]))
plt.colorbar()
plt.subplot(122)
plt.imshow(np.imag(F_1_rotated[:, :, kernel_size // 2, f]))
plt.colorbar()

In [None]:
F1_unrotated = inv_rotate_3d(F_1_rotated, angle1, angle2, angle3)
H1_unrotated = inv_rotate_3d(H_1_rotated, angle1, angle2, angle3)

In [None]:
plt.subplot(121)
plt.imshow(np.real(F1_unrotated[:, :, kernel_size // 2, f]))
plt.colorbar()
plt.subplot(122)
plt.imshow(np.imag(F1_unrotated[:, :, kernel_size // 2, f]))
plt.colorbar()

In [None]:
# difference = F_1 - F1_unrotated
difference = H_1 - H1_unrotated
plt.subplot(131)
plt.imshow(np.abs(difference[:, :, image_size // 2, f]))
plt.colorbar()
plt.subplot(132)
plt.imshow(np.abs(difference[:, image_size // 2, :, f]))
plt.colorbar()
plt.subplot(133)
plt.imshow(np.abs(difference[image_size // 2, :, :, f]))
plt.colorbar()