1. Import Libraries

In [None]:
import registration_utils as rut
import numpy as np
import torch
import torch.nn as nn
import pytorch3d
import matplotlib.pyplot as plt
import matplotlib.image 
import imageio

from PIL import Image
from tqdm.notebook import tqdm
from skimage import color, img_as_ubyte
from scipy.spatial.transform import Rotation
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    FoVPerspectiveCameras, RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex, look_at_view_transform, look_at_rotation)

Set to use GPU

In [None]:
# Set the cuda device 
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    print(torch.cuda.is_available())
else:
    device = torch.device("cpu")

Load Reference Image

DSA AP Image

In [None]:
dsa_ap = rut.xray_read_dt('../data/input/2D_DSA_AP')

plt.figure(figsize=(15, 15))
plt.subplot(1, 3, 1)
plt.imshow(dsa_ap,  cmap='gray')
plt.grid(False)

print(dsa_ap.shape)

# Resize image
dsa_ap = rut.resize_image(dsa_ap, 1000)
plt.subplot(1, 3, 2)
plt.imshow(dsa_ap, cmap='gray')
plt.title("Resized Image")

Segmentation of DSA AP

In [None]:
image_ref = rut.DSA_segmentation(dsa_ap, sigma = 0.1)

print(image_ref.shape)
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.imshow(image_ref, cmap='gray')
plt.grid(False)

# Reversed image
image_ref = 1-image_ref
plt.subplot(1, 2, 2)
plt.imshow(image_ref, cmap='gray')
plt.title("Inverted Image")

Load Lateral Image

In [None]:
dsa_lat = rut.xray_read_dt('../data/input/2D_DSA_LAT')

plt.figure(figsize= (15, 15))
plt.subplot(1, 3, 1)
plt.imshow(dsa_lat,  cmap='gray')  # only plot the alpha channel of the RGBA image
plt.grid(False)

dsa_lat = rut.resize_image(dsa_lat, 1000)
plt.subplot(1, 3, 2)
plt.imshow(dsa_lat, cmap='gray')
plt.title("Resized Image")

image_ref2 = rut.DSA_segmentation(dsa_lat, sigma = 1e-4) # type: ignore

plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.imshow(image_ref2, cmap='gray')
plt.grid(False)

# 반전된 image
image_ref2 = 1-image_ref2
plt.subplot(1, 2, 2)
plt.imshow(image_ref2, cmap='gray')
plt.title("Inverted Image")

In [None]:
# Option to save current segmented image

# png_image = Image.fromarray(image_ref2.astype(np.uint8)*255)
# png_image.save(/../data/input/DSA_LAT_seg.png')
# png_image = Image.fromarray(image_ref.astype(np.uint8)*255)
# png_image.save('../data/input/DSA_AP_seg.png')

In [None]:
# Input manually segmented images
image_ref = (np.array(Image.open('../data/input/DSA_AP_seg.png'))/255.0).astype(np.uint8)
image_ref2 = (np.array(Image.open('../data/input/DSA_LAT_seg.png'))/255.0).astype(np.uint8)

Load Mesh

In [None]:
# Load the obj and ignore the textures and materials.
verts, faces_idx, _ = load_obj("../data/input/dsa_mesh_zaxis_up_trimmed.obj")
faces = faces_idx.verts_idx

# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = TexturesVertex(verts_features=verts_rgb.to(device))

dsa_mesh = Meshes(
    verts=[verts.to(device)],
    faces=[faces.to(device)],
    textures=textures
)

Visualize Mesh

In [None]:
# Initialize a perspective camera.
cameras = FoVPerspectiveCameras(device=device)

# Set parameters which control the opacity and the sharpness of edges
blend_params = BlendParams(sigma=1e-7, gamma=1e-5)

# Define the settings for rasterization and shading. 
raster_settings = RasterizationSettings(
    image_size= image_ref.shape,
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
    faces_per_pixel=100,
    bin_size=0
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader.
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)

# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=image_ref.shape,
    blur_radius=0.0,
    faces_per_pixel=1,
    bin_size=0
)
# We can add a point light in front of the object.
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)


Define Model

