In [25]:
import torch
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import interact, FloatSlider, IntSlider
from torch_transform_image import (
    affine_transform_image_2d, 
    affine_transform_image_3d, 
    rotate_then_shift_image_2d, 
    shift_then_rotate_image_2d, 
    rotate_then_shift_image_3d,
    # shift_then_rotate_image_3d,
)

In [2]:
def master_plot_function(
        volume: torch.Tensor, 
        gray_min: float | None = None, 
        gray_max: float | None = None,
        gray_mid: float | None = None,
):
    gray_min = volume.min().item() if gray_min is None else gray_min
    gray_max = volume.max().item() if gray_max is None else gray_max
    gray_mid = (gray_min + gray_max) / 2 if gray_mid is None else gray_mid

    plane_min = 0
    plane_max = volume.shape[2] - 1
    plane_mid = int(plane_max / 2) 

    # Wrapped function that captures the volume tensor
    def plot_wrapper(zmin, zplane):
        fig = go.Figure(
            data=go.Heatmap(
                z=volume[:, :, zplane],
                colorscale="viridis",
                zmin=zmin,
                zmax=volume.max().item(),
            )
        )
        fig.update_layout(
            title=f"Slice {zplane} with threshold zmin={zmin:.2f}", width=800, height=600
        )
        return fig

    interact(
        plot_wrapper,
        zmin=FloatSlider(
            min=gray_min,
            max=gray_max,
            step=(gray_max - gray_min) / 100,
            value=gray_mid,
            description="Threshold:",
            continuous_update=True,
        ),
        zplane=IntSlider(
            min=plane_min,
            max=plane_max,
            step=1,
            value=plane_mid,
            description="Plane:",
            continuous_update=True,
        ),
    )

In [3]:
image = torch.zeros((28, 28, 28), dtype=torch.float32)
image[18, 13, 13] = 1
image = image.float()

In [4]:
max_idx = torch.nonzero(image == image.max())
print("Location(s) of maximum value in result:", max_idx.tolist())
master_plot_function(image)

Location(s) of maximum value in result: [[18, 13, 13]]


interactive(children=(FloatSlider(value=0.5, description='Threshold:', max=1.0, step=0.01), IntSlider(value=13…

In [5]:
result = rotate_then_shift_image_3d(
    image=image,
    rotate_zyx=[0, 0, 0],
    shift_zyx=[1, 5, 5], # order is y, x, z
    interpolation="trilinear",
)
max_idx = torch.nonzero(result == result.max())
print("Location(s) of maximum value in result:", max_idx.tolist())
master_plot_function(result)

Location(s) of maximum value in result: [[19, 18, 18]]


interactive(children=(FloatSlider(value=0.5, description='Threshold:', max=1.0, step=0.01), IntSlider(value=13…

In [24]:
accum = torch.zeros_like(image)
for a in range(0,360,10):
    result = rotate_then_shift_image_3d(
        image=image,
        rotate_zyx=[0, 0, a], # order is z, y, x
        shift_zyx=[0, 0, 0], # order is y, x, z
        interpolation="trilinear",
    )
    accum += result
max_idx = torch.nonzero(result == result.max())
print("Location(s) of maximum value in result:", max_idx.tolist())
master_plot_function(accum, gray_mid=0)

Location(s) of maximum value in result: [[18, 14, 13]]


interactive(children=(FloatSlider(value=0.0, description='Threshold:', max=1.5029011964797974, step=0.01502901…

In [7]:
from torch_affine_utils.transforms_3d import Rx, Ry, Rz, T

In [8]:
Rz_20 = Rz(20)
print(Rz_20)

tensor([[[ 0.9397, -0.3420,  0.0000,  0.0000],
         [ 0.3420,  0.9397,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  1.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]])
