In [1]:
import numpy as np
import numba as nb
import cupy as cy
from numba import jit, njit, vectorize, prange, objmode
import PIL.Image as im
from matplotlib import pyplot as plt
from fast_histogram import histogram2d
from scipy.ndimage import gaussian_filter
from scipy.interpolate import interpn


@njit()
def B0(u):
    return (1.0 - u) ** 3.0 / 6.0


@njit()
def B1(u):
    return (3.0 * u ** 3.0 - 6.0 * u ** 2.0 + 4.0) / 6.0


@njit()
def B2(u):
    return (-3.0 * u ** 3.0 + 3.0 * u ** 2.0 + 3.0 * u + 1.0) / 6.0


@njit()
def B3(u):
    return u ** 3.0 / 6.0


@njit(parallel=True)
def B(val, order=0):
    o = np.zeros_like(val, dtype=np.float64)
    if order == 0:
        for i in prange(o.shape[0]):
            for j in prange(o.shape[1]):
                o[i, j] = B0(val[i, j])
        return o
    elif order == 1:
        for i in prange(o.shape[0]):
            for j in prange(o.shape[1]):
                o[i, j] = B1(val[i, j])
        return o
    elif order == 2:
        for i in prange(o.shape[0]):
            for j in prange(o.shape[1]):
                o[i, j] = B2(val[i, j])
        return o
    else:
        for i in prange(o.shape[0]):
            for j in prange(o.shape[1]):
                o[i, j] = B3(val[i, j])
        return o


@njit(parallel=True)
def indexing_arr(arr, dimY, dimX):
    out = np.zeros_like(dimY)
    for i in prange(dimY.shape[0]):
        for j in prange(dimY.shape[1]):
            out[i, j] = arr[dimY[i, j], dimX[i, j]]
    return out


@njit(parallel=True)
def indexing_arr_mutate(arr, dimY, dimX):
    out = np.zeros((dimY.size, 1), dtype=np.float64)
    for i in prange(dimY.shape[0]):
        out[i] = arr[dimY[i], dimX[i]]
    return out


@njit(parallel=True)
def meshgrid(iX, iY):
    oX = np.zeros((len(iY), len(iX)))
    oY = np.zeros((len(iY), len(iX)))
    for i in prange(len(iY)):
        oX[i, :] = iX[:]
    for j in prange(len(iX)):
        oY[:, j] = iY[:]
    return oX, oY


@njit()
def getCubicBSpline2DGrid(iImageSize, iStep):
    dy, dx = iImageSize
    a1 = np.arange(-iStep[0], (np.floor(dx / iStep[0] + 3) * iStep[0]), iStep[0])
    a2 = np.arange(-iStep[1], (np.floor(dy / iStep[1] + 3) * iStep[1]), iStep[1])
    tmpx, tmpy = meshgrid(a1, a2)
    return tmpx, tmpy


@njit()
def generate_cromosomes(num, init_locx, init_locy, rng=2):
    tmp = np.random.random((2, num, init_locx.shape[0], init_locx.shape[1])) * 2 * rng - rng
    tmp[0, :, :, :] += init_locx
    tmp[1, :, :, :] += init_locy
    return tmp


@njit(parallel=True)
def interpolate2d(iImage, iY, iX):
    '''
    :param iImage:
    :param iY: meshgrid od Y
    :param iX: meshgrid od X
    :return:
    '''
    out = np.zeros(iY.shape)
    start_row = 0
    end_row = 0
    start_col = 0
    end_col = 0
    for i in prange(len(iY[:, 1])):
        for j in prange(len(iX[:, 1])):
            if iY[i, i] % 1 == 0:
                start_row = int(iY[i, j])
                end_row = int(iY[i, j]) + 1
            else:
                start_row = int(np.floor(iY[i, j]))
                end_row = int(np.ceil(iY[i, j]))
            if iX[i, j] % 1 == 0:
                start_col = int(iX[i, j])
                end_col = int(iX[i, j]) + 1
            else:
                start_col = int(np.floor(iX[i, j]))
                end_col = int(np.ceil(iX[i, j]))
            if iImage[start_row:end_row + 1, start_col:end_col + 1].size == 0:
                out[i, j] = 0.0
            else:
                out[i, j] = np.mean(iImage[start_row:end_row + 1, start_col:end_col + 1])
    return out


