# Part 3 - Using RAPIDS and Accelerated tools to Align Spatial Transcriptomics Images

When it comes to Spatial Transciptomics, there is usually a need to align channels that contain different information, such as Z-planes, DAPI and H&E stained tissue. DAPI (or, to use its less catchy name, 4′,6-diamidino-2-phenylindole) is a fluorescent stain channel whilst H&E (Hematoxylin and Eosin) staining picks out cell nuclei and extracellular material respectively.
During the sample preparation process adjacent slices may be stained with these two methods and images acquired by a slide scanner. However, the resulting images will need registering to counter the effects distortions that might occur during the preparation and scanning process.

This presents a few challenges:
* The images are very large
* The distortions and not rigid/affine
* The images are quite dissimilar

Nevertheless, it is possible to register these images with high accuracy, albeit after a lot or expertly annotated co-registration anchor points have been defined, which is extremely labour intensive, and then some intensive computation to actually perform the image transformation.

In this part of the lab we are going to consider some different techniques to perform this alignment process which could be used to expedite this process.

We could use generic cell segmentor, such as those in the previous notebooks, to identify potential cell outlines and then compute the centroids of each cell for each of the images. Although the two sets of centroids will not match exactly, it should be possible to find a best-fit that could generate a vector map with which to register the images.

We won't have time in this lab to explore more sophisticated Deep Learning approaches but hopefully you will have been given sufficient momentum to try out your own experiments. However, it is also expected that future workshops will cover this topic, so stay tuned...

Let's create an image which we are going to deform and a deformation vector and see how we can apply this to the image


In [None]:
from matplotlib import pyplot as plt
from timeit import default_timer as timer
import numpy as np

# create the 2D image
imgrid = np.zeros((256,256,3),dtype=np.uint8)
imgrid[:,:,:] = 255

# draw some lines on the image
for x in range(1,13):
    imgrid[20*x,:,0:2] = 0
    imgrid[:,20*x,0:2] = 0

plt.imshow(imgrid)

Imagine that we have pairs of coordinates, that represent: expected coordinates of features; and the actual location of these features. The difference between the two is a vector and we can show this vector on the image

In [None]:
# Create a grid of coordinates, from 0-250 with 26 intervals
x = np.linspace(0, 250, num=26)
y = np.linspace(0, 250, num=26)
X, Y = np.meshgrid(x, y)

# create a vertical and horizontal value at each grid point
vx = np.zeros(X.shape)
vy = np.zeros(X.shape)

# create a sparse set of points with x and y values
vectors = np.array([[5.,7.],[-5.5,-3.3],[-7.5,3.]])
coords = np.array([[20,20],[180,50],[20, 180]])

