In [1]:
import sys
import os
import glob

import numpy as np

sys.path.append("../")

In [2]:
from ext.lab2im import utils, edit_volumes
import nobrainer

2024-04-11 16:59:28.487419: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-11 16:59:28.558000: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-11 16:59:28.558053: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-11 16:59:28.558081: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-11 16:59:28.568797: I tensorflow/core/platform/cpu_feature_g

In [None]:
SOURCE_DIR_00 = "/om2/scratch/tmp/sabeen/kwyk_data/kwyk/rawdata/"

# TRANSFORM_DIR = '/om2/user/sabeen/kwyk_data/kwyk_transform_crop_1000'
TRANSFORM_DIR = SOURCE_DIR_00
FEATURE_TRANFORM_DIR = (
    f"{TRANSFORM_DIR}/features" if TRANSFORM_DIR != SOURCE_DIR_00 else SOURCE_DIR_00
)
LABEL_TRANFORM_DIR = (
    f"{TRANSFORM_DIR}/labels" if TRANSFORM_DIR != SOURCE_DIR_00 else SOURCE_DIR_00
)

In [None]:
def get_feature_label_pairs(features_dir=SOURCE_DIR_00, labels_dir=SOURCE_DIR_00):
    """
    Get pairs of feature and label filenames.
    """
    features = sorted(glob.glob(os.path.join(features_dir, "*orig*")))[:10]
    labels = sorted(glob.glob(os.path.join(labels_dir, "*aseg*")))[:10]

    return list(zip(features, labels))

feature_label_pairs = get_feature_label_pairs(FEATURE_TRANFORM_DIR,LABEL_TRANFORM_DIR)
feature_files = [feature for feature, _ in feature_label_pairs]
label_files = [label for _, label in feature_label_pairs]
feature_files = sorted(feature_files)
label_files = sorted(label_files)

In [None]:
feature = feature_files[0]
label = label_files[0]

label_vol, label_aff, label_hdr = utils.load_volume(label, im_only=False)
feature_vol, feature_aff, feature_hdr = utils.load_volume(feature, im_only=False)

In [None]:
label_vol = label_vol.astype('int16')
np.unique(label_vol)

In [None]:
binary_label_vol = np.zeros(label_vol.shape)
binary_label_vol[label_vol > 0] = 1
np.unique(binary_label_vol)

In [None]:
import matplotlib.pyplot as plt
from nilearn import plotting
import nibabel as nib
from nobrainer.volume import standardize

def load_brains(image_file: str, mask_file: str, file_path: str = ''):
    # ensure that mask and image numbers match
    image_nr = image_file.split("_")[1]
    mask_nr = mask_file.split("_")[1]
    assert image_nr == mask_nr, "image and mask numbers do not match"

    if file_path != '':
        image_path = os.path.join(file_path, image_file)
        mask_path = os.path.join(file_path, mask_file)
    else:
        image_path = image_file
        mask_path = mask_file

    brain = nib.load(image_path)
    brain_mask = nib.load(mask_path)

    brain = brain.get_fdata()
    brain_mask = brain_mask.get_fdata()
    brain_mask = brain_mask.astype(int)
    # apply skull stripping
    brain[brain_mask == 0] = 0

    return brain, brain_mask, image_nr