@njit()
def getCubicBSpline2DDeformation(iImageSize, iCPx, iCPy, iStep):
    dy, dx = iImageSize
    gx, gy = meshgrid(np.arange(dx), np.arange(dy))
    oGx, oGy = np.zeros_like(gx), np.zeros_like(gy)
    for l in (0, 1, 2, 3):
        for m in (0, 1, 2, 3):
            tmpi = np.floor(gx / iStep[0])
            tmpj = np.floor(gy / iStep[1])
            i = np.asarray(tmpi, dtype=np.int64)
            j = np.asarray(tmpj, dtype=np.int64)
            u = gx / iStep[0] - i
            v = gy / iStep[1] - j
            oGx += B(u, order=l) * B(v, order=m) * indexing_arr(iCPx, j + m, i + l)
            oGy += B(u, order=l) * B(v, order=m) * indexing_arr(iCPy, j + m, i + l)
    return oGx, oGy

'''
@njit()
def deformImageBSpline2D(iImage, iCPx, iCPy, iStep):
    dy, dx = iImage.shape
    oGx, oGy = getCubicBSpline2DDeformation(iImage.shape, iCPx, iCPy, iStep)
    gx, gy = meshgrid(np.arange(dx), np.arange(dy))
    oGx = 2 * gx - oGx  # inverz preslikave
    oGy = 2 * gy - oGy  # inverz preslikave
    oImage = interpolate2d(iImage, oGy, oGx)
    return oImage
'''
@njit()
def deformImageBSpline2D(iImage, iCPx, iCPy, iStep):
    dy, dx = iImage.shape
    oGx, oGy = getCubicBSpline2DDeformation(iImage.shape, iCPx, iCPy, iStep)
    gx, gy = meshgrid(np.arange(dx), np.arange(dy))
    oGx = 2 * gx - oGx  # inverz preslikave
    oGy = 2 * gy - oGy  # inverz preslikave
    with objmode(oImage='f8[:,::1]'):
        oImage = interpn((np.arange(dy), np.arange(dx)), iImage.astype('float'), np.dstack((oGy, oGx)), method='linear', bounds_error=False, fill_value=0)
    return oImage


@njit()
def mi(ima, imb, b=32):
    hist_2d = np.zeros((32, 32), dtype=np.float64)
    with objmode(hist_2d='f8[:,::1]'):
        hist_2d = histogram2d(ima, imb, b, [[0, 255], [0, 255]])

    pxy = hist_2d / float(np.sum(hist_2d))
    px = np.sum(pxy, axis=1)
    py = np.sum(pxy, axis=0)
    px_py = px.reshape(b, 1) * py.reshape(1, b)
    nzs = pxy > 0
    tmp_nzs = np.ones((32, 32), dtype=np.int64)
    tmp_nzs = nzs.ravel()

    tmp_pxy = np.ones((32, 1), dtype=np.bool_)
    tmp_pxy = pxy.ravel()

    tmp_pxpy = np.ones((32 * 32, 1), dtype=np.bool_)
    tmp_pxpy = px_py.ravel()

    return np.sum(tmp_pxy[tmp_nzs] * np.log(tmp_pxy[tmp_nzs] / tmp_pxpy[tmp_nzs]))


@njit()
def deform_ims(fixed_im, floating_im, cromosomes, iStep):
    fixed_im = np.asarray(fixed_im, dtype=np.float64)
    floating_im = np.asarray(floating_im, dtype=np.float64)
    deformed_ims = np.zeros((cromosomes.shape[1], fixed_im.shape[0], fixed_im.shape[1]), dtype=np.float64)
    for i in range(cromosomes.shape[1]):
        cromoX = cromosomes[0, i, :, :]
        cromoY = cromosomes[1, i, :, :]
        deformed_ims[i, :, :] = deformImageBSpline2D(floating_im, cromoX, cromoY, iStep)
    return deformed_ims


@njit(parallel=True)
def fitness(fixed, images):
    fit = np.zeros(images.shape[0], dtype=np.float64)
    for i in prange(images.shape[0]):
        fit[i] = 1.0 / mi(fixed, images[i, :, :])
    return fit


@njit(parallel=True)
def filter_ims(fixed, images):
    filtered = np.zeros((images.shape[0], images.shape[1], images.shape[2]), dtype=np.float64)
    for i in range(images.shape[0]):
        razlika = np.abs(fixed - images[i, :, :])
        filt = np.array((razlika.shape[0], razlika.shape[1]), dtype=np.float64)
        with objmode(filt='f8[:,::1]'):
            filt = gaussian_filter(razlika, 20)
        filt = (filt - np.min(filt)) / (np.max(filt) - np.min(filt))
        filtered[i, :, :] = filt
    return filtered

