In [None]:
import os
from itertools import product

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import rotate
from scipy import special as sp
import SimpleITK as sitk

from src.models.layers_faster import SHConv3DRadial, BSHConv3D, SSHConv3D
from src.models.utils import config_gpu
# from src.models.models import ResidualSLRILayer3D, ResidualBLRILayer3D

%matplotlib inline

In [None]:
# image_sitk = sitk.ReadImage("/home/vscode/python_wkspce/petct-seg/data/processed/CHGJ074_ct.nii.gz")

In [None]:
config_gpu("0", memory_limit=4)

In [None]:
# layer = SSHConv3D(1, 7, max_degree=3, padding="valid", kernel_initializer=tf.keras.initializers.Constant(value=1.0),  project=False)
layer = BSHConv3D(1, 5, max_degree=3, padding="valid", kernel_initializer=tf.keras.initializers.Constant(value=1.0),  project=False)

In [None]:
def rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, -angle1, axes=(0, 1), reshape=False)
    image = rotate(image, angle2, axes=(1, 2), reshape=False)
    image = rotate(image, -angle3, axes=(0, 1), reshape=False)
    return image


def inv_rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, angle3, axes=(0, 1), reshape=False)
    image = rotate(image, -angle2, axes=(1, 2), reshape=False)
    image = rotate(image, angle1, axes=(0, 1), reshape=False)
    return image

In [None]:
layer.indices[3]

In [None]:
layer.indices_inverse[(1,2,3)]
# layer.indices_inverse[1]

In [None]:
# image = np.transpose(sitk.GetArrayFromImage(image_sitk), (2, 1, 0))
# image = image[60:92, 60:92, 70:102]
image = np.random.rand(32, 32, 32)
plt.imshow(image[:, :, 15])

In [None]:
# image = np.random.rand(image.shape[0], image.shape[1], image.shape[2])

In [None]:
angle1, angle2, angle3 = 90, 90, 0
axes = (2, 1)
# image = np.random.uniform(size=(32, 32, 32))
image_rotated = rotate_3d(image, angle1, angle2, angle3)


In [None]:
image.shape

In [None]:
output = layer(image[np.newaxis, :, :, :, np.newaxis])
output_rotated = layer(image_rotated[np.newaxis, :, :, :, np.newaxis])

In [None]:
def psnr(x, y):
    x_shape = x.shape
    mse = (np.sum(np.abs(x - y)**2, axis=(1, 2, 3)) /
           (x_shape[1] * x_shape[1] * x_shape[1]))
    max_image = np.max(np.abs(x), axis=(1, 2, 3))
    return 10 * np.log10(max_image**2 / mse)

In [None]:
output[0,:,:,:,3]

In [None]:
psnr(output, output_rotated)

In [None]:
def check_output_diff(output):
    for k1 in range(output.shape[-1]):
        if np.sum(np.abs(output[...,k1]))==0:
            print(f"map {layer.indices[k1]} is zero")


In [None]:
output.shape

In [None]:
check_output_diff(output)

In [None]:
output.shape

In [None]:
s = np.max(np.real(output), axis=(0, 1, 2, 3))
for i in range(s.shape[-1]):
    print(f"{layer.indices[i]}: {s[i]}")

In [None]:
# s = np.sum(np.imag(output), axis=(0,1,2,3))
# for i in range(s.shape[-1]):
#     print(f"{layer.indices[i]}: {s[i]}")

In [None]:
# f_ind = layer.indices_inverse[(1, 2,3)]
# f_ind = layer.indices_inverse[(2, 3, 5)]
f_ind = 1
fmap = output[0, :, :, :, f_ind]
fmap_rotated = output_rotated[0, :, :, :, f_ind]
fmap_unrotated = inv_rotate_3d(fmap_rotated, angle1, angle2, angle3)
# fmap_unrotated = rotate(fmap_rotated, -90, axes=(0, 1), reshape=False)

In [None]:
s = 15
difference = fmap - fmap_unrotated
plt.figure(figsize=(24, 4))
plt.subplot(131)
plt.imshow(np.abs(difference[:, :, s]))
plt.colorbar()
plt.subplot(132)
plt.imshow(np.abs(fmap[:, :, s]))
plt.colorbar()
plt.subplot(133)
plt.imshow(np.abs(fmap_unrotated[:, :, s]))
plt.colorbar()

In [None]:
np.sum(fmap[...])

In [None]:
kernel_size = 32
dirac = np.zeros((1, kernel_size, kernel_size, kernel_size, 1))
dirac[0, kernel_size // 2, kernel_size // 2, kernel_size // 2, 0] = 1
plt.imshow(dirac[0, :, :, kernel_size//2, 0])

In [None]:
impulse_response = layer(dirac)

In [None]:
plt.imshow(impulse_response[0, :, :, kernel_size // 2, 1])
plt.colorbar()

In [None]:
impulse_response.shape

In [None]:
atoms = layer.conv_sh.atoms.numpy()

In [None]:
plt.imshow(np.imag(atoms)[:,:,1,0, 1])
plt.colorbar()

In [None]:
layer.conv_sh.n_radial_profiles