def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
    """
    Save a volume.
    :param volume: volume to save
    :param aff: affine matrix of the volume to save. If aff is None, the volume is saved with an identity affine matrix.
    aff can also be set to 'FS', in which case the volume is saved with the affine matrix of FreeSurfer outputs.
    :param header: header of the volume to save. If None, the volume is saved with a blank header.
    :param path: path where to save the volume.
    :param res: (optional) update the resolution in the header before saving the volume.
    :param dtype: (optional) numpy dtype for the saved volume.
    :param n_dims: (optional) number of dimensions, to avoid confusion in multi-channel case. Default is None, where
    n_dims is automatically inferred.
    """

    # mkdir(os.path.dirname(path))
    if ".npz" in path:
        np.savez_compressed(path, vol_data=volume)
    else:
        if header is None:
            header = nib.Nifti1Header()
        if isinstance(aff, str):
            if aff == "FS":
                aff = np.array(
                    [[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]
                )
        elif aff is None:
            aff = np.eye(4)
        if dtype is not None:
            if "int" in dtype:
                volume = np.round(volume)
            volume = volume.astype(dtype=dtype)
            nifty = nib.Nifti1Image(volume, aff, header)
            nifty.set_data_dtype(dtype)
        else:
            nifty = nib.Nifti1Image(volume, aff, header)
        # if res is not None:
        #     if n_dims is None:
        #         n_dims, _ = get_dims(volume.shape)
        #     res = reformat_to_list(res, length=n_dims, dtype=None)
        #     nifty.header.set_zooms(res)
        nib.save(nifty, path)
    return nifty, path

In [None]:
# feature_vol, label_vol, _ = load_brains(feature, label)

# binary_label_vol = np.zeros(label_vol.shape)
# binary_label_vol[label_vol > 0] == 1.0

nifty, path = save_volume(binary_label_vol, None, None, '/om2/user/sabeen/tissue_labeling/misc/test_binary_label_vol_unrot')

fig = plt.figure(figsize=(12, 6))
plotting.plot_roi(
    nifty,
    bg_img='/om2/user/sabeen/tissue_labeling/misc/test_binary_label_vol_unrot.nii',
    # cut_coords=(0, 10, -21),
    alpha=0.4,
    vmin=0,
    vmax=5,
    figure=fig,
)

In [None]:
def nobrainer_rotate(feature_vol, label_vol):
    # randomly choose an angle between -20 to 20 for all axes
    # angles = np.random.uniform(-20, 20, size=3)
    angles = np.radians(np.array([20.0,20.0,20.0]))
    print('angles', angles)
    # assert feature_vol.shape == label_vol.shape

    affine = nobrainer.transform.get_affine(feature_vol.shape, rotation=angles)
    rotated_feature_vol = np.array(nobrainer.transform.warp(feature_vol, affine, order=1))
    rotated_label_vol = np.array(
        nobrainer.transform.warp(label_vol, affine, order=0)
    ).astype("int16")
    return rotated_feature_vol, rotated_label_vol

In [None]:
rotated_feature_vol, rotated_label_vol = nobrainer_rotate(feature_vol, label_vol)

In [None]:
binary_label_vol_rot = np.zeros(rotated_label_vol.shape)
binary_label_vol_rot[rotated_label_vol > 0] =  1
np.unique(binary_label_vol_rot)

In [None]:
nifty, path = save_volume(binary_label_vol_rot, None, None, '/om2/user/sabeen/tissue_labeling/misc/test_binary_label_vol_rot')

fig = plt.figure(figsize=(12, 6))
plotting.plot_roi(
    nifty,
    bg_img='/om2/user/sabeen/tissue_labeling/misc/test_binary_label_vol_rot.nii',
    # cut_coords=(0, 10, -21),
    alpha=0.4,
    vmin=0,
    vmax=5,
    figure=fig,
)

In [None]:
from ext.lab2im import utils, edit_volumes

def synthseg_rotation(feature_vol, label_vol):
    angles = np.radians(np.array([20.0,20.0,20.0]))
    affine = utils.create_affine_transformation_matrix(n_dims=3,rotation=angles)

    rotated_feature_vol,_ = edit_volumes.resample_volume(feature_vol,affine,new_vox_size=[1,1,1],interpolation='linear',blur=False)
    rotated_label_vol,_ = edit_volumes.resample_volume(label_vol,affine,new_vox_size=[1,1,1],interpolation='nearest',blur=False)
    return rotated_feature_vol, rotated_label_vol

In [None]:
synthseg_rotated_feature_vol,synthseg_rotated_label_vol = synthseg_rotation(feature_vol,label_vol)

In [None]:
binary_synthseg_rotated_label_vol = np.zeros(synthseg_rotated_label_vol.shape)
binary_synthseg_rotated_label_vol[synthseg_rotated_label_vol > 0] = 1
np.unique(binary_synthseg_rotated_label_vol)

In [None]:
nifty, path = save_volume(binary_synthseg_rotated_label_vol, None, None, '/om2/user/sabeen/tissue_labeling/misc/test_binary_label_vol_rot_synth')

fig = plt.figure(figsize=(12, 6))
plotting.plot_roi(
    nifty,
    bg_img='/om2/user/sabeen/tissue_labeling/misc/test_binary_label_vol_rot_synth.nii',
    # cut_coords=(0, 10, -21),
    alpha=0.4,
    vmin=0,
    vmax=5,
    figure=fig,
)

In [None]:
(synthseg_rotated_label_vol == label_vol).all()

In [3]:
import tensorflow as tf

In [None]:
def nobrainer_get_coordinates(volume_shape):
    """Get coordinates that represent every voxel in a volume with shape
    `volume_shape`.

    Parameters
    ----------
    volume_shape: tuple of length 3, shape of output volume.

    Returns
    -------
    Tensor of coordinates with shape `(prod(volume_shape), 3)`.
    """
    if len(volume_shape) < 3:
        raise ValueError("shape must have at least 3 items.")
    dtype = tf.float32
    rows, cols, depth = volume_shape[:3]

    out = tf.meshgrid(
        tf.range(rows, dtype=dtype),
        tf.range(cols, dtype=dtype),
        tf.range(depth, dtype=dtype),
        indexing="ij",
    )
    unreshaped = tf.stack(out, axis=3)
    return tf.reshape(tf.stack(out, axis=3), shape=(-1, 3)), unreshaped

def nobrainer_warp_coords(matrix, volume_shape):
    """Build the coordinates for a affine transform on volumetric data.

    Parameters
    ----------
    matrix: tensor with shape (4, 4), affine matrix.
    volume_shape: tuple of length 3, shape of output volume.

    Returns
    -------
    TODO check this.
    Tensor of coordinates with shape (*volume_shape, 3).
    """
    coords,_ = nobrainer_get_coordinates(volume_shape=volume_shape)
    # Append ones to play nicely with 4x4 affine.
    coords_homogeneous = tf.concat(
        [coords, tf.ones((coords.shape[0], 1), dtype=coords.dtype)], axis=1
    )
    return (coords_homogeneous @ tf.transpose(matrix))[..., :3]

def nobrainer_trilinear_interpolation(volume, coords):
    """Trilinear interpolation.

    Implemented according to
    https://en.wikipedia.org/wiki/Trilinear_interpolation#Method
    https://github.com/Ryo-Ito/spatial_transformer_network/blob/2555e846b328e648a456f92d4c80fce2b111599e/warp.py#L137-L222
    """
    volume = tf.cast(volume, tf.float32)
    coords = tf.cast(coords, tf.float32)
    coords_floor = tf.floor(coords)

    shape = tf.shape(volume)
    xlen = shape[0]
    ylen = shape[1]
    zlen = shape[2]

    # Get lattice points. x0 is point below x, and x1 is point above x. Same for y and
    # z.
    x0 = tf.cast(coords_floor[:, 0], tf.int32)
    x1 = x0 + 1
    y0 = tf.cast(coords_floor[:, 1], tf.int32)
    y1 = y0 + 1
    z0 = tf.cast(coords_floor[:, 2], tf.int32)
    z1 = z0 + 1

    # Clip values to the size of the volume array.
    x0 = tf.clip_by_value(x0, 0, xlen - 1)
    x1 = tf.clip_by_value(x1, 0, xlen - 1)
    y0 = tf.clip_by_value(y0, 0, ylen - 1)
    y1 = tf.clip_by_value(y1, 0, ylen - 1)
    z0 = tf.clip_by_value(z0, 0, zlen - 1)
    z1 = tf.clip_by_value(z1, 0, zlen - 1)

    # Get the indices at corners of cube.
    i000 = x0 * ylen * zlen + y0 * zlen + z0
    i001 = x0 * ylen * zlen + y0 * zlen + z1
    i010 = x0 * ylen * zlen + y1 * zlen + z0
    i011 = x0 * ylen * zlen + y1 * zlen + z1
    i100 = x1 * ylen * zlen + y0 * zlen + z0
    i101 = x1 * ylen * zlen + y0 * zlen + z1
    i110 = x1 * ylen * zlen + y1 * zlen + z0
    i111 = x1 * ylen * zlen + y1 * zlen + z1

    # Get volume values at corners of cube.
    if len(volume.shape) == 3:
        volume_flat = tf.reshape(volume, [-1])
    else:
        volume_flat = tf.reshape(volume, [-1, volume.shape[-1]])

    c000 = tf.gather(volume_flat, i000)
    c001 = tf.gather(volume_flat, i001)
    c010 = tf.gather(volume_flat, i010)
    c011 = tf.gather(volume_flat, i011)
    c100 = tf.gather(volume_flat, i100)
    c101 = tf.gather(volume_flat, i101)
    c110 = tf.gather(volume_flat, i110)
    c111 = tf.gather(volume_flat, i111)

    xd = coords[:, 0] - tf.cast(x0, tf.float32)
    yd = coords[:, 1] - tf.cast(y0, tf.float32)
    zd = coords[:, 2] - tf.cast(z0, tf.float32)

    if len(volume.shape) == 4:
        # Add a channels axis for proper broadcasting
        xd = xd[:, tf.newaxis]
        yd = yd[:, tf.newaxis]
        zd = zd[:, tf.newaxis]

    # Interpolate along x-axis.
    c00 = c000 * (1 - xd) + c100 * xd
    c01 = c001 * (1 - xd) + c101 * xd
    c10 = c010 * (1 - xd) + c110 * xd
    c11 = c011 * (1 - xd) + c111 * xd

    # Interpolate along y-axis.
    c0 = c00 * (1 - yd) + c10 * yd
    c1 = c01 * (1 - yd) + c11 * yd

    # Interpolate along z-axis.
    c = c0 * (1 - zd) + c1 * zd

    return tf.reshape(c, volume.shape)


In [5]:
vol_shape = (5,10,15)
volume = np.random.rand(*vol_shape)
affine = nobrainer.transform.get_affine(vol_shape, rotation=np.radians([20,20,20]))

2024-04-11 16:44:50.777824: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2024-04-11 16:44:50.777874: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: node078
2024-04-11 16:44:50.777882: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: node078
2024-04-11 16:44:50.778061: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 550.54.14
2024-04-11 16:44:50.778086: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 550.54.14
2024-04-11 16:44:50.778092: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 550.54.14


In [None]:
nobrainer_coords, nobrainer_unreshaped = nobrainer_get_coordinates(vol_shape)
nob_coords_homogeneous = tf.concat(
        [nobrainer_coords, tf.ones((nobrainer_coords.shape[0], 1), dtype=nobrainer_coords.dtype)], axis=1
    )
print(nob_coords_homogeneous.shape)

In [None]:
nobrainer_warped_coords = nobrainer_warp_coords(affine,vol_shape)
print(nobrainer_warped_coords.shape)

In [None]:
nobrainer_vol = nobrainer_trilinear_interpolation(volume, nobrainer_warped_coords)
print(nobrainer_vol.shape)

In [None]:
# MATCHES NOBRAINER RESULTS
Nx, Ny, Nz = vol_shape
x = np.linspace(0, Nx - 1, Nx)
y = np.linspace(0, Ny - 1, Ny)
z = np.linspace(0, Nz - 1, Nz)
xx, yy, zz = np.meshgrid(x,y,z, indexing='ij')
# coor = np.array([xx, yy, zz])
coor = np.stack((xx,yy,zz),axis=3)
coor_reshaped = np.reshape(coor,(-1,3))
coor_homog = np.concatenate([coor_reshaped,np.ones((coor_reshaped.shape[0],1)).astype(coor_reshaped.dtype)], axis=1)
warped_coords = (coor_homog @ np.transpose(affine))[..., :3]

In [None]:
warped_coords[500], nobrainer_warped_coords[500]

In [None]:
def np_trilinear_interpolation(volume, coords): # MATCHES NOBRAINER RESULTS
    """Trilinear interpolation.

    Implemented according to
    https://en.wikipedia.org/wiki/Trilinear_interpolation#Method
    https://github.com/Ryo-Ito/spatial_transformer_network/blob/2555e846b328e648a456f92d4c80fce2b111599e/warp.py#L137-L222
    """
    volume = volume.astype(np.float32)
    coords = warped_coords.astype(np.float32)
    coords_floor = np.floor(coords)

    shape = volume.shape
    xlen = shape[0]
    ylen = shape[1]
    zlen = shape[2]

    # Get lattice points. x0 is point below x, and x1 is point above x. Same for y and
    # z.
    x0 = coords_floor[:, 0].astype(np.int32)
    x1 = x0 + 1
    y0 = coords_floor[:, 1].astype(np.int32)
    y1 = y0 + 1
    z0 = coords_floor[:, 2].astype(np.int32)
    z1 = z0 + 1

    # Clip values to the size of the volume array.
    x0 = np.clip(x0, 0, xlen - 1)
    x1 = np.clip(x1, 0, xlen - 1)
    y0 = np.clip(y0, 0, ylen - 1)
    y1 = np.clip(y1, 0, ylen - 1)
    z0 = np.clip(z0, 0, zlen - 1)
    z1 = np.clip(z1, 0, zlen - 1)

    i000 = x0 * ylen * zlen + y0 * zlen + z0
    i001 = x0 * ylen * zlen + y0 * zlen + z1
    i010 = x0 * ylen * zlen + y1 * zlen + z0
    i011 = x0 * ylen * zlen + y1 * zlen + z1
    i100 = x1 * ylen * zlen + y0 * zlen + z0
    i101 = x1 * ylen * zlen + y0 * zlen + z1
    i110 = x1 * ylen * zlen + y1 * zlen + z0
    i111 = x1 * ylen * zlen + y1 * zlen + z1

    if len(volume.shape) == 3:
        volume_flat = np.reshape(volume, [-1])
    else:
        volume_flat = np.reshape(volume, [-1, volume.shape[-1]])

    c000 = np.take(volume_flat, i000)
    c001 = np.take(volume_flat, i001)
    c010 = np.take(volume_flat, i010)
    c011 = np.take(volume_flat, i011)
    c100 = np.take(volume_flat, i100)
    c101 = np.take(volume_flat, i101)
    c110 = np.take(volume_flat, i110)
    c111 = np.take(volume_flat, i111)

    xd = coords[:, 0] - x0.astype(np.float32)
    yd = coords[:, 1] - y0.astype(np.float32)
    zd = coords[:, 2] - z0.astype(np.float32)

    # Interpolate along x-axis.
    c00 = c000 * (1 - xd) + c100 * xd
    c01 = c001 * (1 - xd) + c101 * xd
    c10 = c010 * (1 - xd) + c110 * xd
    c11 = c011 * (1 - xd) + c111 * xd

    # Interpolate along y-axis.
    c0 = c00 * (1 - yd) + c10 * yd
    c1 = c01 * (1 - yd) + c11 * yd

    c = c0 * (1 - zd) + c1 * zd

    return np.reshape(c, volume.shape)

In [None]:
volume = volume.astype(np.float32)
coords = warped_coords.astype(np.float32)
coords_floor = np.floor(coords)

shape = volume.shape
xlen = shape[0]
ylen = shape[1]
zlen = shape[2]

# Get lattice points. x0 is point below x, and x1 is point above x. Same for y and
# z.
x0 = coords_floor[:, 0].astype(np.int32)
x1 = x0 + 1
y0 = coords_floor[:, 1].astype(np.int32)
y1 = y0 + 1
z0 = coords_floor[:, 2].astype(np.int32)
z1 = z0 + 1

# Clip values to the size of the volume array.
x0 = np.clip(x0, 0, xlen - 1)
x1 = np.clip(x1, 0, xlen - 1)
y0 = np.clip(y0, 0, ylen - 1)
y1 = np.clip(y1, 0, ylen - 1)
z0 = np.clip(z0, 0, zlen - 1)
z1 = np.clip(z1, 0, zlen - 1)

i000 = x0 * ylen * zlen + y0 * zlen + z0
i001 = x0 * ylen * zlen + y0 * zlen + z1
i010 = x0 * ylen * zlen + y1 * zlen + z0
i011 = x0 * ylen * zlen + y1 * zlen + z1
i100 = x1 * ylen * zlen + y0 * zlen + z0
i101 = x1 * ylen * zlen + y0 * zlen + z1
i110 = x1 * ylen * zlen + y1 * zlen + z0
i111 = x1 * ylen * zlen + y1 * zlen + z1

if len(volume.shape) == 3:
    volume_flat = np.reshape(volume, [-1])
else:
    volume_flat = np.reshape(volume, [-1, volume.shape[-1]])

c000 = np.take(volume_flat, i000)
c001 = np.take(volume_flat, i001)
c010 = np.take(volume_flat, i010)
c011 = np.take(volume_flat, i011)
c100 = np.take(volume_flat, i100)
c101 = np.take(volume_flat, i101)
c110 = np.take(volume_flat, i110)
c111 = np.take(volume_flat, i111)

xd = coords[:, 0] - x0.astype(np.float32)
yd = coords[:, 1] - y0.astype(np.float32)
zd = coords[:, 2] - z0.astype(np.float32)

# Interpolate along x-axis.
c00 = c000 * (1 - xd) + c100 * xd
c01 = c001 * (1 - xd) + c101 * xd
c10 = c010 * (1 - xd) + c110 * xd
c11 = c011 * (1 - xd) + c111 * xd

# Interpolate along y-axis.
c0 = c00 * (1 - yd) + c10 * yd
c1 = c01 * (1 - yd) + c11 * yd

c = c0 * (1 - zd) + c1 * zd

interp = np.reshape(c, volume.shape)

In [None]:
interp = np_trilinear_interpolation(volume, warped_coords)

In [None]:
def nobrainer_get_voxels(volume, coords): # MATCHES NOBRAINER RESULTS
    """Get voxels from volume at points. These voxels are in a flat tensor."""
    x = tf.cast(volume, tf.float32)
    coords = tf.cast(coords, tf.float32)

    if len(x.shape) < 3:
        raise ValueError("`volume` must be at least rank 3")
    if len(coords.shape) != 2 or coords.shape[1] != 3:
        raise ValueError("`coords` must have shape `(N, 3)`.")

    rows, cols, depth, *n_channels = x.shape

    # Points in flattened array representation.
    fcoords = coords[:, 0] * cols * depth + coords[:, 1] * depth + coords[:, 2]

    # Some computed finds are out of range of the image's flattened size.
    # Zero those so we don't get errors. These points in the volume are filled later.
    fcoords_size = tf.size(fcoords, out_type=fcoords.dtype)
    fcoords = tf.clip_by_value(fcoords, 0, fcoords_size - 1)
    xflat = tf.squeeze(tf.reshape(x, [tf.math.reduce_prod(x.shape[:3]), -1]))

    # Reorder image data to transformed space.
    xflat = tf.gather(params=xflat, indices=tf.cast(fcoords, tf.int32))

    # Zero image data that was out of frame.
    outofframe = (
        tf.reduce_any(coords < 0, -1)
        | (coords[:, 0] > rows)
        | (coords[:, 1] > cols)
        | (coords[:, 2] > depth)
    )

    if n_channels:
        outofframe = tf.stack([outofframe for _ in range(n_channels[0])], axis=-1)

    xflat = tf.multiply(xflat, tf.cast(tf.logical_not(outofframe), xflat.dtype))

    return xflat

In [None]:
nobrainer_voxels = nobrainer_get_voxels(volume=volume, coords=tf.round(nobrainer_warped_coords))

In [None]:
def np_get_voxels(volume, coords):
    """Get voxels from volume at points. These voxels are in a flat tensor."""
    x = volume.astype(np.float32)
    coords = coords.astype(np.float32)

    if len(x.shape) < 3:
        raise ValueError("`volume` must be at least rank 3")
    if len(coords.shape) != 2 or coords.shape[1] != 3:
        raise ValueError("`coords` must have shape `(N, 3)`.")

    rows, cols, depth, *n_channels = x.shape

    # Points in flattened array representation.
    fcoords = coords[:, 0] * cols * depth + coords[:, 1] * depth + coords[:, 2]

    # Some computed finds are out of range of the image's flattened size.
    # Zero those so we don't get errors. These points in the volume are filled later.
    fcoords_size = np.size(fcoords) * 1.0
    fcoords = np.clip(fcoords, 0, fcoords_size - 1)
    xflat = np.squeeze(np.reshape(x, [np.prod(x.shape[:3]), -1]))

    # Reorder image data to transformed space.
    xflat = np.take(xflat, indices=fcoords.astype(np.int32))

    # Zero image data that was out of frame.
    outofframe = (
        np.any(coords < 0, -1)
        | (coords[:, 0] > rows)
        | (coords[:, 1] > cols)
        | (coords[:, 2] > depth)
    )

    if n_channels:
        outofframe = np.stack([outofframe for _ in range(n_channels[0])], axis=-1)

    xflat = xflat * np.logical_not(outofframe).astype(xflat.dtype)

    return xflat

In [None]:
np_voxels = np_get_voxels(volume,np.round(warped_coords))

In [None]:
def nobrainer_nearest_neighbor_interpolation(volume, coords):
    """Three-dimensional nearest neighbors interpolation."""
    volume_f = nobrainer_get_voxels(volume=volume, coords=tf.round(coords))
    return tf.reshape(volume_f, volume.shape)

nobrainer_nearest = nobrainer_nearest_neighbor_interpolation(volume, nobrainer_warped_coords)

In [None]:
def np_nearest_neighbor_interpolation(volume, coords): #MATCHES NOBRAINER VERSION
    """Three-dimensional nearest neighbors interpolation."""
    volume_f = np_get_voxels(volume=volume, coords=np.round(coords))
    return np.reshape(volume_f, volume.shape)

np_nearest = np_nearest_neighbor_interpolation(volume, warped_coords)

In [None]:
(np_nearest == np.array(nobrainer_nearest)).all()

In [6]:
from TissueLabeling.brain_utils import np_warp_features_labels
from nobrainer.transform import warp_features_labels as nob_warp_features_labels

In [8]:
nobrainer_feature, nobrainer_label = nob_warp_features_labels(volume,volume,affine)

In [7]:
np_feature, np_label = np_warp_features_labels(volume, volume, affine)

In [13]:
np_feature[:,:,0]

array([[0.6768483 , 0.25291976, 0.30628937, 0.76994455, 0.8470567 ,
        0.35417992, 0.29429865, 0.388851  , 0.2596641 , 0.13360512],
       [0.52823603, 0.5099588 , 0.6822168 , 0.7956048 , 0.9229262 ,
        0.45054898, 0.41697913, 0.25189778, 0.21098478, 0.16766718],
       [0.33755022, 0.73677367, 0.8790368 , 0.56680334, 0.55487216,
        0.56641364, 0.46990442, 0.3239779 , 0.22389337, 0.20172921],
       [0.06480993, 0.35477626, 0.8948826 , 0.20037226, 0.2864908 ,
        0.5252845 , 0.569538  , 0.30177903, 0.19813123, 0.22233976],
       [0.4530878 , 0.50801575, 0.7702455 , 0.45247576, 0.64028263,
        0.5992667 , 0.38509613, 0.41447946, 0.2223226 , 0.19657756]],
      dtype=float32)

In [15]:
nobrainer_feature[:,:,0]

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[0.6768483 , 0.25291976, 0.30628937, 0.76994455, 0.8470567 ,
        0.35417968, 0.29429883, 0.38885075, 0.2596641 , 0.13360512],
       [0.52823603, 0.5099588 , 0.6822168 , 0.7956048 , 0.92292637,
        0.45054898, 0.41697913, 0.25189784, 0.21098478, 0.16766718],
       [0.33755022, 0.73677367, 0.8790367 , 0.56680334, 0.55487216,
        0.56641364, 0.46990442, 0.3239779 , 0.22389337, 0.20172921],
       [0.06480998, 0.35477632, 0.8948827 , 0.20037216, 0.28649077,
        0.5252845 , 0.569538  , 0.30177897, 0.19813125, 0.22233976],
       [0.4530878 , 0.5080158 , 0.7702453 , 0.45247588, 0.64028263,
        0.5992667 , 0.38509613, 0.4144795 , 0.2223226 , 0.19657758]],
      dtype=float32)>

