# Rotate and shift 3d examples

In [1]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import interact, FloatSlider, IntSlider
from torch_transform_image import (
    rotate_then_shift_image_3d,
    shift_then_rotate_image_3d,
)

## Setup

In [2]:
def master_plot_function(
        volume: torch.Tensor, 
        gray_min: float | None = None, 
        gray_max: float | None = None,
        gray_mid: float | None = None,
):
    max_idx = torch.nonzero(volume == volume.max())
    print("Location(s) of maximum value in result:", max_idx.tolist())

    gray_min = volume.min().item() if gray_min is None else gray_min
    gray_max = volume.max().item() if gray_max is None else gray_max
    gray_mid = (gray_min + gray_max) / 2 if gray_mid is None else gray_mid

    plane_min = 0
    plane_max = volume.shape[2] - 1
    plane_mid = int(plane_max / 2) + 1

    # Wrapped function that captures the volume tensor
    def plot_wrapper(zmin, zplane):
        fig = go.Figure(
            data=go.Heatmap(
                z=volume[zplane, :, :], # because tensors are [ z, y, x ]
                colorscale="viridis",
                zmin=zmin,
                zmax=volume.max().item(),
            )
        )
        fig.update_layout(
            title=f"Slice {zplane} with threshold zmin={zmin:.2f}", width=800, height=600
        )
        return fig

    interact(
        plot_wrapper,
        zmin=FloatSlider(
            min=gray_min,
            max=gray_max,
            step=(gray_max - gray_min) / 100,
            value=gray_mid,
            description="Threshold:",
            continuous_update=True,
        ),
        zplane=IntSlider(
            min=plane_min,
            max=plane_max,
            step=1,
            value=plane_mid,
            description="Plane:",
            continuous_update=True,
        ),
    )

In [3]:
image = torch.zeros((28, 28, 28), dtype=torch.float32)
image[14, 7, 14] = 1
image = image.float()

In [4]:
center_dot = torch.zeros((28, 28, 28), dtype=torch.float32)
center_dot[14, 14, 14] = 1
center_dot = center_dot.float()

In [5]:
master_plot_function(image + center_dot*0.1, gray_mid=0)

Location(s) of maximum value in result: [[14, 7, 14]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.0, step=0.01), IntSlider(value=14…

## Order of operations

In [6]:
rotate_only = shift_then_rotate_image_3d(
    image=image,
    rotate_zyx=[90, 0, 0],
    shifts_zyx=[0, 0, 0],
)
shift_only = shift_then_rotate_image_3d(
    image=image,
    rotate_zyx=[0, 0, 0],
    shifts_zyx=[0, 0, 5],
)
shift_then_rotate = shift_then_rotate_image_3d(
    image=image,
    rotate_zyx=[90, 0, 0],
    shifts_zyx=[0, 0, 5],
)
rotate_then_shift = rotate_then_shift_image_3d(
    image=image,
    rotate_zyx=[90, 0, 0],
    shifts_zyx=[0, 0, 5],
)

In [7]:
images = [rotate_only, shift_only, shift_then_rotate, rotate_then_shift]
titles = ["rotate_only", "shift_only", "shift_then_rotate", "rotate_then_shift"]

fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=titles,
)
fig.update_layout(
    width=1000,
    height=1000,
    title="<sup>Starting location = cyan<br>Center = purple<br>Final location = yellow"
)

for idx, img in enumerate(images):
    img = img + center_dot * 0.1 + image * 0.5 # Showing original point and center as lighter points
    fig.add_trace(
        go.Heatmap(
            z=img[14, :, :],
            colorscale="viridis",
        ),
        row = idx // 2 + 1, 
        col = idx % 2 + 1,
    )
fig.show()

Order looks correct

## Axis order

In [8]:
order_image = torch.zeros((28, 28, 28), dtype=torch.float32)
order_image[14, 10, 13] = 1
order_image = order_image.float()

### Shift

Showing original location in cyan, final location in yellow. Order is `[ z, y, x ]`.

[n,0,0] shifts in z

In [9]:
result = rotate_then_shift_image_3d(image=order_image, shifts_zyx=[5, 0, 0])
master_plot_function(result + order_image * 0.5, gray_mid=0)

Location(s) of maximum value in result: [[19, 10, 13]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.0, step=0.01), IntSlider(value=14…

[0,n,0] shifts in y

In [10]:
result = rotate_then_shift_image_3d(image=order_image, shifts_zyx=[0, 5, 0])
master_plot_function(result + order_image * 0.5, gray_mid=0)

Location(s) of maximum value in result: [[14, 15, 13]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.0, step=0.01), IntSlider(value=14…

[0,0,n] shifts in x

In [11]:
result = rotate_then_shift_image_3d(image=order_image, shifts_zyx=[0, 0, 5])
master_plot_function(result + order_image * 0.5, gray_mid=0)

Location(s) of maximum value in result: [[14, 10, 18]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.0, step=0.01), IntSlider(value=14…

Thus, order appears to be `[ z, y, x ]` for shifts

### Rotate

Rotate [n,0,0] = around z axis

In [12]:
accum = torch.zeros_like(order_image)
for a in range(0,360,10):
    result = rotate_then_shift_image_3d(
        image=order_image,
        rotate_zyx=[a, 0, 0],
        shifts_zyx=[0, 0, 0],
    )
    accum += result
master_plot_function(accum, gray_mid=0)

Location(s) of maximum value in result: [[14, 15, 10]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.5029023885726929, step=0.01502902…

Rotate [0,n,0] = around y axis

In [13]:
accum = torch.zeros_like(order_image)
for a in range(0,360,10):
    result = rotate_then_shift_image_3d(
        image=order_image,
        rotate_zyx=[0, a, 0],
        shifts_zyx=[0, 0, 0],
    )
    accum += result
master_plot_function(accum, gray_mid=0)

Location(s) of maximum value in result: [[14, 10, 13]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=5.758772850036621, step=0.057587728…

Rotate [0,0,n] = around x axis

In [14]:
accum = torch.zeros_like(order_image)
for a in range(0,360,10):
    result = rotate_then_shift_image_3d(
        image=order_image,
        rotate_zyx=[0, 0, a],
        shifts_zyx=[0, 0, 0],
    )
    accum += result
master_plot_function(accum, gray_mid=0)

Location(s) of maximum value in result: [[14, 18, 13]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.5736968517303467, step=0.01573696…

Thus, order appears to be `[ z, x, x ]` for rotation

## Analyzing `test_rotate_shift_image_3d()`

In [49]:
test_image = torch.zeros((28, 28, 28), dtype=torch.float32)
test_image[14, 8, 10] = 1
test_image = test_image.float()

result = rotate_then_shift_image_3d(
    image=test_image,
    rotate_zyx=[90, 0, 0],
    shifts_zyx=[0, 0, 5],
    interpolation="trilinear",
)
assert test_image[14, 14, 25] == 0
assert torch.allclose(result[14, 10, 25], torch.tensor(1.0), atol=1e-6)
assert result[14, 8, 10] == 0

In [50]:
master_plot_function(result+test_image*.5, gray_mid=0)

Location(s) of maximum value in result: [[14, 10, 25]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.0, step=0.01), IntSlider(value=14…