# Import packages

In [1]:
import sys

sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt

# Load image

In [2]:
from membranequant.funcs import load_image

embryo_path = '../test_datasets/dataset2_par2_neon/00/'
img_path = embryo_path + '/af_corrected.tif'
img = load_image(file_path)

%matplotlib inline
plt.imshow(img, cmap='gray')
plt.gcf().set_size_inches(10,10)

NameError: name 'file_path' is not defined

# Specify rough ROI

This does not need to be precise, 4 points is usually enough. We will computationally refine the ROI later

In [None]:
from membranequant.roi import def_roi

periodic = True

# %matplotlib tk
# roi = def_roi(img, spline=True, periodic=periodic)
# print(roi.shape)

roi = np.loadtxt(embryo_path + 'ROI.txt')

In [None]:
%matplotlib inline
plt.imshow(img, cmap='gray')
plt.plot(roi[:, 0], roi[:, 1], c='aqua')
plt.scatter(roi[0, 0], roi[0, 1], c='aqua')
plt.gcf().set_size_inches(10,10)

# Quantify membrane concentrations around the cell

### Straighten

In [None]:
from membranequant.funcs import straighten, rolling_ave_2d

thickness = 50

straight = straighten(img, roi=roi, thickness=thickness, interp='cubic')

%matplotlib inline
plt.imshow(straight, cmap='gray')
plt.gcf().set_size_inches(15,15)

### Preprocess straight image for fitting

In [None]:
rol_ave = 20


def preprocess(straight, rol_ave=rol_ave, bg_subtract=bg_subtract):
    # Smoothen
    straight_filtered = rolling_ave_2d(straight, window=20, periodic=periodic)
    
    # Normalise
    norm = np.max(straight_filtered)
    target = straight_filtered / norm
    
    return target, norm

target, norm = preprocess(straight)

%matplotlib inline
plt.imshow(target, cmap='gray')
plt.gcf().set_size_inches(15,15)

### Set up model

Create differentiable model using tensorflow

In [None]:
import tensorflow as tf

sigma = 2

def sim_img(cyts, mems, offsets, sigma=sigma, thickness=thickness):
    nfits = cyts.shape[0]

    # Positions to evaluate mem and cyt curves
    positions = tf.reshape(tf.reshape(tf.tile(np.arange(thickness, dtype=np.float64), [nfits]),
                                      [nfits, thickness]) + tf.expand_dims(offsets, -1), [-1])

    # Mem curve
    mem_curve = tf.math.exp(-((positions - thickness / 2) ** 2) / (2 * sigma ** 2))
    
    # Cyt curve:
    cyt_curve = (1 + tf.math.erf((positions - thickness / 2) / sigma)) / 2
    
    # Reshape
    cyt_curve_ = tf.reshape(cyt_curve, [nfits, thickness])
    mem_curve_ = tf.reshape(mem_curve, [nfits, thickness])

    # Calculate output
    mem_total = mem_curve_ * tf.expand_dims(mems, axis=-1)
    cyt_total = cyt_curve_ * tf.expand_dims(cyts, axis=-1)
    return tf.transpose(tf.math.add(mem_total, cyt_total))

def loss_function(target_image, cyts, mems, offsets):
    return tf.math.reduce_mean((sim_img(cyts, mems, offsets) - target_image) ** 2)

### Initialise parameters

Calculate a rough initial guess which will serve as the starting conditions for gradient descent optimisation

In [None]:
def init_params(target):
    nfits = target.shape[1]
    offsets = tf.Variable(np.zeros(nfits))
    cyts = tf.Variable(np.mean(target[-5:, :], axis=0))
    mems = tf.Variable(np.max(target, axis=0) - 0.5 * cyts)
    return cyts, mems, offsets
    
%matplotlib inline
plt.imshow(sim_img(*init_params(target)), cmap='gray')
plt.gcf().set_size_inches(15,15)

### Optimise parameters by gradient descent