@njit()
def generate_points(iCPx, iCPy):
    '''
    Generacija matrike tock preko matrik kontrolnih tock, za lazje racunanje z ostalimi funkcijami.
    :param iCPx: matrika x kontrolnih tock
    :param iCPy: matrika y kontrolnih tock
    :return: matrika tock dimenzije [stevilo tock, 2]
    '''
    tCPx = iCPx[1:-2, 1:-2].flatten()
    tCPy = iCPy[1:-2, 1:-2].flatten()
    src = np.zeros((tCPx.size, 2), dtype=np.uint8)
    src[:,0] = tCPx
    src[:,1] = tCPy
    return src


@njit(parallel=True)
def modify_filtered_vals(vals):
    tmp = np.zeros((vals.shape[0], vals.shape[1]), dtype=np.float64)
    for i in prange(1, vals.shape[0]-1):
        for j in prange(1, vals.shape[1]-1):
            if np.abs(vals[i, j] - np.max(vals[i-1:i+2, j-1:j+2]))>0.5:
                tmp[i, j] = np.mean(vals[i-1:i+2, j-1:j+2])
                #tmp[i, j] = np.max(vals[i - 1:i + 2, j - 1:j + 2]) / 2.0
            else:
                tmp[i, j] = vals[i, j]
    return tmp



@njit()
def get_filtered_vals(cromosomes, im_filtered, filter_pts):
    tmp = np.zeros((cromosomes.shape[2], cromosomes.shape[3]), dtype=np.float64)
    im_indexed = indexing_arr_mutate(im_filtered, filter_pts[:, 0], filter_pts[:, 1])
    tmp_vals = im_indexed.reshape((cromosomes.shape[2]-3, cromosomes.shape[3]-3)).transpose()
    tmp_vals = (tmp_vals - tmp_vals.min())/(tmp_vals.max()-tmp_vals.min())*0.8+0.2
    tmp[1:-2, 1:-2] = tmp_vals
    tmp = modify_filtered_vals(tmp)
    return tmp


@njit(parallel=True)
def mutate(dst_points, rng, filt_pts, filtered):
    mutation = np.random.random((2, dst_points.shape[1], dst_points.shape[2], dst_points.shape[3]))*2*rng - rng
    for i in range(dst_points.shape[1]):
        mutation[0, i, :, :] = np.multiply(mutation[0, i, :, :], 2*get_filtered_vals(cromosomes, filtered[i], filt_pts))
        mutation[1, i, :, :] = np.multiply(mutation[1, i, :, :], 2*get_filtered_vals(cromosomes, filtered[i], filt_pts))
    mutation_locations = np.random.randint(0, 2, (2, dst_points.shape[1], dst_points.shape[2], dst_points.shape[3]))
    return dst_points + mutation*mutation_locations


@njit()
def select_mating(population, fit, num_parents, deformed):
    # indeksi najmanjsih elementov kritejrijske funkcije
    min_idxs = np.abs(fit).argsort()[:num_parents]
    filtered = filter_ims(fixed, def_ims[min_idxs.reshape(min_idxs.shape[0]), :, :])
    # vrnemo x najboljse primerke za starse
    return population[:, min_idxs, :, :].reshape(2, num_parents, population.shape[2], population.shape[3]), fit[min_idxs].reshape(num_parents, 1), deformed[min_idxs.reshape(min_idxs.shape[0]), :, :], filtered, min_idxs


@njit(parallel=True)
def generate_pairs(noOfChildren, noOfParents):
    pairs = np.zeros((noOfChildren, 2), dtype=np.uint8)
    for i in prange(noOfChildren):
        id1 = 0
        id2 = 0
        while id1 == id2:
            id1 = int(np.random.random()*noOfParents)
            id2 = int(np.random.random()*noOfParents)
        pairs[i, 0] = id1
        pairs[i, 1] = id2
    return pairs


