In [None]:
import os
from itertools import product

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import rotate
from scipy import special as sp
import SimpleITK as sitk

from src.models.cubenet.layers import GroupConv 
from src.models.models import ResidualGLayer3D
from src.models.utils import config_gpu
from src.models.models import GUnet, Unet

%matplotlib inline

In [None]:
# image_sitk = sitk.ReadImage("/home/vscode/python_wkspce/petct-seg/data/processed/CHGJ074_ct.nii.gz")

In [None]:
config_gpu("0", memory_limit=16)

In [None]:
model = tf.keras.Sequential(layers=[
    tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-2)),
    GroupConv(
        5,
        kernel_size=(3, 3, 3),
        group="S4",
        activation="relu",
        use_bias=True,
        bias_initializer=tf.keras.initializers.Constant(0.1),
        share_weights=True,
    ),
    GroupConv(
        5,
        kernel_size=(3, 3, 3),
        group="S4",
        activation="relu",
        use_bias=True,
        bias_initializer=tf.keras.initializers.Constant(0.1),
        share_weights=True,
    ),
    tf.keras.layers.Lambda(lambda x: tf.reduce_max(x, axis=-1)),
])


In [None]:
# layer = tf.keras.Sequential(layers=[
#     tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-2)),
#     ResidualGLayer3D(8,
#                      3,
#                      group="S4",
#                      activation="relu",
#                      use_bias=True,
#                      use_batch_norm=True,
#                      bias_initializer=tf.keras.initializers.Constant(0.1)),
#     tf.keras.layers.Lambda(lambda x: tf.reduce_max(x, axis=-1)),
# ])
# model = GUnet(1, n_features=[2, 4, 8, 16, 32])


In [None]:
# model_standard = Unet(1, n_features=[2, 4, 8, 16, 32])

In [None]:
def rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, -angle1, axes=(0, 1), reshape=False)
    image = rotate(image, angle2, axes=(1, 2), reshape=False)
    image = rotate(image, -angle3, axes=(0, 1), reshape=False)
    return image


def inv_rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, angle3, axes=(0, 1), reshape=False)
    image = rotate(image, -angle2, axes=(1, 2), reshape=False)
    image = rotate(image, angle1, axes=(0, 1), reshape=False)
    return image

In [None]:
# image = np.transpose(sitk.GetArrayFromImage(image_sitk), (2, 1, 0))
# image = image[60:92, 60:92, 70:102]
image = np.random.rand(32, 32, 32)
plt.imshow(image[:, :, 15])

In [None]:
angle1, angle2, angle3 = 90, 90, 90
axes = (2, 1)
# image = np.random.uniform(size=(32, 32, 32))
image_rotated = rotate_3d(image, angle1, angle2, angle3)


In [None]:
image.shape

In [None]:
output = model(image[np.newaxis, :, :, :, np.newaxis])
output_rotated = model(image_rotated[np.newaxis, :, :, :, np.newaxis])

In [None]:
model.summary()

In [None]:
output.shape

In [None]:
output.shape

In [None]:
output.shape

In [None]:
# f_ind = layer.indices_inverse[(1, 1,2)]
# f_ind = layer.indices_inverse[(2, 3, 5)]
f_ind = 0
fmap = output[0, :, :, :, f_ind]
fmap_rotated = output_rotated[0, :, :, :, f_ind]
fmap_unrotated = inv_rotate_3d(fmap_rotated, angle1, angle2, angle3)
# fmap_unrotated = rotate(fmap_rotated, -90, axes=(0, 1), reshape=False)

In [None]:
s = 15
difference = fmap - fmap_unrotated
plt.figure(figsize=(24, 4))
plt.subplot(131)
plt.imshow(np.abs(difference[:, :, s]))
plt.colorbar()
plt.subplot(132)
plt.imshow(np.abs(fmap[:, :, s]))
plt.colorbar()
plt.subplot(133)
plt.imshow(np.abs(fmap_unrotated[:, :, s]))
plt.colorbar()

In [None]:
np.sum(fmap[...])

In [None]:
kernel_size = 32
dirac = np.zeros((1, kernel_size, kernel_size, kernel_size, 1))
dirac[0, kernel_size // 2, kernel_size // 2, kernel_size // 2, 0] = 1
plt.imshow(dirac[0, :, :, kernel_size//2, 0])

In [None]:
impulse_response = layer(dirac)

In [None]:
layer.indices[7]

In [None]:
plt.imshow(impulse_response[0, :, :, kernel_size // 2, 1])
plt.colorbar()

In [None]:
impulse_response.shape

In [None]:
yo = layer.conv_sh.filters.numpy()


In [None]:
yo.shape

In [None]:
i = layer.conv_sh.ravel_sh_index(3, 0)
s = 1
plt.subplot(131)
plt.imshow(np.real(yo[:, :, s, 0, 0, i]))
plt.colorbar()
plt.subplot(132)
plt.imshow(np.real(yo[:, s, :, 0, 0, i]))
plt.colorbar()
plt.subplot(133)
plt.imshow(np.real(yo[s, :, :, 0, 0, i]))
plt.colorbar()

In [None]:
layer.conv_sh.n_radial_profiles