In [2]:
import SimpleITK as sitk
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import cv2


In [3]:
def fill_blank(img):
    state = sitk.Image([img.GetSize()[0], img.GetSize()[1]], sitk.sitkUInt16) * 0
    out_ = []
    for i in range(img.GetSize()[2]):
        sl = sitk.Extract(img, [img.GetSize()[0], img.GetSize()[1], 0], [0, 0, i])
        sl.SetOrigin([0., 0.])
        sl.SetSpacing([1, 1])
        zero_area = sitk.LessEqual(sl, 90)
        zero_area = sitk.Cast(zero_area, sitk.sitkUInt16)
        state = state * zero_area + sl * sitk.Not(zero_area)
        out_.append(state)
    out = []
    state = sitk.Image([img.GetSize()[0], img.GetSize()[1]], sitk.sitkUInt16) * 0
    for i in range(img.GetSize()[2] - 1, -1, -1):
        sl = out_[i]
        sl.SetOrigin([0., 0.])
        sl.SetSpacing([1, 1])
        zero_area = sitk.LessEqual(sl, 90)
        zero_area = sitk.Cast(zero_area, sitk.sitkUInt16)
        state = state * zero_area + sl * sitk.Not(zero_area)
        out.append(state)
    out.reverse()
    return sitk.JoinSeries(out)

def calc_surface_height_map(img: sitk.Image, slice_thickness=300, internal_pixel_size=4, internal_downsample=(2, 2, 1)):
    k_surface_band = cv2.getGaussianKernel(3, -1)
    k_surface_band = np.matmul(k_surface_band, np.transpose(k_surface_band))
    k_surface_band[1, 1] = k_surface_band[1, 1] - np.sum(k_surface_band)
    k_surface_band = np.array([[k_surface_band]])
    k_surface_band = torch.FloatTensor(k_surface_band)
    k_surface_band = Variable(k_surface_band)
    k_surface_band_c = k_surface_band.cuda()

    k_grad = np.float32([[[0.5, 1, 2, 4, 8, 0, -8, -4, -2, -1, 0.5]]])
    k_grad = np.transpose(k_grad, [2, 1, 0])
    k_grad = sitk.GetImageFromArray(k_grad)

    k_grad2 = np.float32([[[1, 0, -1]]])
    k_grad2 = np.transpose(k_grad2, [2, 1, 0])
    k_grad2 = sitk.GetImageFromArray(k_grad2)

    thickness = int(slice_thickness / internal_pixel_size / internal_downsample[2])
    surface_pixel_size = float(internal_pixel_size / np.sqrt(internal_downsample[0] * internal_downsample[1]))
    img.SetSpacing([1, 1, 1])
    img.SetOrigin([0, 0, 0])
    size_ = img.GetSize()
    tf = sitk.AffineTransform(3)
    i_ = list(internal_downsample)
    tf.Scale(i_)
    proc_size = [int(size_[i] / i_[i]) for i in range(3)]
    img = sitk.Resample(img, proc_size, tf)
    img = fill_blank(img)
    img = sitk.Cast(img, sitk.sitkFloat32)
    img = (sitk.Log(img) - 4.6) * 39.4
    img = sitk.Clamp(img, sitk.sitkFloat32, 0, 255)

    def get_edge_grad(img_: sitk.Image, ul):
        grad_m = sitk.Convolution(img_, k_grad)
        if ul == 1:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, 0, 65535)
        else:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, -65535, 0)
        grad_m = ul * sitk.Convolution(grad_m, k_grad2)
        grad_m = sitk.GetArrayFromImage(grad_m)
        grad_m = torch.FloatTensor(grad_m)
        grad_m = grad_m.cuda()
        return grad_m

    u_grad_m = get_edge_grad(img, 1)
    l_grad_m = get_edge_grad(img, -1)

    shape = u_grad_m.shape
    u = (torch.rand(1, 1, shape[1], shape[2]) * (shape[0]) / 2 * 0)
    l = (shape[0] + 0 * torch.rand(1, 1, shape[1], shape[2]) * (shape[0]) / 2)

    lr = 0.001
    momentum = 0.9
    lr_decay = 0.0001
    u_grad = torch.zeros(1, 1, shape[1], shape[2])
    l_grad = torch.zeros(1, 1, shape[1], shape[2])
    k = k_surface_band
    u = u.cuda()
    l = l.cuda()
    u_grad = u_grad.cuda()
    l_grad = l_grad.cuda()
    k = k_surface_band_c
    grid = torch.meshgrid(torch.Tensor([0.0]).float().cuda(),
                          torch.linspace(-1, 1, shape[1]).float().cuda(),
                          torch.linspace(-1, 1, shape[2]).float().cuda())
    grid = torch.stack([grid[2 - i] for i in range(3)], 3)[None,]

    u_grad_m = u_grad_m[None, None,]
    l_grad_m = l_grad_m[None, None,]
    gu, gl = None, None
    for i in range(8000):
        def calc_grad(s, grad, grad_m, ul):
            g = torch.clamp(s[0], 0, shape[0] - 1)
            u_plane_bend = F.pad(s, (1, 1, 1, 1), mode='reflect')
            u_plane_bend = F.conv2d(u_plane_bend, k, padding=0)[0].data
            grid[:, :, :, :, 2].copy_(g * (2 / shape[0]) - 1)
            u_edge = F.grid_sample(grad_m, grid)[0]
            grad = u_edge + \
                   (1200 * surface_pixel_size) * u_plane_bend + \
                   0.004 * thickness * ul * torch.clamp(l - u - thickness, -1000, 10) + \
                   momentum * grad
            return grad, g

        u_grad, gu = calc_grad(u, u_grad, u_grad_m, 1)
        l_grad, gl = calc_grad(l, l_grad, l_grad_m, -1)
        u += lr * u_grad
        l += lr * l_grad
        u = torch.clamp(u, 0, shape[0] - 1)
        l = torch.clamp(l, 0, shape[0] - 1)
        lr *= (1 - lr_decay)

    umap = np.float32((u * internal_downsample[2] + 0.5).cpu().numpy()[0][0])
    lmap = np.float32((l * internal_downsample[2] - 1.5).cpu().numpy()[0][0])
    umap = cv2.resize(umap, (size_[0], size_[1]), interpolation=cv2.INTER_CUBIC)
    lmap = cv2.resize(lmap, (size_[0], size_[1]), interpolation=cv2.INTER_CUBIC)
    umap = sitk.GetImageFromArray(umap)
    lmap = sitk.GetImageFromArray(lmap)

    return umap, lmap