# populate the vector and scale the coordinates
for i,c in enumerate(coords):
    vx[c[0]//10][c[1]//10]=vectors[i,0]
    vy[c[0]//10][c[1]//10]=vectors[i,1]

# creating plot
fig, ax = plt.subplots(figsize =(5, 5))
ax.quiver(X, Y, vx, vy,angles='xy', scale_units='xy', scale=1)
 
ax.axis([0, 250, 0, 250])
ax.set_aspect('equal')
 
# show plot
plt.show()

This plot shows the vector we just created mapped by its coordinates. If we know that an image needs to be adjusted by this vector at these locations then it makes sense to smoothly interpolate the vectors for the points between these coordinates. There are a number of algorithms that we can use for this and quite a few tools support various implementations. SciPy has a linear interpolator that we could use here. 

We can treat each axis separately and follow these steps:
* Create a grid of values that equate to the locations of the pixels in the image that we want to warp, using linspace and meshgrid
* Create an interpolator, using the vectors we defined and the grid
* Generate the vectors for the whole grid
* Plot them out


To make things a little more clear, we can create a function to map the vector direction to a specific color

In [None]:
import matplotlib.colors

def vector_to_rgb(x,y):
    
    angle = np.arctan2(y,x)

    # normalize angle
    angle = angle % (2 * np.pi)
    if angle < 0:
        angle += 2 * np.pi

    return matplotlib.colors.hsv_to_rgb((angle / 2 / np.pi, 1, 0.5))

In [None]:
from scipy.interpolate import LinearNDInterpolator

# interpolate the x-axis values
interpx = LinearNDInterpolator(list(coords), vectors[:,0])
Zx = interpx(Y, X)

# interpolate the y-axis values
interpy = LinearNDInterpolator(list(coords), vectors[:,1])
Zy = interpy(Y, X)

# creating plot
fig, ax = plt.subplots(figsize =(5, 5))
c1 = np.array(list(map(vector_to_rgb, Zx.flatten(),Zy.flatten())))
ax.quiver(X, Y, Zx, Zy,angles='xy', scale_units='xy', scale=1, color=c1)
 
ax.axis([0, 255, 0, 255])
ax.set_aspect('equal')
 
# show plot
plt.show()

You will notice that the interpolation is only done for the points that lie within the region defined by the vector locations. This begs the question: What happens to the rest of the points? Well, they could be extrapolated, but this can give unexpected results since there is not enough information to work from for points that live outside the interpolation region. There are a few approaches that could be taken here:
* Set each pixel to the vector of its nearest neighbour
* Set each vector at the extremities of the image to 0,0 and interpolate the unspecified values
* Use the gradient of the vector to extend the values to the borders
* Set all undefined values to a fixed value

There are pros and cons of each of these approaches and it may not matter too much in the real world since it is likely that you would have defined anchor points within all regions of the image that actually matter and so, any that fall outside this region are probably background regions of less relevance. Nevertheless, it's useful to understand the implications of this because some of the choices above could lead to unexpected results. Let's try a couple of options to illustrate this. First, we will set the points at the corners of the image to 0 and see what happens:

In [None]:
# create some vectors at certain points
vectors = np.array([[5.,7.],[-5.5,-3.3],[-7.5,3.],[0,0],[0,0],[0,0], [0,0]])
coords = np.array([[20,20],[180,50],[20, 180],[0,0],[0,255],[255,0], [255,255]])

interpx = LinearNDInterpolator(list(coords), vectors[:,0])
Zxi = interpx(Y, X)

interpy = LinearNDInterpolator(list(coords), vectors[:,1])
Zyi = interpy(Y, X)

# creating plot
fig, ax = plt.subplots(figsize =(5, 5))
c1 = np.array(list(map(vector_to_rgb, Zxi.flatten(),Zyi.flatten())))
ax.quiver(X, Y, Zxi, Zyi,angles='xy', scale_units='xy', scale=1, color=c1)
 
ax.axis([0, 255, 0, 255])
ax.set_aspect('equal')
 
# show plot
plt.show()

So there are many different interpolation methods and each has its own characteristics. For example, the one below uses a Clough-Tocher scheme, which has somewhat different characteristics 

In [None]:
from scipy.interpolate import CloughTocher2DInterpolator as CT
# create some vectors at certain points

interpx = CT(list(coords), vectors[:,0])
Zxi = interpx(Y, X)

interpy = CT(list(coords), vectors[:,1])
Zyi = interpy(Y, X)

# creating plot
fig, ax = plt.subplots(figsize =(5, 5))
c1 = np.array(list(map(vector_to_rgb, Zxi.flatten(),Zyi.flatten())))
ax.quiver(X, Y, Zxi, Zyi,angles='xy', scale_units='xy', scale=1, color=c1)
 
ax.axis([0, 255, 0, 255])
ax.set_aspect('equal')
 
# show plot
plt.show()

Of course, detecting the global transformation needed is only part of the solution to the registration problem - the other part of it is to actually deform the source image to map the destination. To do this we need to use a warp function. Again, there are many algorithms out there that can be used, such as those found in skimage's transform class.

In [None]:
from timeit import default_timer as timer
import matplotlib.pyplot as plt
from skimage.transform import PiecewiseAffineTransform, warp
    
step_size = 20
vectors = np.array([[3.0,1.0],[-5.,-1.3],[-3.5,8.3],[0,0],[0,0],[0,0], [0,0]])
coords = np.array([[20,20],[180,50],[20, 180],[0,0],[0,255],[255,0], [255,255]])

# Creating grid
x = np.linspace(0, 255, num=step_size)
y = np.linspace(0, 255, num=step_size)
X, Y = np.meshgrid(x, y)

interpx = LinearNDInterpolator(list(coords), vectors[:,0])
Zxi = interpx(Y, X)

interpy = LinearNDInterpolator(list(coords), vectors[:,1])
Zyi = interpy(Y, X)

src = np.column_stack((X.reshape(-1), Y.reshape(-1)))

# add the interpolated offets
dst_rows = X + Zxi
dst_cols = Y + Zyi

dst = np.column_stack([dst_cols.reshape(-1), dst_rows.reshape(-1)])

tform = PiecewiseAffineTransform()

start = timer()
tform.estimate(src, dst)
end = timer()
t = end-start
print("cpu estimate took {}s".format(t))
start = timer()
out = warp(imgrid, tform, output_shape=(255, 255))
end = timer()
t = end-start
print("cpu warp took {}s".format(t))

fig, ax = plt.subplots()
ax.imshow(out)
ax.axis((0, 255, 255, 0))
plt.show()

Next we will create a fictitious transformation and apply it to a real image to see the problem we might be faced with. What we will do it to create a small rotation to each coordinate, but will add some random noise as well. We will then sample some points from this function, generate the transform that best fits this data and apply it to the H & E image from a pair of H & E and DAPI images that have been pre-registered. This will give us a 'known' tranformation that we can try to detect and correct.

In [None]:
# generate a grid
y = np.linspace(-3, 3)
x = np.linspace(-3, 3)
X, Y = np.meshgrid(x, y)

# vector direction based on position
p = np.arctan2(X,Y)
d = np.sqrt(X**2 + Y**2)
# strength of vector based on distance from centre
u = np.cos(p)*d 
v = -np.sin(p)*d 

fig, ax = plt.subplots(figsize =(5, 5))
c1 = np.array(list(map(vector_to_rgb, u.flatten(),v.flatten())))
ax.quiver(X, Y, u, v, angles='xy', scale_units='xy', scale=3, color=c1)
 
ax.axis([-3, 3, -3, 3])
ax.set_aspect('equal')
 
# show plot
plt.show()

You can see that the further the coordinate is away from the center, the more it is offset. Obviously we could do this with a simple rotation, but where's the fun in that and anyway, this is just one part of the transformation.
To add noise, we can just add a small random number to each coordinate's x and y values.

In the more realistic case below, we sample 200 points and displace them using the process outlined above using a 100 x 100 grid. 

In [None]:
from scipy.interpolate import griddata

# level of noise
g=0.3

#number of points
n=200

# create grid
grid_x, grid_y = np.meshgrid(np.linspace(-3, 3, 100),
                             np.linspace(-3, 3, 100), indexing='ij')

# create some points to compute
points = np.random.rand(n,2)*6
points = points-3.
noise = np.random.rand(n,2)*g

#for each point compute a vector
p = np.arctan2(points[:,0],points[:,1])
d = np.sqrt(points[:,0]**2 + points[:,1]**2)
u = np.cos(p)*d + noise[:,0]
v = -np.sin(p)*d + noise[:,1]

# interpolate the vector for the whole grid from the points
print(p.shape,u.shape,grid_x.shape,grid_y.shape)
plot_x = griddata(points, u, (grid_x, grid_y), method='linear',fill_value=0.0)
plot_y = griddata(points, v, (grid_x, grid_y), method='linear',fill_value=0.0)

angles = p
lengths = d
max_abs = np.max(d)

# color is direction, hue and value are magnitude
c1 = np.array(list(map(vector_to_rgb, plot_x.flatten(), plot_y.flatten())))

fig, ax = plt.subplots(figsize =(10, 10))
ax.quiver(grid_x, grid_y, plot_x, plot_y,angles='xy', scale_units='xy', scale=10.0,pivot='mid',color=c1)
 
ax.axis([-3.0, 3.0, -3.0, 3.0])
ax.set_aspect('equal')
 
# show plot
plt.show()

Of course, the more grid points we need to interpolate the slower the computation. To get a feel for this, we can experiment with the number of data points and plot the results

In [None]:
tp = 8
n = 80
times = [0] * tp
ns = [0] * tp

for i in range(tp):
    
    grid_x, grid_y = np.meshgrid(np.linspace(-3, 3, n),
                             np.linspace(-3, 3, n*2), indexing='ij')
    
    # create some points to compute
    points = np.random.rand(n,2)*6
    points = points-3.
    noise = np.random.rand(n,2)*g

    #for each point compute a vector
    p = np.arctan2(points[:,0],points[:,1])
    d = np.sqrt(points[:,0]**2 + points[:,1]**2)
    u = np.cos(p)*d + noise[:,0]
    v = -np.sin(p)*d + noise[:,1]

    start = timer()
    # interpolate the vector for the whole grid from the points
    plot_x = griddata(points, u, (grid_x, grid_y), method='cubic')
    plot_y = griddata(points, v, (grid_x, grid_y), method='cubic')

    end = timer()
    times[i] = end-start
    ns[i]=n
    print(n,times[i])
    n *=2

plt.plot(ns, times, '-ok')
plt.xlabel("n-points")
plt.ylabel("Compute time Time (s)");
plt.yscale("linear")

plt.show    

This is actually pretty good, considering that doubling the number of coordinates is quadrupling the number of data points in the grid. It seems to increase by a factor of 4, which is expected, but we don't see any drop off in this scaling even up to 10k x 10k grid size. Nevertheless, ~50 seconds is quite a long time and precludes any sort of real-time interaction - especially if we want to work with larger images.
Let's try some other ways of computing the same thing - i.e. a pixel-level vector that could be used to warp an image to align with another.

Next we will load the H&E slide from the Xenium sample dataset. This dataset is associated with the preprint "High resolution mapping of the breast cancer tumor microenvironment using integrated single cell, spatial and in situ analysis of FFPE tissue", posted to bioRxiv November 03, 2022, in which a single breast cancer FFPE tissue block (Sample #1) was assayed with a trio of complementary technologies (Chromium, Visium, Xenium). In the manuscript revision (submitted June 23, 2023), a second breast cancer block (Sample #2) was analyzed with Xenium. See the preprint for full details on methods and results.
In this instance we will use tifffile to load the image, but we could also use cucim, the RAPIDS mage loading and processing library, which can speed things up for .svs and .tiff images.

In [None]:
import cupy
from matplotlib import pyplot as plt
import numpy as np
import tifffile as tiff 

he_file1 = "./data/he-registered-to-dapi-shlee.tif"
he_img = np.array(tiff.imread(he_file1))
h1 = he_img.shape[0]
w1 = he_img.shape[1]

# Show the image
plt.figure(figsize=(10,10))
plt.imshow(he_img)
plt.title('he-registered-to-dapi-shlee.tif')
print("Width = {}, Height = {}, Channels = {}".format(w1, h1, he_img.shape[2]))
plt.show()

So, once the image is loaded, we can apply some simulated deformation to it by generating a set of random points and then adding some rotation plus some noise (as before). Although the speed of the interpolation may not seem directly useful, it actally might be since, in some cases, images are aligned by defining anchor points on pairs of corresponding images and then deforming one of them to match the other. To do this smoothly, you need to interpolate each pixel's offset vector from the sparse points manually defined.

In [None]:
from scipy.ndimage import map_coordinates

grid_y, grid_x = np.meshgrid(np.linspace(0, h1-1, h1),
                             np.linspace(0, w1-1, w1), indexing='ij')
g=30. # level of noise
n=200 #number of points

y = np.random.uniform(0, h1, size=n)
x = np.random.uniform(0, w1, size=n)
noise = (np.random.rand(n+4,2)-0.5)*g

#add points at the extremities to avoid edge artefacts
y = np.concatenate((y,[0],[0],[h1-1],[h1-1]))
x = np.concatenate((x,[0],[w1-1],[0],[w1-1]))

# compute an angle and distance from the centre
p = np.arctan2(w1//2-x,h1//2-y)
d = np.sqrt((h1//2-y)**2 + (w1//2-x)**2)
# create a horizontal and verical vector
u = np.cos(p)*d 
v = -np.sin(p)*d

# set a maximum offset with respect to the image size
umax = np.max(u)
vmax = np.max(v)
max_offset=120
u = (u/umax) * 120 + noise[:,0]
v = (v/vmax) * 120 + noise[:,1]

# interpolate the vectors at the rest of the points
plot_x = griddata((y,x), u, (grid_y, grid_x), method='linear',fill_value=0.0)
plot_y = griddata((y,x), v, (grid_y, grid_x), method='linear',fill_value=0.0)

gx = grid_x + plot_x
gy = grid_y + plot_y
coords = np.stack((gy,gx),axis=0)
coords = np.concatenate((coords,np.zeros((1,h1,w1))),axis=0)
coords = np.stack((coords,coords,coords), axis=3)
# the mapping coords needs to have this set
coords[2,:,:,0]=0.
coords[2,:,:,1]=1.
coords[2,:,:,2]=2.

start = timer()
host_image = map_coordinates(he_img, coords)
print("warp image - ", timer()-start)
plt.figure(figsize=(10,10))
plt.imshow(host_image)

Okay, so that worked, but let's see if we can speed up the warping, using a cuda version of the map_coordinates function. Notice that we need to convert all of the numpy arrays to cupy arrays, which copies them onto the gpu.

In [None]:
from cupyx.scipy.ndimage import map_coordinates as cu_map_coordinates
import cupy as cp

cu_plot_x = cp.array(plot_x)
cu_plot_y = cp.array(plot_y)

gx = cp.array(grid_x) + cu_plot_x
gy = cp.array(grid_y) + cu_plot_y
coords = cp.stack((gy,gx),axis=0)
coords = cp.concatenate((coords,cp.zeros((1,h1,w1))),axis=0)
coords = cp.stack((coords,coords,coords), axis=3)
coords[2,:,:,0]=0.
coords[2,:,:,1]=1.
coords[2,:,:,2]=2.

start = timer()
cu_he_img = cp.array(he_img)
warped_image = cu_map_coordinates(cu_he_img, coords)
host_image = warped_image.get()
end = timer()
print("warp image - ",end-start)

plt.figure(figsize=(10,10))
plt.imshow(host_image)

As you will have noticed, that produced a huge speed-up! From over a minute to less than a second!

Let's visualize the image distortion by setting the red pixels to the amount of offset in the x direction and the blue pixels to the offset in the y direction

In [None]:
vec_img = np.zeros((h1,w1,3),dtype=np.uint8)
mx = np.max(plot_x.flatten())
mn = np.min(plot_x.flatten())

vec_img[:,:,2] = (plot_x-mn)*255/(mx-mn)
vec_img[:,:,1] = 0
mx = np.max(plot_y.flatten())
mn = np.min(plot_y.flatten())
vec_img[:,:,0] = (plot_y-mn)*255/(mx-mn)

plt.figure(figsize =(10, 10))

# show plot
plt.imshow(vec_img)

And let's plot out the difference between the original and the warped image

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(host_image-he_img)

You'll notice that the majority of the time was spent on the interpolation rather than the actual image warp. At the moment we are using a CPU-based interpolation. Although there are some GPU equivalents, there are none in cucim.skimage that can accept sparse points for 2D datasets. There are other libraries around that can do this sort of thing (e.g. OpenCV has many GPU-accelerated routines), but we'll look at some other techniques which might be useful to know. 

As an experiment we can create our own intepolation algorithm and see what sort of performance we can get using a couple of different techniques. First off, we could find the three sample points nearest each pixel on the image. RAPIDS has a some great solutions for this sort of computation and RAFT provides just such an algorithm.
Although there are some even faster hueristics that could be used here, even the brute force algorithm should be well within our latency budget here.

What we do below is to create a set of points that represent all the pixels in the image. We also send the set of deformation points over to the GPU and then let RAFT do it's thing. Given that there are >41 million points to evaluate against ~200 possible nearest neighbors, this could take some time...

In [None]:
%%time
import cupy as cp
from pylibraft.neighbors.brute_force import knn

# Create a point for every pixel in the image
x1 = cp.linspace(0, w1-1, num=w1)
y1 = cp.linspace(0, h1-1, num=h1)
X,Y = cp.meshgrid(x1, y1)

# Formulate the points as an array of x,y coordinates (e.g. shape = (height * width,2))
points = cp.stack((Y.flatten(),X.flatten()),axis=-1)
points = points.astype(cp.float32)

# Add the randomly created sample points to a GPU array
data = cp.stack((cp.array(y,dtype=cp.float32),cp.array(x,dtype=cp.float32)),axis=1)
k = 3 # find nearest three sample points from each pixel location
_, neighbors = knn(data, points, k)

print(cp.asarray(neighbors).shape)

In fact, it actually only takes a seconds or two to find the 3 nearest neighbors for all 41,151,487 pixels in the image, which is pretty impressive!

Something we can do - to sanity check the nearest neighbor computation, is to plot out the nearest neighbour indexes as RGB values (since we picked <255 of them and each point has three). If it's working as expected then we should see something resembling a voronoi diagram.

In [None]:
%%time
nn_image = np.zeros((h1,w1,3),dtype=np.uint8)
i=0

ns = np.array(cp.asarray(neighbors).get())

for y in range(h1):
    for x in range(w1):
        nn_image[y,x]=ns[i]
        i+=1
        
plt.imshow(nn_image)

So that took a little while to compute. One of the first tools we can try, to make things happen faster is to use [Numba](https://numba.pydata.org/). Numba supports a few different underlying strategies, including both GPU and CPU. 

In [None]:
%%time
from numba import jit

im = np.zeros((h1,w1,3),dtype=np.uint8)
ns = np.array(cp.asarray(neighbors).get())

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def fast_image(a,b): # Function is compiled to machine code when called the first time
    
    for y in range(h1):
        for x in range(w1):
            i = y * w1 + x
            a[y,x]=b[i]
        
fast_image(im,ns)
plt.imshow(im)

That was definitely a bit quicker. In this case Numba was just using the CPU cores to speed things up, but it can also use the GPU. You can explore this in more detail [here](https://numba.pydata.org/numba-doc/latest/cuda/index.html).

Now we can weight the vectors from the three nearest points using interpolation. This will provide us with the grid we need to perform the image warp.

To apply the correct weight at each point we can linearly interpolate the values by thinking of the triangle formed between the three nearest neighbours as a 2D projection of a triangle lying on a plane in which the vector values form the z dimension. We can use the three points with their weight vectors as inputs to get the formula for the plane, which then allows us to plug in new x, y values and obtain the z value.

<img src="./images/interp.png" alt="Interpolation" style="width: 300px;"/>

In [None]:
def get_z(x,y,n,a):
    # interpolate vector using the computed plane info
    return  a[2]-((n[0]*(x-a[0])+n[1]*(y-a[1]))/n[2])

def apply_weighted_avg(u,v,nns,x,y):
    # add the vector as the z dimension
    nnu = cp.column_stack((nns,u.T))
    
    #get the vector between two pairs of the 3 points
    ab_u = nnu[0]-nnu[1]
    bc_u = nnu[1]-nnu[2]
    # ocmpute the cross product
    n_u = cp.cross(ab_u, bc_u)
    
    nnv = cp.column_stack((nns,v.T))
    ab_v = nnv[0]-nnv[1]
    bc_v = nnv[1]-nnv[2]
    n_v = cp.cross(ab_v, bc_v)
   
    ua = get_z(x,y,n_u,nnu[0])
    va = get_z(x,y,n_v,nnv[0])
   
    # return the interpolated vectors (u and v) at this x and y coordinate
    return cp.array((ua,va))

# assign an array to return the result into
grid = cp.zeros((w1,h1,2),dtype=cp.float32)

# convert to GPU array
cneighbors = cp.asarray(neighbors)
cu = cp.array(u)
cv = cp.array(v)
ix=0

start = timer()

#iterate through all the points in the image
for i in range(w1):
    for j in range(h1):
        ux=cu[cneighbors[ix]]  
        vx=cv[cneighbors[ix]]
        p = data[cneighbors[ix]]
        grid[i,j] = apply_weighted_avg(ux,vx,p,i,j)
        ix+=1
        
    if ix>100:
        break
        
t = timer()-start
                                                
print("Computing {} points took {} seconds".format(i,t))

Oh - that could take quite a while, given that we only computed less than a 1/10000th of the full image. Luckily we can speed this up by creating a kernel (GPU function). There are several ways of doing this, from using C++ CUDA, PyCUDA or using a Python-based syntax like like Numba. Since we are already using cupy, let's see how we might do that using cupy's kernel methods.

The key concept to grasp with writing kernels is that we don't loop through all the function calls in our host (CPU) code. Instead we define a function, pass (references to) the whole arrays we want to process and let the GPU figure out how to schedule the execution using the number of threads and thread blocks we specify. Within the kernel we also need to map the current thread of execution to a specific element within the arrays we have passed. Note also that we don't return anything - we pass the output in as a parameter instead. This is always the case with kernel functions.

One other thing to note with this type of cupy rawkernel - we can't use array functions inside the kernel, which is why we compute the cross product 'by hand'.

N.B. kernel functions never return anything. Instead we pass the array we want to work on to the function and set its elements within the kernel

In [None]:
# import the Just-In-Time compiler for cupy
from cupyx import jit
from cupyx.profiler import benchmark

# Let the compiler know that we want to create a kernel from this function
@jit.rawkernel()
def elementwise_uv(uv, u, v, n, d, p, size):
    # get the current thread id 
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    # get the total number of threads available
    ntid = jit.gridDim.x * jit.blockDim.x
    
    # if there are more elements to process than the total number of threads 
    # available then we neeed to loop over the array a few times
    for i in range(tid, size, ntid):
        if i<size:
            # index to neighbors
            ia = n[i,0] 
            ib = n[i,1]
            ic = n[i,2]
            # current point
            yy = p[i,0]
            xx = p[i,1]

            # neighbor location (x,y)
            a = d[ia]
            b = d[ib]
            c = d[ic]

            # compute vector between pairs of the 3 vectors
            ab_u = u[ia] - u[ib]
            bc_u = u[ib] - u[ic]
            ab_v = v[ia] - v[ib]
            bc_v = v[ib] - v[ic]
            # compute vector between pairs of the 3 points
            ab_y = a[0] - b[0]
            ab_x = a[1] - b[1]
            bc_y = b[0] - c[0]
            bc_x = b[1] - c[1]

            # elementwise cross product computation
            xp_ux = (ab_y * bc_u) - (ab_u * bc_y)
            xp_uy = (ab_u * bc_x) - (ab_x * bc_u)
            xp_uz = (ab_x * bc_y) - (ab_y * bc_x)
            xp_vx = (ab_y * bc_v) - (ab_v * bc_y)
            xp_vy = (ab_v * bc_x) - (ab_x * bc_v)
            xp_vz = (ab_x * bc_y) - (ab_y * bc_x)

            # check for div by 0 issues
            if xp_uz>0:
                ua = u[ia]-((xp_ux*(xx-a[0])+xp_uy*(yy-a[1]))/max(xp_uz,0.001))
            else:
                ua = u[ia]-((xp_ux*(xx-a[0])+xp_uy*(yy-a[1]))/min(xp_uz,-0.001))
            
            if xp_vz>0:
                va = v[ia]-((xp_vx*(xx-a[0])+xp_vy*(yy-a[1]))/max(xp_vz,0.001))
            else:
                va = v[ia]-((xp_vx*(xx-a[0])+xp_vy*(yy-a[1]))/min(xp_vz,-0.001))
            
            # if the point is not within the bounds of the nns 
            # trim to max of the neighbors
            if ua>0:
                ua = min(ua, max(u[ia],u[ib],u[ic]))
            else:
                ua = max(ua, min(u[ia],u[ib],u[ic]))

            if va>0:
                va = min(va, max(v[ia],v[ib],v[ic]))
            else:
                va = max(va, min(v[ia],v[ib],v[ic]))

            # set the values in the results array
            uv[i,0] = cp.float32(va)
            uv[i,1] = cp.float32(ua)

grid2 = cp.zeros((points.shape),dtype=cp.float32)

cneighbors = cp.asarray(neighbors)
cu = cp.array(u)
cv = cp.array(v)

#print(benchmark(elementwise_uv[128, 1024], (grid, cu, cv, cdistances, cneighbors, distances.shape[0]), n_repeat=10))  
elementwise_uv[128, 1024](grid2, cu, cv, cneighbors, data, points, grid2.shape[0])  #  Numba style
grid2.get()
print("Completed")

Amazingly, this kernel is able to speed up the computation by several orders of magnitude! For precise timings you can uncomment the benchmark call to the kernel, which averages several calls to the function to provide a more accurate performance figure.

That is a lot faster!! If it actually works, then you'd expect that if we reverse the vector and apply it to the warped H&E image then it should return it to something like its original shape. Some pixels will have been lost because they were warped outside of the bounds of the image, but let's try it.

First, though, we will plot the pixel offset map that this technique generated. As before blue and red channels are used to represent the horzontal and vertical vectors

In [None]:
coords = np.moveaxis(grid2.get().reshape(h1,w1,2),(0,1,2),(1,2,0))
mx = np.max(coords.flatten())
mn = np.min(coords.flatten())

vec_img = np.zeros((h1,w1,3),dtype=np.uint8)
vec_img[:,:,0] = (coords[0]-mn)*255/(mx-mn)
vec_img[:,:,1] = 0
vec_img[:,:,2] = (coords[1]-mn)*255/(mx-mn)

plt.figure(figsize =(10, 10))

# show plot
plt.imshow(vec_img)

So the output shows that it is doing what is expected. Note that the edges of the nearest neighbor regions are less smooth that the skimage method we used before. This could be fixed in a few ways. However, the upside is that the anchor points remain in exactly the position they were defined at. We could smooth the output in several ways but this smoothing may slightly alter the anchor points we defined, so there are pros and cons.
One method that could be explored would be to choose the four nearest points and then compute three sets of interpolations and take the average. Slightly more complex but it ought to smooth things out. Something to try when you have time? 

In [None]:
coords = cp.moveaxis(cp.negative(grid2).reshape(h1,w1,2),(0,1,2),(1,2,0))
coords[1,:,:]+=cp.array(grid_x) 
coords[0,:,:]+=cp.array(grid_y)
print(coords.shape)
coords = cp.concatenate((coords,cp.zeros((1,h1,w1))),axis=0)

coords = cp.stack((coords,coords,coords), axis=3)
coords[2,:,:,0]=0.
coords[2,:,:,1]=1.
coords[2,:,:,2]=2.

start = timer()
unwarped_image = cu_map_coordinates(warped_image, coords)
host_image = unwarped_image.get()
end = timer()
print("warp image - ",end-start)

#plot the difference between the original and the unwarped
plt.figure(figsize=(10,10))
plt.imshow(host_image)

Now we can look at the difference between the original and the de-warped image

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(host_image-he_img)

So the quality of the actual unwarping isn't too bad, but the point is that it was performed extremely quickly. The algorithm can certainly be improved, but we have sped things up without having to leave the comfort of our Python armchair :)

Okay, now we are going to load the DAPI channel image and resize it so that it matches the H & E. To do the resizing we will actually create a routine really just to show another easy source of acceleration, which is Numba's @jit feature. We can create kernels in a similar way to how we did previously with cupy, but perhaps the best acceleration to effort ratio can be acheived with one simple function decoration.
First, run the code cell below as it is and try loading the DAPI image in the cell below and see how long in takes to load and process the whole image.
Then, comment out the @jit line and do the same thing. 

In [None]:
import math
import numpy as np
from numba import jit, prange

# TODO - Comment out the line below to prevent Numba from compiling a parallel version of this function
@jit(nopython=True, parallel=True)
def resize_dapi(input_image, new_height, new_width):
    
    output_image = np.zeros((new_height, new_width), dtype=input_image.dtype)
    old_height, old_width = input_image.shape
    scale_factor_y = old_height/new_height
    scale_factor_x = old_width/new_width

    for new_y in prange(new_height):
        old_y = new_y * scale_factor_y
        y_fraction = old_y - math.floor(old_y)
        for new_x in prange(new_width):
            # use the scale factor to find the source pixels
            old_x = new_x * scale_factor_x
            x_fraction = old_x - math.floor(old_x)

            # Sample four neighboring pixels:
            left_upper = input_image[math.floor(old_y), math.floor(old_x)]
            right_upper = input_image[math.floor(old_y), min(old_width - 1, math.ceil(old_x))]
            left_lower = input_image[min(old_height - 1, math.ceil(old_y)), math.floor(old_x)]
            right_lower = input_image[min(old_height - 1, math.ceil(old_y)), min(old_width - 1, math.ceil(old_x))]

            # Interpolate horizontally:
            blend_top = (right_upper * x_fraction) + (left_upper * (1.0 - x_fraction))
            blend_bottom = (right_lower * x_fraction) + (left_lower * (1.0 - x_fraction))
            # Interpolate vertically:
            final_blend = (blend_top * y_fraction) + (blend_bottom * (1.0 - y_fraction))
            output_image[new_y, new_x] = final_blend

    return output_image

def normalize(x, high=255):
    y = x.flatten()
    mx = np.max(y)
    mn = np.min(y)
    
    return (x - mn)*high/(mx-mn)
    

So that has created the resizing and normalizing functions. Now we can load the DAPI channel image and apply these functions

In [None]:
import cupy
from matplotlib import pyplot as plt
import numpy as np

input_file2 = "./data/morphology_mip.ome.3.invert.tif"
dp_img = np.array(tiff.imread(input_file2))

# Dimensions of the whole Slide at full-resolution
h2 = dp_img.shape[0]
w2 = dp_img.shape[1]

# resize to smaller of the 2 using linear interpolation
dp_img = resize_dapi(dp_img[:,:], h1, w1)
dp_img = normalize(dp_img,255).astype(np.uint8)

# Show the image
plt.figure(figsize=(10,10))
plt.imshow(dp_img)
plt.title('morphology_mip.ome.3.invert.tif')
print("Width = {}, Height = {}".format(w2, h2))
plt.show()

So, although the acquisition modality was different, you can still see that there are large scale features that correspond to the H & E.

What you should have observed is that whilst the jit compiled version runs quickly, the non-jit version will take several minutes (Feel free to stop execution of that code cell, if you get bored). That's because, when you insert the @jit command, Numba compiles a version of the function and creates a parallel version of the loop code. 

Finally we are going to play with another tool that can also be used to speed up the processing of potentially expensive operations by distributing the code to any available cluster resources. In this case we are just going to use CPU threads, but the same process can be used for a GPU cluster too - which is very powerful. This is DASK and it has many uses and techniques but we will look at one simple example. What DASK does is to examine your code and figures out what dependencies exist between the various computations and builds an execution graph. This can be displyed using the following toy example

Imagine that we want to sum a series of integers. Naively, you'd have to iterate over each element one at a time adding each element to the running total. The run time would be a factor of the number elements. A better way would be to concurrently add every other element to its neighbour iteratively until there is only one element left. This would bring the runtime down to log(N) time. By providing a few basic commands you can let Dask figure out the execution graph for you. Let's look at a concrete example

We can write the code to do the adding for us using a Dask Delayed function. This means that before the result is calculated a graph is constructed and Dask will map this graph onto the available compute (e.g. Processes, Threads or GPUs)

In [None]:
import dask
from dask import delayed

@dask.delayed
def add(x, y):
    return x + y

a = [i+1 for i in range(16)]
b = []

while len(a)>1:
    for i in range(0,len(a),2):
         b.append(add(a[i],a[i+1]))
    a=b
    b=[]
    
result = a[0]
   
result.visualize()

At this point, no computation has been done - just the graph construction. By doing this up-front, a more efficient graph can be created. You can see that the graph shows how the additions at each phase can be done in parallel , but also how each subsequent addition depends only on its ancestors. To actually do the computation, we need to execute a compute() command

In [None]:
%%timeit
result.compute()

Note that, in this toy example, the overhead of organizing the concurrency would far outweight any gains. This technique is really only suitable for larger problems. So, let's apply it to something a little more challenging.
Let's compute the variance on the H & E and DAPI channels and see how they compare. The intuition being that, even if the modalities are different, we should see similar distributions of information variation across the images. This could provide some common ground with which to align them.
We will use a sliding 8x8 window to compute the variance. We'll do it on the DAPI channel first.

This time we create a Dask Cluster and Client, which allows us to use the handy dashboard


In [None]:
from dask.distributed import Client, LocalCluster

# Setup a local cluster.
cluster = LocalCluster(dashboard_address= 8787, processes=False)
client = Client(cluster)
client

The dashboard URL above is fine for local installations, but when the server is hosted remotely it may not work. In this case, paste the URL generated in the cell below into the DASK DASHBOARD URL field in the Dask tool from the didebar tools. 

In [None]:
import subprocess

result = subprocess.run(['curl', '-s','ipecho.net/plain'], stdout=subprocess.PIPE)
url = "http://" + str(result.stdout.decode('utf-8')) + ":8787"
print(url)

In [None]:
import dask.bag as db
sz=8

def compute_variance(arr):
    x1 = arr[0]
    y1 = arr[1]
    x2 = arr[0]+sz
    y2 = arr[1]+sz
    
    nn_image[y1//sz,x1//sz]=dp_img[y1:y2,x1:x2].flatten().var()
    
nn_image = np.zeros((h1//sz,w1//sz),dtype=np.uint8)  


start_loc_data = [(x, y) for x in range(0,w1-sz,sz) for y in range(0,h1-sz,sz)]
        
b = db.from_sequence(start_loc_data, npartitions=100)
print("bagged")
b = b.map(compute_variance)
print("mapped")
results_bag = b.compute()
print("computed")
plt.imshow(nn_image)

Next we can do the same thing with the H&E image and see how they compare

In [None]:
import dask.bag as db

def compute_variance(arr):
    x1 = arr[0]
    y1 = arr[1]
    x2 = arr[0]+sz
    y2 = arr[1]+sz
    
    n2_image[y1//sz,x1//sz]=he_img[y1:y2,x1:x2,2].flatten().var()
    
n2_image = np.zeros((h1//sz,w1//sz),dtype=np.uint8)  
        
b = db.from_sequence(start_loc_data, npartitions=200)
print("bagged")
b = b.map(compute_variance)
print("mapped")
results_bag = b.compute()
print("computed")
plt.imshow(n2_image)

Now we can look at the similarity of the images to see whether this is a sufficiently common representation to use for registration.

In [None]:
diff = abs(nn_image-n2_image)
print(max(diff.flatten()),min(diff.flatten()))
      
plt.imshow(diff)

Although this operation is not especially well-suited to dask, it does show how dask can be used to speed up processes that might otherwise take a long time. 

The outputs of this might be used to idenfify points that would make good candidates for anchor points (i.e. those that correspond well between modalities). You could apply some some of thresholding, find the maxima and train a kernel that can find these points based on one of the modalities. However, that is a rich topic for another day...!

# Final Exercise
Given a convolution kernel, can you create an accelerated function that detects the highest correspondence between the R channel of the H&E and the DAPI channel?

The motivation is that if we can detect regions of the image that are quite similar then we might be able to use these as anchor points. In reality, it might make sense to train a network to learn what type of kernel(s) provide the most similar activations across the two modalities. For this exercise you can just use a fixed kernel and pass it over the two images and see which locations produce the highest joint activation. This is computed simply by multiplying the summed activations together from each image at each location. Feel free to experiment with different convolution filters, but the main objective is to see how quickly you can do this for all the pixels of the image.
Use whichever method(s) you prefer.

A serial version of the code is provided below.

In [None]:
%%time

def joint_activation(h_image, d_image, k_filter):
    
    ksize = k_filter.shape[0]
    new_height = h_image.shape[0]-k_filter.shape[0]+1
    new_width = h_image.shape[1]-k_filter.shape[1]+1
    output_image = np.zeros((new_height, new_width), dtype=np.float64)
    
    # reduced size to prevent excessive execution time!!!
    # remove when you are confident of a speed up!!
    new_height=1000
    new_width=1000

    for y in range(new_height):
        for x in range(new_width):
            sum1:np.float64=0.
            for i in range(ksize):
                for j in range(ksize):
                    a = h_image[y+i,x+j] * kernel[i,j]
                    b = d_image[y+i,x+j] * kernel[i,j]
                    sum1=sum1 + a + b

            output_image[y,x] = np.float64(sum1/255)
            
    return output_image

kernel = np.array([[0.4,0.5,0.4],[0.5,1.,0.5],[0.4,0.5,0.4]])
activation_image = joint_activation(he_img[:,:,0],dp_img[:,:],kernel)

plt.imshow(activation_image)

Thank you for completing this lab today. We hope you enjoyed the content, learned something and will come back for more sessions in the future!