In [None]:
import os
from itertools import product

import h5py
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
from scipy.ndimage import rotate

from src.data.tf_data import TFDataCreator
from src.data.data_augmentation import preprocess_ct, RightAngleRotation
from src.models.utils import config_gpu
from src.models.cubenet.layers import GroupConv 
from src.models.layers_faster import SHConv3DRadial, BSHConv3D, SSHConv3D

%matplotlib inline

In [None]:
config_gpu("0", 4)
task = "Task04_Hippocampus"

In [None]:
file = h5py.File(f"../data/processed/{task}/{task}_training.hdf5", "r")

In [None]:
data_creator = TFDataCreator.get(task.split("_")[0])(file,
                                                     shuffle=True,
                                                     params_augmentation={
                                                         "rotation": True,
                                                         "random_center": False,
                                                     })
ds = data_creator.get_tf_data(data_augmentation=True)

In [None]:
np_iterator = ds.batch(4).repeat().as_numpy_iterator()

In [None]:
x, y_gt = next(np_iterator)

In [None]:
s = 32
b = 0
plt.subplot(1, 4, 1)
plt.imshow(x[b, :, :, s, 0])
plt.subplot(1, 4, 2)
plt.imshow(y_gt[b, :, :, s, 0])
plt.subplot(1, 4, 3)
plt.imshow(y_gt[b, :, :, s, 1])
plt.colorbar()
plt.subplot(1, 4, 4)
plt.imshow(y_gt[b, :, :, s, 2])

In [None]:
def rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, -angle1, axes=(0, 1), reshape=False, order=0)
    image = rotate(image, angle2, axes=(1, 2), reshape=False, order=0)
    image = rotate(image, -angle3, axes=(0, 1), reshape=False, order=0)
    return image


def inv_rotate_3d(image, angle1, angle2, angle3):
    image = np.squeeze(image)
    image = rotate(image, angle3, axes=(0, 1), reshape=False, order=0)
    image = rotate(image, -angle2, axes=(1, 2), reshape=False, order=0)
    image = rotate(image, angle1, axes=(0, 1), reshape=False, order=0)
    return image

In [None]:
model = tf.keras.Sequential(layers=[
    tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-2)
                           ),  # need this for the lifting layer
    GroupConv(
        5,
        kernel_size=(3, 3, 3),
        group="S4",
        activation="relu",
        use_bias=True,
        bias_initializer=tf.keras.initializers.Constant(0.1),
        share_weights=True,
    ),  # Lifting layer
    GroupConv(
        5,
        kernel_size=(3, 3, 3),
        group="S4",
        activation="relu",
        use_bias=True,
        bias_initializer=tf.keras.initializers.Constant(0.1),
        share_weights=True,
    ),  # g-conv layer
    tf.keras.layers.Lambda(lambda x: tf.reduce_max(x, axis=-1)),
])

model_bispectrum = BSHConv3D(
    1,
    3,
    max_degree=5,
    padding="valid",
    kernel_initializer=tf.keras.initializers.Constant(value=1.0),
    project=False)


In [None]:
angle1, angle2, angle3 = 90, 90, 90
x_rotated = np.zeros_like(x)
for b, c in product(range(x.shape[0]), range(x.shape[-1])):
    x_rotated[b, :, :, :, c] = rotate_3d(x[b, :, :, :, c], angle1, angle2,
                                         angle3)

In [None]:
# Just a check to be sure the rotation is perfect
for b, c in product(range(x.shape[0]), range(x.shape[-1])):
    error = np.sum(
        np.abs(x[b, :, :, :, c] -
               inv_rotate_3d(x_rotated[b, :, :, :,
                                       c], angle1, angle2, angle3)), )
    print(f"Error for sample {b} and channel {c} is {error}")

In [None]:
y_rotated = model(x_rotated)
y = model(x)

In [None]:
y_unrotated = np.zeros_like(y)
for b, c in product(range(y.shape[0]), range(y.shape[-1])):
    y_unrotated[b, :, :, :, c] = inv_rotate_3d(y_rotated[b, :, :, :, c],
                                               angle1, angle2, angle3)

In [None]:
def psnr(x, y):
    mse = np.mean(np.abs(x - y)**2, axis=(1, 2, 3))
    max_image = np.max(np.abs(x), axis=(1, 2, 3))
    return 20 * np.log10(max_image / np.sqrt(mse))

In [None]:
y.shape

In [None]:
psnr(y, y_unrotated)

In [None]:
b = 0
c = 2
fmap = y[b, :, :, :, c]
fmap_unrotated = y_unrotated[b, :, :, :, c]

In [None]:
difference = fmap - fmap_unrotated
indices_max_error = np.where(difference.numpy() == difference.numpy().max())
s = indices_max_error[2][0]  # check the first slice with the maximum error
print(f"Coordinate of the maximum errors: {indices_max_error}")

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, layout='constrained', sharey=True)
fig.set_size_inches(18.5, 10.5)

z1_plot = ax1.imshow(np.abs(difference[:, :, s]))
fig.colorbar(z1_plot, ax=ax1, fraction=0.046, pad=0.04)
ax1.set_title("Differences")

z2_plot = ax2.imshow(np.abs(fmap[:, :, s]))
fig.colorbar(z2_plot, ax=ax2, fraction=0.046, pad=0.04)
ax2.set_title("y")

z3_plot = ax3.imshow(np.abs(fmap_unrotated[:, :, s]))
fig.colorbar(z3_plot, ax=ax3, fraction=0.046, pad=0.04)
ax3.set_title("y_unrotated")

In [None]:
y_rotated = model_bispectrum(x_rotated)
y = model_bispectrum(x)

In [None]:
y_unrotated = np.zeros_like(y)
for b, c in product(range(y.shape[0]), range(y.shape[-1])):
    y_unrotated[b, :, :, :, c] = inv_rotate_3d(y_rotated[b, :, :, :, c],
                                               angle1, angle2, angle3)

In [None]:
psnr(y,y_unrotated)

In [None]:
b = 0
c = 9
fmap = y[b, :, :, :, c]
fmap_unrotated = y_unrotated[b, :, :, :, c]

In [None]:
difference = fmap - fmap_unrotated
indices_max_error = np.where(difference.numpy() == difference.numpy().max())
s = indices_max_error[2][0]  # check the first slice with the maximum error
print(f"Coordinate of the maximum errors: {indices_max_error}")

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, layout='constrained', sharey=True)
fig.set_size_inches(18.5, 10.5)

z1_plot = ax1.imshow(np.abs(difference[:, :, s]))
fig.colorbar(z1_plot, ax=ax1, fraction=0.046, pad=0.04)
ax1.set_title("Differences")

z2_plot = ax2.imshow(np.abs(fmap[:, :, s]))
fig.colorbar(z2_plot, ax=ax2, fraction=0.046, pad=0.04)
ax2.set_title("y")

z3_plot = ax3.imshow(np.abs(fmap_unrotated[:, :, s]))
fig.colorbar(z3_plot, ax=ax3, fraction=0.046, pad=0.04)
ax3.set_title("y_unrotated")

In [None]:
difference.numpy().argmax()

In [None]:
indices_max_error = np.where(difference.numpy()==difference.numpy().max())
coords_max_error = list(zip(indices_max_error[0], indices_max_error[1], indices_max_error[2]))

In [None]:
coords_max_error

In [None]:
indices_max_error