In [None]:
class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref, image_ref2, camera_initial_position):

        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer

        self.camera1_weight = 1.0
        self.camera2_weight = 0.0

        self.loss_graph = []

        image_ref = torch.from_numpy(image_ref.astype(np.float32))
        self.register_buffer('image_ref', image_ref)

        image_ref2 = torch.from_numpy(image_ref2.astype(np.float32))
        self.register_buffer('image_ref2', image_ref2)

        # Create an optimizable parameter for the x, y, z position of the camera, at vector, and up vector
        self.camera_position = nn.Parameter(
            torch.from_numpy(np.array([
                [camera_initial_position[0][0].item(), camera_initial_position[0][1].item(), camera_initial_position[0][2].item()],
                [camera_initial_position[1][0].item(), camera_initial_position[1][1].item(), camera_initial_position[1][2].item()],
                [camera_initial_position[2][0].item(), camera_initial_position[2][1].item(), camera_initial_position[2][2].item()]
                ], dtype=np.float32)).to(meshes.device))

    def forward(self):
        # Render the image using the updated camera position. Based on the new position of the
        # camera we calculate the rotation and translation matrices
        # position, at, up
        R = look_at_rotation(self.camera_position[0][None, :], 
                             self.camera_position[1][None, :],
                             self.camera_position[2][None, :],device=self.device)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[0][None, :, None])[:, :, 0]
        
        c1 = self.camera_position[0][None, :].squeeze()
        at = self.camera_position[1][None, :].squeeze()
        up = self.camera_position[2][None, :].squeeze()

        # camera 2 position is rotated 90 degrees from camera 1
        c2 = (at - torch.linalg.cross(up,at-c1))/torch.norm(at - torch.linalg.cross(up,at-c1)) * torch.norm(at-c1)

        # Use 90 degree rotated position, and same at and up vectors as image 1
        R2 = look_at_rotation(c2[None, :],
                            self.camera_position[1][None, :],
                            self.camera_position[2][None, :], device=self.device)
        T2 = -torch.bmm(R2.transpose(1, 2), c2[None, :, None])[:, :, 0]

        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        image2 = self.renderer(meshes_world=self.meshes.clone(), R=R2, T=T2)

        # Calculate the silhouette loss
        loss = torch.sqrt(torch.sum(self.camera1_weight*(image[..., 3] - self.image_ref) ** 2 +
                                    self.camera2_weight*(image2[..., 3] - self.image_ref2) ** 2))
        loss_first_camera = torch.sqrt(torch.sum(self.camera1_weight*(image[..., 3] - self.image_ref) ** 2))
        self.loss_graph.append(loss_first_camera.item())
        return loss, image, image2, loss_first_camera

Initialize Model and Optimizer

In [None]:
# We will save images periodically and compose them into a GIF.
filename_output = "../data/output/real/registration_real_AP.gif"
writer = imageio.get_writer(filename_output, mode='I')
filename_output2 = "../data/output/real/registration_real_LAT.gif"
writer2 = imageio.get_writer(filename_output2, mode='I')

first_camera_initial_position = torch.tensor([[-23.187042236328125, -162.63189697265625, 5.946759223937988],
[0,0,40],
[-0.06765928119421005, 0.031012320891022682, 1.623045802116394]])

# Initialize a model using the renderer, mesh and reference image
model = Model(meshes=dsa_mesh,
              renderer=silhouette_renderer,
              image_ref=image_ref,
              image_ref2=image_ref2,
              camera_initial_position=first_camera_initial_position
              ).to(device)

# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

Visualize Starting Position and Reference Position

In [None]:
plt.figure(figsize=(20, 20))

_, image_init, image_init2, _ = model()
plt.subplot(1, 4, 1)
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.title("Starting position AP")

plt.subplot(1, 4, 2)
plt.imshow(model.image_ref.cpu().numpy().squeeze())
plt.grid(False)
plt.title("Reference silhouette AP")