import cv2
import numpy as np
import SimpleITK as sitk

def calc_surface_height_map_CPU(img: sitk.Image, slice_thickness=300, internal_pixel_size=4, internal_downsample=(2, 2, 1)):
    k_surface_band = cv2.getGaussianKernel(3, -1)
    k_surface_band = np.matmul(k_surface_band, np.transpose(k_surface_band))
    k_surface_band[1, 1] = k_surface_band[1, 1] - np.sum(k_surface_band)
    k_surface_band = np.array([[k_surface_band]], dtype=np.float32)

    k_grad = np.float32([[[0.5, 1, 2, 4, 8, 0, -8, -4, -2, -1, 0.5]]])
    k_grad = np.transpose(k_grad, [2, 1, 0])
    k_grad = sitk.GetImageFromArray(k_grad)

    k_grad2 = np.float32([[[1, 0, -1]]])
    k_grad2 = np.transpose(k_grad2, [2, 1, 0])
    k_grad2 = sitk.GetImageFromArray(k_grad2)

    thickness = int(slice_thickness / internal_pixel_size / internal_downsample[2])
    surface_pixel_size = float(internal_pixel_size / np.sqrt(internal_downsample[0] * internal_downsample[1]))
    img.SetSpacing([1, 1, 1])
    img.SetOrigin([0, 0, 0])
    size_ = img.GetSize()
    tf = sitk.AffineTransform(3)
    i_ = list(internal_downsample)
    tf.Scale(i_)
    proc_size = [int(size_[i] / i_[i]) for i in range(3)]
    img = sitk.Resample(img, proc_size, tf)
    img = fill_blank(img)  # Assuming `fill_blank` is defined elsewhere
    img = sitk.Cast(img, sitk.sitkFloat32)
    img = (sitk.Log(img) - 4.6) * 39.4
    img = sitk.Clamp(img, sitk.sitkFloat32, 0, 255)

    def get_edge_grad(img_: sitk.Image, ul):
        grad_m = sitk.Convolution(img_, k_grad)
        if ul == 1:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, 0, 65535)
        else:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, -65535, 0)
        grad_m = ul * sitk.Convolution(grad_m, k_grad2)
        grad_m = sitk.GetArrayFromImage(grad_m)
        return grad_m

    u_grad_m = get_edge_grad(img, 1)
    l_grad_m = get_edge_grad(img, -1)

    shape = u_grad_m.shape
    u = np.random.rand(1, 1, shape[1], shape[2]) * (shape[0]) / 2 * 0
    l = shape[0] + 0 * np.random.rand(1, 1, shape[1], shape[2]) * (shape[0]) / 2

    lr = 0.001
    momentum = 0.9
    lr_decay = 0.0001
    u_grad = np.zeros((1, 1, shape[1], shape[2]), dtype=np.float32)
    l_grad = np.zeros((1, 1, shape[1], shape[2]), dtype=np.float32)

    def calc_grad(s, grad, grad_m, ul):
        g = np.clip(s[0], 0, shape[0] - 1)
        u_plane_bend = np.pad(s, ((0, 0), (0, 0), (1, 1), (1, 1)), mode='reflect')
        u_plane_bend = cv2.filter2D(u_plane_bend[0, 0], -1, k_surface_band[0, 0])
        u_edge = np.interp(g, np.linspace(0, shape[0] - 1, shape[0]), grad_m[0, 0])
        grad = u_edge + \
               (1200 * surface_pixel_size) * u_plane_bend + \
               0.004 * thickness * ul * np.clip(l - u - thickness, -1000, 10) + \
               momentum * grad
        return grad, g

    for i in range(8000):
        u_grad, gu = calc_grad(u, u_grad, u_grad_m, 1)
        l_grad, gl = calc_grad(l, l_grad, l_grad_m, -1)
        u += lr * u_grad
        l += lr * l_grad
        u = np.clip(u, 0, shape[0] - 1)
        l = np.clip(l, 0, shape[0] - 1)
        lr *= (1 - lr_decay)

    umap = np.float32((u * internal_downsample[2] + 0.5)[0, 0])
    lmap = np.float32((l * internal_downsample[2] - 1.5)[0, 0])
    umap = cv2.resize(umap, (size_[0], size_[1]), interpolation=cv2.INTER_CUBIC)
    lmap = cv2.resize(lmap, (size_[0], size_[1]), interpolation=cv2.INTER_CUBIC)
    umap = sitk.GetImageFromArray(umap)
    lmap = sitk.GetImageFromArray(lmap)

    return umap, lmap


