In [1]:
%reload_ext autoreload
%autoreload 2
from cudants.io.image import Image, BatchedImages
from cudants.registration.rigid import RigidRegistration
from cudants.registration.affine import AffineRegistration
import torch
from torch.optim import SGD, Adam
from torch.nn import functional as F
from torch import nn

In [3]:
img1 = Image.load_file('/data/BRATS2021/training/BraTS2021_00598/BraTS2021_00598_t1.nii.gz')
img2 = Image.load_file('/data/BRATS2021/training/BraTS2021_00599/BraTS2021_00599_t1.nii.gz')

In [4]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, IntSlider, Layout

def browse_image_slices(image_3d):
    # Infer the dimensions of image
    dim_x, dim_y, dim_z = image_3d.shape
    
    # Define a function to visualize the slices
    def plot_slice(x, y, z):
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))

        ax[0].imshow(image_3d[x, :, :], cmap='gray')
        ax[0].set_title('Slice at X = %d' % x)

        ax[1].imshow(image_3d[:, y, :], cmap='gray')
        ax[1].set_title('Slice at Y = %d' % y)

        ax[2].imshow(image_3d[:, :, z], cmap='gray')
        ax[2].set_title('Slice at Z = %d' % z)

        plt.show()

    # Use interact to create the UI with sliders
    interact(plot_slice,
             x=IntSlider(min=0, max=dim_x-1, value=dim_x//2, layout=Layout(width='600px')),
             y=IntSlider(min=0, max=dim_y-1, value=dim_y//2, layout=Layout(width='600px')),
             z=IntSlider(min=0, max=dim_z-1, value=dim_z//2, layout=Layout(width='600px')))

In [5]:
# Test the function
browse_image_slices(img1.array[0, 0].data.cpu().numpy())

interactive(children=(IntSlider(value=77, description='x', layout=Layout(width='600px'), max=154), IntSlider(v…

In [6]:
# Test the function
browse_image_slices(img2.array[0, 0].data.cpu().numpy())

interactive(children=(IntSlider(value=77, description='x', layout=Layout(width='600px'), max=154), IntSlider(v…

In [65]:
fixed = BatchedImages([img1,])
moving = BatchedImages([img2,])
transform = AffineRegistration([8, 4, 2, 1], [1000, 500, 250, 100], fixed, moving, \
# transform = RigidRegistration([8, 4, 2], [100, 50, 25], fixed, moving, \
    loss_type='cc', optimizer='Adam', optimizer_lr=1e-3, optimizer_momentum=0.0)
moved = transform.optimize(save_transformed=True)

scale: 8, iter: 999/1000, loss: -0.099450: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 333.52it/s]
scale: 4, iter: 288/500, loss: -0.051869:  58%|██████████████████████████████████████████████████████▎                                       | 289/500 [00:00<00:00, 339.14it/s]
scale: 2, iter: 104/250, loss: -0.048560:  42%|███████████████████████████████████████▍                                                      | 105/250 [00:00<00:00, 155.22it/s]
scale: 1, iter: 92/100, loss: -0.046981:  93%|██████████████████████████████████████████████████████████████████████████████████████████▏      | 93/100 [00:03<00:00, 24.46it/s]


In [66]:
browse_image_slices(moved[-1][0, 0].data.cpu().numpy() - img1.array[0, 0].data.cpu().numpy())
# browse_image_slices(moved[-1][0, 0].data.cpu().numpy())

interactive(children=(IntSlider(value=77, description='x', layout=Layout(width='600px'), max=154), IntSlider(v…

In [67]:
browse_image_slices(moved[-1][0, 0].data.cpu().numpy() - img2.array[0, 0].data.cpu().numpy())
# browse_image_slices(moved[-1][0, 0].data.cpu().numpy())

interactive(children=(IntSlider(value=77, description='x', layout=Layout(width='600px'), max=154), IntSlider(v…

In [None]:
# def get_rotation_matrix(rotation, dims, N):
#     if dims == 2:
#         rotmat = torch.zeros((N, 3, 3), device='cuda')
#         rotmat[:, 2, 2] = 1
#         cos, sin = torch.cos(rotation[:, 0]), torch.sin(rotation[:, 0])
#         rotmat[:, 0, 0] = cos
#         rotmat[:, 0, 1] = -sin
#         rotmat[:, 1, 0] = sin
#         rotmat[:, 1, 1] = cos
#     elif self.dims == 3:
#         rotmat = torch.zeros((N, 4, 4), device='cuda')
#         skew = torch.zeros((N, 3, 3), device='cuda')
#         norm = torch.norm(rotation, dim=-1)+1e-8  # [N, 1]
#         angle = norm[:, None, None]
#         skew[:, 0, 1] = -rotation[:, 2]/norm
#         skew[:, 0, 2] = rotation[:, 1]/norm
#         skew[:, 1, 0] = rotation[:, 2]/norm
#         skew[:, 1, 2] = -rotation[:, 0]/norm
#         skew[:, 2, 0] = -rotation[:, 1]/norm
#         skew[:, 2, 1] = rotation[:, 0]/norm
#         rotmat[:, :3, :3] = torch.eye(3, device=self.rotation.device)[None] + torch.sin(angle) * skew + torch.matmul(skew, skew) * (1 - torch.cos(angle))
#         rotmat[:, 3, 3] = 1
#     else:
#         raise ValueError(f"Dimensions {self.dims} not supported")
#     return rotmat

In [9]:
fixed_arrays = fixed()
moving_arrays = moving()
fixed_t2p = fixed.get_torch2phy()
moving_p2t = moving.get_phy2torch()
fixed_size = fixed_arrays.shape[2:]

init_grid = torch.eye(3, 4).to(fixed.device).unsqueeze(0).repeat(fixed.size(), 1, 1)  # [N, dims, dims+1]

In [10]:
scale = 8
size_down = [max(int(s / scale), 32) for s in fixed_size]

In [11]:
size_down

[32, 32, 32]

In [12]:
fixed_image_down = F.interpolate(fixed_arrays, size=size_down, mode=fixed.interpolate_mode, align_corners=True)

In [13]:
browse_image_slices(fixed_image_down[0, 0].data.cpu().numpy())

interactive(children=(IntSlider(value=16, description='x', layout=Layout(width='600px'), max=31), IntSlider(va…

In [14]:
fixed_image_coords = F.affine_grid(init_grid, fixed_image_down.shape, align_corners=True)

In [15]:
fixed_image_coords[0, 0, -1, -1]

tensor([ 1.,  1., -1.], device='cuda:0')

In [16]:
fixed_image_coords_homo = torch.cat([fixed_image_coords, torch.ones(list(fixed_image_coords.shape[:-1]) + [1], device=fixed_image_coords.device)], dim=-1)

In [17]:
fixed_image_coords_homo[0, 4, 0, -1]

tensor([ 1.0000, -1.0000, -0.7419,  1.0000], device='cuda:0')

In [18]:
fixed_image_coords_homo_phy = torch.einsum('ntd, n...d->n...t', fixed_t2p, fixed_image_coords_homo)
print(fixed_t2p.shape, fixed_image_coords_homo.shape)
print(fixed_t2p)

torch.Size([1, 4, 4]) torch.Size([1, 32, 32, 32, 4])
tensor([[[ 119.5000,    0.0000,    0.0000,  119.5000],
         [   0.0000,  119.5000,    0.0000, -119.5000],
         [   0.0000,    0.0000,   77.0000,   77.0000],
         [   0.0000,    0.0000,    0.0000,    1.0000]]], device='cuda:0')


In [19]:
fixed_image_coords_homo_phy[0, 4, 0, -1]

tensor([ 239.0000, -239.0000,   19.8710,    1.0000], device='cuda:0')

In [20]:
moved_image_coords_homo = torch.einsum('ntd, n...d->n...t', moving_p2t, fixed_image_coords_homo_phy)

In [21]:
moved_image_coords_homo[0, 0, 0, -1]

tensor([ 1., -1., -1.,  1.], device='cuda:0')

In [22]:
img1._px2phy, img1._torch2px

(array([[   1.,    0.,    0.,    0.],
        [   0.,    1.,    0., -239.],
        [   0.,    0.,    1.,    0.],
        [   0.,    0.,    0.,    1.]]),
 array([[119.5,   0. ,   0. , 119.5],
        [  0. , 119.5,   0. , 119.5],
        [  0. ,   0. ,  77. ,  77. ],
        [  0. ,   0. ,   0. ,   1. ]]))

In [23]:
print(np.around(img1._px2phy @ img1._torch2px, 2))

[[ 119.5    0.     0.   119.5]
 [   0.   119.5    0.  -119.5]
 [   0.     0.    77.    77. ]
 [   0.     0.     0.     1. ]]


In [24]:
img1._torch2px, img1._px2phy

(array([[119.5,   0. ,   0. , 119.5],
        [  0. , 119.5,   0. , 119.5],
        [  0. ,   0. ,  77. ,  77. ],
        [  0. ,   0. ,   0. ,   1. ]]),
 array([[   1.,    0.,    0.,    0.],
        [   0.,    1.,    0., -239.],
        [   0.,    0.,    1.,    0.],
        [   0.,    0.,    0.,    1.]]))

In [25]:
img1.torch2phy

tensor([[[ 119.5000,    0.0000,    0.0000,  119.5000],
         [   0.0000,  119.5000,    0.0000, -119.5000],
         [   0.0000,    0.0000,   77.0000,   77.0000],
         [   0.0000,    0.0000,    0.0000,    1.0000]]], device='cuda:0')

In [26]:
img1.phy2torch

tensor([[[ 0.0084,  0.0000,  0.0000, -1.0000],
         [ 0.0000,  0.0084,  0.0000,  1.0000],
         [ 0.0000,  0.0000,  0.0130, -1.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]], device='cuda:0')