In [1]:
##Import libraries
import torch
from torch import nn
import os
import h5py
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import matplotlib.image as mpimg
import numpy as np
import cv2
import math
import time
from scipy.interpolate import UnivariateSpline, interp1d
from scipy import interpolate
from scipy import optimize
from scipy import integrate
from scipy import signal
#from arc_length_1 import *
from scipy import stats as st
from tkinter import *
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg, NavigationToolbar2Tk)
import warnings

In [2]:
os.chdir("C:/Users/roryg/Desktop/Zhen Lab/2022/Kymograph_v1/Kymograph_v1")
from image import *
from process import *
from classes import *



In [3]:
#Debugging 
img_dict = {}
width_dict = {}

In [4]:
warnings.filterwarnings("ignore")
device = torch.device("cpu")

In [5]:
green_filename = 'h5_greenonly20230111_4.h5' #green dataset
red_filename = 'h5_redonly20230111_4.h5' # red dataset, use red for

file_path_model = "D:/Models/ver15.pth"
file_path_data = "D:/Datasets/"

In [6]:
#Set up segmentation model
segnet = SegNet()
segnet.load_state_dict(torch.load(file_path_model,map_location=device))
segnet.eval()

SegNet(
  (conv1_1): KerasConv(
    (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): ReLU()
  )
  (conv1_2): KerasConv(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): ReLU()
  )
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (layer_1): Sequential(
    (0): KerasConv(
      (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (activation): ReLU()
    )
    (1): KerasConv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (activation): ReLU()
    )
  )
  (conv2_1): KerasConv(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): ReLU()
  )
  (conv2_2): KerasConv(
    (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): ReLU()
  )
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode

In [7]:
#Load Data
r = h5py.File(file_path_data + red_filename, 'r') # Load h5
g = h5py.File(file_path_data + green_filename,"r")

In [8]:
total_frames = len(r.keys())-1
spacing = 3
start = 5300
j = '05300'
showUI = True

activation = {}

if not showUI:
    first_intestine_ind = 1310

In [9]:
lag_lengths = np.zeros(total_frames)

In [10]:
#UI functions 

canvas = None
fig = Figure()
window = Tk()
window.title("Identify pharyngeal intestine junction")



def plotGraphUI():
    global canvas, ROI
    canvas = FigureCanvasTkAgg(fig, master = window)

    plot1 = fig.add_subplot(111)
    flatten = ROI.flatten()
    #mode = st.mode(np.rint(flatten[flatten > 20]))[0][0] + threshold
    #print(mode)
    FI = ROI.copy()
    #FI[FI < mode] = 0 #try no threshold
    plot1.imshow(FI)
    #To get xticks range
    max_x = FI.shape[1]
    plot1.set_xticks(range(0,max_x, 100))

    canvas.draw()
    canvas.get_tk_widget().pack(fill = BOTH, expand = True)

def clearGraphUI():
    global canvas
    print("cleared")
    canvas.get_tk_widget().destroy()
    cavas = None


def getInput():
    global plot_exists, first_intestine_ind
    input = textbox.get(1.0, "end-1c")
    first_intestine_ind = int(input)
    #print(plot_exists)
    if plot_exists:
        clearGraphUI()
    plotGraphUI()
    plot_exists = True
    
textbox = Text(window, height = 5, width = 20)
textbox.pack()
plot_exists = False
plot_btn = Button(window, command = getInput, text = "Identify intestine region")
plot_btn.pack()
    

In [11]:
#Process first frame
arr = get_image(r, 0)
green_arr = get_image(g,0)
mask = getPrediction(arr, segnet, torch.device("cpu"))
mask = ChooseLargestBlob(mask)
mask_filled = fillHoles(mask)
mask_erode = erode(mask_filled)
mask_skeleton = skeletonize(mask_erode)
body_points = sortSkeleton(mask_skeleton)
original_img = padEdges(arr)
original_green_img = padEdges(green_arr)

[spline_x, spline_y, t_vals, x_vals, y_vals] = fitSpline(body_points)
ROI = displayThreshold(original_img,  spline_x, spline_y, t_vals, x_vals, y_vals)

0.2486902203969274


In [12]:
if showUI:
    window.mainloop()

cleared


In [31]:
[t_vals_1, body_points] = pointsOnSpline(spline_x, spline_y, max(t_vals), resolution = 4)
arc_lengths = arcLengthSpline(spline_x, spline_y, spline_x(t_vals_1), spline_y(t_vals_1)) #t_vals is 4 times longer than x_vals
[integrated_vals, ROI_l, width_dist] = integrateImage([original_img, original_green_img], spline_x, spline_y, t_vals_1, arc_lengths)
xpoints_1 = arc_lengths #variable switch here because I was too lazy
total_val = integrated_vals[:,0]
total_green_val = integrated_vals[:,1]

max_arc_length = math.floor(0.9*xpoints_1.max()) #first frame sets the kymograph y dim for all frames.
[waveform,green_waveform] = interpolateIntensity([total_val, total_green_val], xpoints_1, max_arc_length, spacing)


In [32]:
img_dict[start] = ROI
red_kym = np.zeros([total_frames, max_arc_length*spacing+400])
green_kym = np.zeros([total_frames, max_arc_length*spacing+400])




In [33]:
intestine_target = int(xpoints_1[first_intestine_ind])