In [4]:
# Assuming calc_surface_height_map is already defined

# Load your image data
input_image_path = '/home/wanqing.yu/AC_Project/ac_materialization/flattening_code/mip4.tif'
img = sitk.ReadImage(input_image_path)
img

# # Call the function
# umap, lmap = calc_surface_height_map_CPU(img)

# # Save the results
# output_upper_surface_path = 'upper_surface.tif'
# output_lower_surface_path = 'lower_surface.tif'

# sitk.WriteImage(umap, output_upper_surface_path)
# sitk.WriteImage(lmap, output_lower_surface_path)

# print(f"Upper surface saved to {output_upper_surface_path}")
# print(f"Lower surface saved to {output_lower_surface_path}")

<SimpleITK.SimpleITK.Image; proxy of <Swig Object of type 'itk::simple::Image *' at 0x7f7e58a16d90> >

In [5]:
# def calc_surface_height_map_CPU(img: sitk.Image, slice_thickness=300, internal_pixel_size=4, internal_downsample=(2, 2, 1)):

k_surface_band = cv2.getGaussianKernel(3, -1)
k_surface_band = np.matmul(k_surface_band, np.transpose(k_surface_band))
k_surface_band[1, 1] = k_surface_band[1, 1] - np.sum(k_surface_band)
k_surface_band = np.array([[k_surface_band]], dtype=np.float32)