plt.subplot(1, 4, 3)
plt.imshow(image_init2.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.title("Starting position LAT")

plt.subplot(1, 4, 4)
plt.imshow(model.image_ref2.cpu().numpy().squeeze())
plt.grid(False)
plt.title("Reference silhouette LAT")

In [None]:
# Load untrimmmed mesh for visualization
verts, faces_idx, _ = load_obj("../data/input/dsa_mesh_zaxis_up.obj")
faces = faces_idx.verts_idx
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = TexturesVertex(verts_features=verts_rgb.to(device))

dsa_mesh_for_gif = Meshes(
    verts=[verts.to(device)],
    faces=[faces.to(device)],
    textures=textures
)

Run Optimization

In [None]:
strikes_to_break = 0
strike_out = 10
loss_previous1 = 0.0
second_camera_added_i = 0
second_registration_mode = False # True for both AP + LAT at the same time, False for AP then LAT
if (second_registration_mode == True):
    model.camera1_weight = 1.0
    model.camera2_weight = 1.0
total_iterations = 1000
index_for_gif = 0
fontsize = 20

# Create output file for rotational and translation matrices
with open('../data/output/real/RT_data_real.txt', 'w') as f:

    loop = tqdm(range(total_iterations))
    for i in loop:

        # Breaks if loss increases or loss difference is small strikes_to_break number of times
        if i >= 11 and i % 10 ==0:
            if  loss.item() - loss_previous1 > 10:
                strikes_to_break += 1
                print("strike "+str(strikes_to_break)+" because loss increased at i = "+str(i)+".")

            if abs(loss.item() - loss_previous1) < 0.08:
                strikes_to_break += 1
                print("strike "+str(strikes_to_break)+" because loss difference is small at i = "+str(i)+".")

        if second_registration_mode == True:
            if strikes_to_break == strike_out:
                break
        else:
            if strikes_to_break == strike_out or i == int(total_iterations/2):  
                print("Adding second image to registration.")
                model.camera1_weight = 1.0
                model.camera2_weight = 1.0
                strikes_to_break = 0
                second_registration_mode = True
                second_camera_added_i = i

        if i >= 1 and i % 10 == 0:
            loss_previous1 = loss.item()
        
        optimizer.zero_grad()
        loss, image_sillouette, image_sillouette2, loss_first_camera = model()
        loss.backward()
        optimizer.step()

        loop.set_description('Optimizing (loss %.4f)' % loss.data)    
            
        #Save outputs to create a GIF.
        if i % 10 == 0:
            R = look_at_rotation(model.camera_position[0][None, :],
                                model.camera_position[1][None, :],
                                model.camera_position[2][None, :], device=model.device)
            T = -torch.bmm(R.transpose(1, 2), model.camera_position[0][None, :, None])[:, :, 0]   # (1, 3)

            image = phong_renderer(meshes_world=dsa_mesh_for_gif, R=R, T=T)
            image = image[0, ..., :3].detach().squeeze().cpu().numpy()
            image = img_as_ubyte(image)

            plt.figure(figsize=(10, 10))
            plt.imshow(color.rgb2gray(image), cmap='Greys', alpha=0.5)
            plt.imshow(dsa_ap, cmap='gray', alpha=0.5)
            plt.title("AP Views", size=fontsize)
            plt.axis("off")

            # plt.savefig('../data/output/real/temp_gif_frame_AP.png', bbox_inches='tight')  # Save to a temporary location
            # writer.append_data(imageio.v2.imread('../data/output/real/temp_gif_frame_AP.png'))  # Read and append to GIF

            f.write(str(model.camera_position[0][None, :].cpu().detach().squeeze().tolist()))
            f.write('\n')
            f.write(str(model.camera_position[1][None, :].cpu().detach().squeeze().tolist()))
            f.write('\n')
            f.write(str(model.camera_position[2][None, :].cpu().detach().squeeze().tolist()))
            f.write('\n')

            if(second_registration_mode):
                c1 = model.camera_position[0][None, :].squeeze()
                at = model.camera_position[1][None, :].squeeze()
                up = model.camera_position[2][None, :].squeeze()

                c2 = (at - torch.linalg.cross(up,at-c1))/torch.norm(at - torch.linalg.cross(up,at-c1)) * torch.norm(at-c1)

                # Use 90 degree rotated position, and same at and up vectors as image 1
                R2 = look_at_rotation(c2[None, :],
                                    model.camera_position[1][None, :],
                                    model.camera_position[2][None, :], device=model.device)
                T2 = -torch.bmm(R2.transpose(1, 2), c2[None, :, None])[:, :, 0]

                image2 = phong_renderer(meshes_world=dsa_mesh_for_gif, R=R2, T=T2)
                image2 = image2[0, ..., :3].detach().squeeze().cpu().numpy()
                image2 = img_as_ubyte(image2)

                # LAT
                plt.figure(figsize=(10, 10))
                plt.imshow(color.rgb2gray(image2), cmap='Greys', alpha=0.5)
                plt.imshow(dsa_lat, cmap='gray', alpha=0.5)
                plt.title("LAT Views", size=fontsize)
                plt.axis("off")

                # plt.savefig('../data/output/real/temp_gif_frame_LAT.png', bbox_inches='tight')  # Save to a temporary location
                # writer.append_data(imageio.v2.imread('../data/output/real/temp_gif_frame_LAT.png'))  # Read and append to GIF
            index_for_gif += 1
    writer.close()
    writer2.close()

In [None]:
plt.plot(model.loss_graph)
plt.title('First Camera Loss by Iteration')
plt.xlabel('Iteration')
plt.ylabel('First Camera Loss')
plt.annotate('Second Camera Added', 
            xy=(second_camera_added_i, model.loss_graph[second_camera_added_i]+10), 
            xytext=(second_camera_added_i, model.loss_graph[second_camera_added_i]+30), 
            arrowprops = dict(facecolor='black', shrink=0.05))