kymo_length = max_arc_length *spacing
        #plotting just entry is good enough

#For alignment
kernel7 = np.tile(np.arange(-3,4),(7,1))
res7_1 = signal.correlate2d(ROI_l, kernel7)
avg_act1 = np.mean(res7_1, axis = 0)
lag_lengths[start] = np.argmax(avg_act1[intestine_target-20:intestine_target+20]) + intestine_target - 20
#plt.imshow(ROI)
#plt.show()
#plt.plot(avg_act1)
#plt.show()
red_kym[start] = waveform
green_kym[start] = green_waveform
width_dict[start] = width_dist
activation[start] = avg_act1

In [34]:

for j in range(start+1, 5500): #Loop over how many frames you want to analyse
    k = j
    arr = np.zeros(np.shape(r['t00000']['s00']['0']['cells'])[1:])
    green_arr = np.zeros(np.shape(g['t00000']['s00']['0']['cells'])[1:])
    j = str(j)
    lenj = len(j)
    for instance in range(5-lenj):
        j = '0' + j
    #print(j)
    for i in range(np.shape(r['t00000']['s00']['0']['cells'])[0]): #Sum up z stack
        t = r['t'+ j]['s00']['0']['cells']
        t_g = g['t'+ j]['s00']['0']['cells']
        green_arr = green_arr + t_g[i]
        arr = arr + t[i] # arr holds the summed up z stack image
        #print(arr)

    start_time = time.time()
    arr = arr.T #for test only#start_time = time.time()
    green_arr = green_arr.T
    original_dim = arr.shape

##Over here we select whether it is red or green we integrate
    original_arr = np.array(arr) #red
    original_green_arr = np.array(green_arr) #green

    mask = getPrediction(original_arr, segnet, torch.device("cpu"))
    mask = ChooseLargestBlob(mask)
    mask_filled = fillHoles(mask)
    mask_erode = erode(mask_filled)
    mask_skeleton = skeletonize(mask_erode)
    body_points = sortSkeleton(mask_skeleton)
    original_img = padEdges(original_arr)
    original_green_img = padEdges(original_green_arr)

    [spline_x, spline_y, t_vals, x_vals, y_vals] = fitSpline(body_points)

    [t_vals, body_points] = pointsOnSpline(spline_x, spline_y, max(t_vals), resolution = 4)
    x_points = arcLengthSpline(spline_x, spline_y, spline_x(t_vals), spline_y(t_vals)) #t_vals is 4 times longer than x_vals
    #print("global var? ",arcLengthSpline(spline_x, spline_y, spline_x(t_vals), spline_y(t_vals)).shape, xpoints.shape)
    [integrated_vals, ROI_l, width_dist] = integrateImage([original_img, original_green_img], spline_x, spline_y, t_vals, x_points)
    #print(integrated_vals.shape, xpoints.shape)
    cur_arc_length = body_points.max()
    
    total_val = integrated_vals[:,0]
    total_green_val = integrated_vals[:,1]


    #Cross correlate entire waveform to get approximate displacement vector
    res7 = signal.correlate2d(ROI_l, kernel7)
    avg_act = np.mean(res7, axis = 0)
    corr = signal.correlate(avg_act[20:-20], activation[k-1][20:-20], mode = 'full')
    lags = signal.correlation_lags(len(avg_act[20:-20]),len(activation[k-1][20:-20]), mode='full')
    lag = lags[np.argmax(corr)]
    #print(lag) #Should be 19
    smooth_actROI = signal.savgol_filter(avg_act, 15, 3)[int(lag_lengths[k-1])+lag-40:int(lag_lengths[k-1])+lag+40]
    maxima = np.array(signal.argrelextrema(smooth_actROI, np.greater)[0].tolist())

    prev_loc = np.array([40,activation[k-1][int(lag_lengths[k-1])]/1000])

    maximaXandY = (np.concatenate((np.expand_dims(maxima,axis = 1),np.expand_dims(smooth_actROI[maxima]/1000,axis = 1)),axis = 1))
    maximaDist = np.linalg.norm(maximaXandY - prev_loc,axis = 1) #y is distorted, divide by 100
    intestine2_len = maxima[np.argmin(maximaDist)] + int(lag_lengths[k-1])+lag-40 # Just x distance doesnt work, use y distance

    activation[k] = avg_act



    lag_lengths[k] = intestine2_len

    #Need to factor in previous start point
    xpoints = xpoints + (intestine_target - intestine2_len)
    
    print(xpoints.shape, total_val.shape, total_green_val.shape)
    [waveform,green_waveform] = interpolateIntensity([total_val, total_green_val], xpoints, max_arc_length, spacing)

    red_kym[k] = waveform
    green_kym[k] = green_waveform
    secs = time.gmtime((total_frames - k)*(time.time()-start_time))
    time_remaining = time.strftime("%H:%M:%S", secs)
    print("Done t=" + str(k) + " out of "+ str(total_frames) + "Estimated time remaining:" + str(time_remaining))

global var?  (2836,) (2528,)
(2836, 2) (2528,)
(2528,) (2836,) (2836,)


ValueError: x and y arrays must be equal in length along interpolation axis.

In [None]:
red_kym = red_kym.T
green_kym = green_kym.T
r.close()
g.close()

np.save("red_kym.npy", red_kym)
np.save("green_kym.npy", green_kym)
np.save("lag_indices.npy", lag_lengths)