Perform gradient descent to optimise the cytoplasmic and membrane concentration parameters, and the offset parameters

In [None]:
lr = 0.01
iterations = 2000

def optimise(target, lr=lr, iterations=iterations):
    
    # Init parameters
    cyts, mems, offsets = init_params(target)
    
    # Gradient descent
    opt = tf.keras.optimizers.Adam(learning_rate=lr)
    losses = np.zeros(iterations)
    for i in range(iterations):   
        with tf.GradientTape() as tape:
            loss = loss_function(target, cyts, mems, offsets)
            losses[i] = loss
            var_list = [offsets, cyts, mems]
            grads = tape.gradient(loss, var_list)
            opt.apply_gradients(list(zip(grads, var_list)))
        
        # Plot fit
        if (i + 1) % 200 == 0: 
            plt.imshow(sim_img(cyts, mems, offsets), cmap='gray')
            plt.title('Iteration = ' + str(i + 1))
            plt.gcf().set_size_inches(15,15)
            plt.show()
            
    return cyts, mems, offsets, losses
            
%matplotlib inline
cyts, mems, offsets, losses = optimise(target)

In [None]:
%matplotlib inline
plt.plot(np.log10(losses))
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.gcf().set_size_inches(15,10)

# Results: membrane concentrations

Must rescale concentrations by normalisation factor calculated earlier

In [None]:
% matplotlib inline
plt.plot(cyts.numpy() * norm)
plt.axhline(0, c='k', linestyle='--')
plt.xlabel('Position')
plt.ylabel('Membrane concentration (a.u.)')
plt.gcf().set_size_inches(15,10)

# Optional: Refine ROI

Use fitted offset paramters to refine the ROI

In [None]:
plt.plot(offsets.numpy())
plt.xlabel('Position')
plt.ylabel('Offset')
plt.gcf().set_size_inches(15,10)

In [None]:
from membranequant.funcs import offset_coordinates, spline_roi

roi_new = offset_coordinates(roi, offsets)
roi_new = spline_roi(roi=roi_new, periodic=periodic, s=100)

%matplotlib inline
plt.imshow(img, cmap='gray')
plt.plot(roi[:, 0], roi[:, 1], c='r', label='Original ROI')
plt.plot(roi_new[:, 0], roi_new[:, 1], c='aqua', label='New ROI')
plt.legend()
plt.gcf().set_size_inches(10,10)

### Redo fitting with new ROI

In [None]:
straight_iteration2 = straighten(img, roi=roi_new, thickness=thickness, interp='cubic')
target_iteration2, norm_iteration2 = preprocess(straight_iteration2)
cyts_iteration2, mems_iteration2, offsets_iteration2, losses_iteration2 = optimise(target_iteration2)

In [None]:
%matplotlib inline
plt.plot(np.log10(losses_iteration2))
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.gcf().set_size_inches(15,10)

In [None]:
% matplotlib inline
plt.plot(mems_iteration2.numpy() * norm_iteration2)
plt.axhline(0, c='k', linestyle='--')
plt.xlabel('Position')
plt.ylabel('Membrane concentration (a.u.)')
plt.gcf().set_size_inches(15,10)

# ImageQuant class

We can perform the above optimisation in a single line using the ImageQuant class

### Set up class

In [None]:
from membranequant.quantification import ImageQuant

iq = ImageQuant(img=img, roi=roi, sigma=sigma, thickness=thickness, periodic=periodic, rol_ave=rol_ave, iterations=2, lr=lr)

### Perform optimisation

In [None]:
iq.run()

### View quantification

In [None]:
%matplotlib inline
plt.plot(iq.mems)
plt.plot(mems_iteration2.numpy() * norm_iteration2)
plt.xlabel('Position')
plt.ylabel('Membrane concentration (a.u.)')
plt.gcf().set_size_inches(15,10)

### View segmentation

In [None]:
% matplotlib inline

### Assess quality of fitting

In [None]:
% matplotlib tk