@njit(parallel=True)
def crossover(parents, noOfChildren, noOfParents):
    children = np.zeros((2, noOfChildren, parents.shape[2], parents.shape[3]), dtype=np.float64)
    swap_ids = np.random.randint(0, 1, (parents.shape[1], parents.shape[2], parents.shape[3]))
    pairs = generate_pairs(noOfChildren, noOfParents)
    for i in prange(noOfChildren):
        children[0, i, :, :] = parents[0, pairs[i, 0], :, :] * swap_ids[pairs[i, 0], :, :] + (
                    parents[0, pairs[i, 1], :, :] * (1 - swap_ids[pairs[i, 1], :, :]))
        children[1, i, :, :] = parents[1, pairs[i, 0], :, :] * swap_ids[pairs[i, 0], :, :] + (
                    parents[1, pairs[i, 1], :, :] * (1 - swap_ids[pairs[i, 1], :, :]))
    return children


@njit()
def find_best(cromosomes, fitness, deformed):
    fmin = fitness.min()
    idx = np.where(fitness == fmin)[0][0]
    T_optx = cromosomes[0, idx, :, :]
    T_opty = cromosomes[1, idx, :, :]
    def_opt = deformed[idx, :, :]
    return fmin, T_opty, T_optx, def_opt

In [2]:
noCromo = 10
noParents = 5
noChildren = 5

fixed = np.array(im.open('mr1.png'))
moving = np.array(im.open('mr7.png'))
fixed = np.asarray(fixed, dtype=np.float)
moving = np.asarray(moving, dtype=np.float)
height, width = fixed.shape
iStep = (((fixed.shape[0]-1)/5), ((fixed.shape[1]-1)/5))
oCPx, oCPy = getCubicBSpline2DGrid(fixed.shape, iStep)
source_pts = generate_points(oCPx, oCPy)
cromosomes = generate_cromosomes(noCromo, oCPx, oCPy, rng=2)

In [3]:
def_ims = deform_ims(fixed, moving, cromosomes, iStep)

In [4]:
%%timeit -r 10
# 1st step of genetic alg
def_ims = deform_ims(fixed, moving, cromosomes, iStep)
fit = fitness(fixed, def_ims)
fmin, T_opty, T_optx, def_opt = find_best(cromosomes, fit, def_ims)
parents, par_fit, par_def, par_filt, par_idxs = select_mating(cromosomes, fit, noParents, def_ims)
children = crossover(parents, noChildren, noParents)
parents = mutate(parents, 2, source_pts, par_filt)
cromosomes[:,0:noParents,:,:] = parents
cromosomes[:,noParents:,:,:] = children

316 ms ± 16.8 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [5]:
%timeit -r 10 def_ims = deform_ims(fixed, moving, cromosomes, iStep)
%timeit -r 10 fit = fitness(fixed, def_ims)
%timeit -r 10 filtered = filter_ims(fixed, def_ims)

305 ms ± 13.5 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)
768 µs ± 18 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
54.3 ms ± 1.38 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [6]:
def_ims = deform_ims(fixed, moving, cromosomes, iStep)
fit = fitness(fixed, def_ims)
fmin, T_opty, T_optx, def_opt = find_best(cromosomes, fit, def_ims)
parents, par_fit, par_def, par_filt, par_idxs = select_mating(cromosomes, fit, noParents, def_ims)
children = crossover(parents, noChildren, noParents)
parents = mutate(parents, 2, source_pts, par_filt)


In [14]:
%timeit -r 10 def_ims = deform_ims(fixed, moving, cromosomes, iStep)
%timeit -r 10 fit = fitness(fixed, def_ims)
%timeit -r 10 fmin, T_opty, T_optx, def_opt = find_best(cromosomes, fit, def_ims)
%timeit -r 10 parents, par_fit, par_def, par_filt, par_idxs = select_mating(cromosomes, fit, noParents, def_ims)
%timeit -r 10 children = crossover(parents, noChildren, noParents)
%timeit -r 10 parents2 = mutate(parents, 2, source_pts, par_filt)


272 ms ± 10.1 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)
744 µs ± 2.77 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
1.74 µs ± 25 ns per loop (mean ± std. dev. of 10 runs, 1000000 loops each)
27.7 ms ± 297 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
231 µs ± 2.42 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
2.36 ms ± 16.5 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [11]:
parents = mutate(parents, 2, source_pts, par_filt)
#%timeit -r 10 children = crossover(parents, noChildren, noParents)

In [13]:
%timeit -r 10 parents2 = mutate(parents, 2, source_pts, par_filt)

2.4 ms ± 41.5 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)