In [15]:
vol_shape = (5,10,15)
affine2 = np_get_affine(vol_shape, rotation=np.radians([20,20,20]))
affine2

array([[ 0.8830222 , -0.21147065,  0.41898912, -1.7473502 ],
       [ 0.3213938 ,  0.9230309 , -0.21147065,  1.1838677 ],
       [-0.34202012,  0.3213938 ,  0.8830222 ,  0.05661297],
       [ 0.        ,  0.        ,  0.        ,  1.        ]],
      dtype=float32)

In [16]:
affine = nob_get_affine(vol_shape, rotation=np.radians([20,20,20]))
affine

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[ 0.8830222 , -0.21147065,  0.41898912, -1.7473502 ],
       [ 0.3213938 ,  0.9230309 , -0.21147065,  1.1838677 ],
       [-0.34202012,  0.3213938 ,  0.8830222 ,  0.05661297],
       [ 0.        ,  0.        ,  0.        ,  1.        ]],
      dtype=float32)>

In [13]:
def nob_get_affine(volume_shape, rotation=[0, 0, 0], translation=[0, 0, 0]):
    """Return 4x4 affine, which encodes rotation and translation of 3D tensors.

    Parameters
    ----------
    rotation: iterable of three numbers, the yaw, pitch, and roll,
        respectively, in radians.
    translation: iterable of three numbers, the number of voxels to translate
        in the x, y, and z directions.

    Returns
    -------
    Tensor with shape `(4, 4)` and dtype float32.
    """
    volume_shape = tf.cast(volume_shape, tf.float32)
    rotation = tf.cast(rotation, tf.float32)
    translation = tf.cast(translation, tf.float32)
    if volume_shape.shape[0] < 3:
        raise ValueError("`volume_shape` must have at least three values")
    if rotation.shape[0] != 3:
        raise ValueError("`rotation` must have three values")
    if translation.shape[0] != 3:
        raise ValueError("`translation` must have three values")

    # ROTATION
    # yaw
    rx = tf.convert_to_tensor(
        [
            [1, 0, 0, 0],
            [0, tf.math.cos(rotation[0]), -tf.math.sin(rotation[0]), 0],
            [0, tf.math.sin(rotation[0]), tf.math.cos(rotation[0]), 0],
            [0, 0, 0, 1],
        ],
        dtype=tf.float32,
    )

    # pitch
    ry = tf.convert_to_tensor(
        [
            [tf.math.cos(rotation[1]), 0, tf.math.sin(rotation[1]), 0],
            [0, 1, 0, 0],
            [-tf.math.sin(rotation[1]), 0, tf.math.cos(rotation[1]), 0],
            [0, 0, 0, 1],
        ],
        dtype=tf.float32,
    )

    # roll
    rz = tf.convert_to_tensor(
        [
            [tf.math.cos(rotation[2]), -tf.math.sin(rotation[2]), 0, 0],
            [tf.math.sin(rotation[2]), tf.math.cos(rotation[2]), 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ],
        dtype=tf.float32,
    )

    # Rotation around origin.
    transform = rz @ ry @ rx

    center = tf.convert_to_tensor(volume_shape[:3] / 2 - 0.5, dtype=tf.float32)
    neg_center = tf.math.negative(center)
    center_to_origin = tf.convert_to_tensor(
        [
            [1, 0, 0, neg_center[0]],
            [0, 1, 0, neg_center[1]],
            [0, 0, 1, neg_center[2]],
            [0, 0, 0, 1],
        ],
        dtype=tf.float32,
    )

    origin_to_center = tf.convert_to_tensor(
        [
            [1, 0, 0, center[0]],
            [0, 1, 0, center[1]],
            [0, 0, 1, center[2]],
            [0, 0, 0, 1],
        ],
        dtype=tf.float32,
    )

    # Rotation around center of volume.
    transform = origin_to_center @ transform @ center_to_origin

    # TRANSLATION
    translation = tf.convert_to_tensor(
        [
            [1, 0, 0, translation[0]],
            [0, 1, 0, translation[1]],
            [0, 0, 1, translation[2]],
            [0, 0, 0, 1],
        ],
        dtype=tf.float32,
    )

    transform = translation @ transform

    # REFLECTION
    #
    # TODO.
    # See http://web.iitd.ac.in/~hegde/cad/lecture/L6_3dtrans.pdf#page=7
    # and https://en.wikipedia.org/wiki/Transformation_matrix#Reflection_2

    return transform

In [14]:
def np_get_affine(volume_shape, rotation=[0, 0, 0], translation=[0, 0, 0]): # MATCHES NOBRAINER OUTPUT
    """Return 4x4 affine, which encodes rotation and translation of 3D tensors.

    Parameters
    ----------
    rotation: iterable of three numbers, the yaw, pitch, and roll,
        respectively, in radians.
    translation: iterable of three numbers, the number of voxels to translate
        in the x, y, and z directions.

    Returns
    -------
    Tensor with shape `(4, 4)` and dtype float32.
    """
    volume_shape = np.array(volume_shape).astype(np.float32)
    rotation = np.array(rotation).astype(np.float32)
    translation = np.array(translation).astype(np.float32)
    if volume_shape.shape[0] < 3:
        raise ValueError("`volume_shape` must have at least three values")
    if rotation.shape[0] != 3:
        raise ValueError("`rotation` must have three values")
    if translation.shape[0] != 3:
        raise ValueError("`translation` must have three values")

    # ROTATION
    # yaw
    rx = np.array(
        [
            [1, 0, 0, 0],
            [0, np.cos(rotation[0]), -np.sin(rotation[0]), 0],
            [0, np.sin(rotation[0]), np.cos(rotation[0]), 0],
            [0, 0, 0, 1],
        ],
        dtype=np.float32
    )

    # pitch
    ry = np.array(
        [
            [np.cos(rotation[1]), 0, np.sin(rotation[1]), 0],
            [0, 1, 0, 0],
            [-np.sin(rotation[1]), 0, np.cos(rotation[1]), 0],
            [0, 0, 0, 1],
        ],
        dtype=np.float32
    )

    # roll
    rz = np.array(
        [
            [np.cos(rotation[2]), -np.sin(rotation[2]), 0, 0],
            [np.sin(rotation[2]), np.cos(rotation[2]), 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ],
        dtype=np.float32
    )

    # Rotation around origin.
    transform = rz @ ry @ rx

    center = (volume_shape[:3] / 2 - 0.5).astype(np.float32)
    neg_center = -1 * center
    center_to_origin = np.array(
        [
            [1, 0, 0, neg_center[0]],
            [0, 1, 0, neg_center[1]],
            [0, 0, 1, neg_center[2]],
            [0, 0, 0, 1],
        ],
        dtype=np.float32,
    )

    origin_to_center = np.array(
        [
            [1, 0, 0, center[0]],
            [0, 1, 0, center[1]],
            [0, 0, 1, center[2]],
            [0, 0, 0, 1],
        ],
        dtype=np.float32,
    )

    # Rotation around center of volume.
    transform = origin_to_center @ transform @ center_to_origin

    # TRANSLATION
    translation = np.array(
        [
            [1, 0, 0, translation[0]],
            [0, 1, 0, translation[1]],
            [0, 0, 1, translation[2]],
            [0, 0, 0, 1],
        ],
        dtype=np.float32,
    )

    transform = translation @ transform

    # REFLECTION
    #
    # TODO.
    # See http://web.iitd.ac.in/~hegde/cad/lecture/L6_3dtrans.pdf#page=7
    # and https://en.wikipedia.org/wiki/Transformation_matrix#Reflection_2

    return transform