In [None]:
k_grad = np.float32([[[0.5, 1, 2, 4, 8, 0, -8, -4, -2, -1, 0.5]]])
k_grad = np.transpose(k_grad, [2, 1, 0])
k_grad = sitk.GetImageFromArray(k_grad)

k_grad2 = np.float32([[[1, 0, -1]]])
k_grad2 = np.transpose(k_grad2, [2, 1, 0])
k_grad2 = sitk.GetImageFromArray(k_grad2)

thickness = int(slice_thickness / internal_pixel_size / internal_downsample[2])
surface_pixel_size = float(internal_pixel_size / np.sqrt(internal_downsample[0] * internal_downsample[1]))
img.SetSpacing([1, 1, 1])
img.SetOrigin([0, 0, 0])
size_ = img.GetSize()
tf = sitk.AffineTransform(3)
i_ = list(internal_downsample)
tf.Scale(i_)
proc_size = [int(size_[i] / i_[i]) for i in range(3)]
img = sitk.Resample(img, proc_size, tf)
img = fill_blank(img)  # Assuming `fill_blank` is defined elsewhere
img = sitk.Cast(img, sitk.sitkFloat32)
img = (sitk.Log(img) - 4.6) * 39.4
img = sitk.Clamp(img, sitk.sitkFloat32, 0, 255)

def get_edge_grad(img_: sitk.Image, ul):
    grad_m = sitk.Convolution(img_, k_grad)
    if ul == 1:
        grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, 0, 65535)
    else:
        grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, -65535, 0)
    grad_m = ul * sitk.Convolution(grad_m, k_grad2)
    grad_m = sitk.GetArrayFromImage(grad_m)
    return grad_m

u_grad_m = get_edge_grad(img, 1)
l_grad_m = get_edge_grad(img, -1)

shape = u_grad_m.shape
u = np.random.rand(1, 1, shape[1], shape[2]) * (shape[0]) / 2 * 0
l = shape[0] + 0 * np.random.rand(1, 1, shape[1], shape[2]) * (shape[0]) / 2

lr = 0.001
momentum = 0.9
lr_decay = 0.0001
u_grad = np.zeros((1, 1, shape[1], shape[2]), dtype=np.float32)
l_grad = np.zeros((1, 1, shape[1], shape[2]), dtype=np.float32)

def calc_grad(s, grad, grad_m, ul):
    g = np.clip(s[0], 0, shape[0] - 1)
    u_plane_bend = np.pad(s, ((0, 0), (0, 0), (1, 1), (1, 1)), mode='reflect')
    u_plane_bend = cv2.filter2D(u_plane_bend[0, 0], -1, k_surface_band[0, 0])
    u_edge = np.interp(g, np.linspace(0, shape[0] - 1, shape[0]), grad_m[0, 0])
    grad = u_edge + \
            (1200 * surface_pixel_size) * u_plane_bend + \
            0.004 * thickness * ul * np.clip(l - u - thickness, -1000, 10) + \
            momentum * grad
    return grad, g

for i in range(8000):
    u_grad, gu = calc_grad(u, u_grad, u_grad_m, 1)
    l_grad, gl = calc_grad(l, l_grad, l_grad_m, -1)
    u += lr * u_grad
    l += lr * l_grad
    u = np.clip(u, 0, shape[0] - 1)
    l = np.clip(l, 0, shape[0] - 1)
    lr *= (1 - lr_decay)

umap = np.float32((u * internal_downsample[2] + 0.5)[0, 0])
lmap = np.float32((l * internal_downsample[2] - 1.5)[0, 0])
umap = cv2.resize(umap, (size_[0], size_[1]), interpolation=cv2.INTER_CUBIC)
lmap = cv2.resize(lmap, (size_[0], size_[1]), interpolation=cv2.INTER_CUBIC)
umap = sitk.GetImageFromArray(umap)
lmap = sitk.GetImageFromArray(lmap)