In [1]:
from pycudadecon import decon
import tifffile as tf
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import describe
from skimage import exposure
from tkinter import filedialog
import os
from tqdm import tqdm
import GPUtil
%matplotlib notebook

In [2]:
def get_gpu_size():
    gpus = GPUtil.getGPUs()
    
    mem = gpus[0].memoryTotal

    return mem

In [3]:
def load_image(image_path):
    data = {}
    with tf.TiffFile(image_path) as tif:
            data["volume"] = tif.asarray()
            data["axes"] = tif.series[0].axes
            #data["imagej_metadata"] = tif.imagej_metadata
            
            
            ij_data = tif.imagej_metadata
            
            try:
                x_res = tif.pages[0].tags["XResolution"]   
                data["x_res"] = x_res.value[1]/x_res.value[0]
                y_res = tif.pages[0].tags["YResolution"]
                data["y_res"] = y_res.value[1]/y_res.value[0]
                data["X"] = tif.pages[0].tags["ImageWidth"].value
                data["Y"] = tif.pages[0].tags["ImageLength"].value
                data["z_step"] = ij_data["spacing"]
                data["bits"] = ij_data["Info"].split("\n")[0].split("=")[1]
                data["name"] = ij_data["Info"].split("\n")[6].split("=")[1]
                wavelengths = [item for item in data["name"].split("_") if 
                           (item=="470")|(item =="560")|(item=="630")]
                data["wavelengths"] = wavelengths
            except:
                data["x_res"] = 1.0
                data["y_res"] = 1.0
                data["z_step"] = 1.0
                data["bits"] = 8.0
            
            
    return data
  

In [4]:
def split_channels(im, wavelengths):
    """Split channels- assumes shape CZYX"""
    channels = {}
    for channel in range(im.shape[1]):
        c = im[channel]
        c = im[:, channel, :, :]
        channels[wavelengths[channel]] = c
    return channels
        

In [5]:
def plot_slices(channel):
    """Plot a selection of slices in a channel to compare"""
    center_pos = int(channel.shape[0]/2)
    mid_right_pos = int((channel.shape[0]/2)+channel.shape[0]/4)
    mid_left_pos = int((channel.shape[0]/2)-channel.shape[0]/4)
    fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (10, 5))
    for n, pos in enumerate([mid_left_pos, center_pos, mid_right_pos]):
        ax = axes[n]
        ax.imshow(channel[pos], vmin = 0., vmax = 65536, cmap = plt.cm.gray)
        

In [6]:
def deconvolution(channel, psf):
    left_edge = False
    right_edge = False
    top_edge = False
    bottom_edge = False
    
    if (channel.shape[1] % 5) == 0:
        y_subsets = 5
        y_size = int(channel.shape[1] / 5)
    elif (channel.shape[1] % 4) == 0:
        y_subsets = 4
        y_size = int(channel.shape[1] / 4)
    elif (channel.shape[1] % 3) == 0:
        y_subsets = 3
        y_size = int(channel.shape[1] / 3)
        
    
    if (channel.shape[2] % 5) == 0:
        x_subsets = 5
        x_size = int(channel.shape[2] / 5)
    elif (channel.shape[2] % 4) == 0:
        x_subsets = 4
        x_size = int(channel.shape[2] / 4)
    elif (channel.shape[2] % 3) == 0:
        x_subsets = 3
        x_size = int(channel.shape[2] / 3)
    
    deconvolved_im = np.zeros([channel.shape[0]-10, channel.shape[1], channel.shape[2]])
    
    # how to add overlap
    for y_no in range(y_subsets):
        y2 = (y_no+1) * y_size
        y1 = y2 - y_size
        
        

        for x_no in range(x_subsets):
            #print([y1, y2, x1, x2])
            x2 = (x_no+1) * x_size
            x1 = x2 -x_size
            
            subset = channel[:, y1:y2, x1:x2]
            print("subset shape" + str(subset.shape))
            

            result = decon(subset, psf["volume"], 
            dx_data = data["x_res"], 
            dz_data = data["z_step"],
            dx_psf = psf["x_res"], 
            dz_psf = psf["z_step"],
            background = 0, n_iters = 15)
            
            
            print("result shape" +str(result.shape))
            deconvolved_im[:, y1:y2, x1:x2] = result   


            
    #deconvolved_im = deconvolved_im.astype("uint8")
    
    return deconvolved_im
        

In [7]:
def deconvolution_with_overlap(data, channel, psf, overlap):
    left_edge = False
    right_edge = False
    top_edge = False
    bottom_edge = False
    
    gpu_available = get_gpu_size()
    channel_size = (channel.size * channel.itemsize)/1e6 # in mb
    mem_ratio = int(np.ceil(gpu_available/channel_size))+1
    
    
    
    if (channel.shape[1] % mem_ratio) == 0:
        y_subsets = mem_ratio
        y_size = int(channel.shape[1] / mem_ratio)
    elif (channel.shape[1] % (mem_ratio-1)) == 0:
        y_subsets = mem_ratio-1
        y_size = int(channel.shape[1] / mem_ratio-1)
    elif (channel.shape[1] % (mem_ratio-2)) == 0:
        y_subsets = mem_ratio-2
        y_size = int(channel.shape[1] / mem_ratio-2)
        
    
    if (channel.shape[2] % mem_ratio) == 0:
        x_subsets = mem_ratio
        x_size = int(channel.shape[2] / mem_ratio)
    elif (channel.shape[2] % (mem_ratio-1)) == 0:
        x_subsets = mem_ratio-1
        x_size = int(channel.shape[2] / mem_ratio-1)
    elif (channel.shape[2] % (mem_ratio-2)) == 0:
        x_subsets = mem_ratio-2
        x_size = int(channel.shape[2] / mem_ratio-2)
    
    #find largest stack range that matches up - run test decon to get shape that comes out
    stack_subset = channel[:, :y_size, :x_size]
    stack_result = decon(stack_subset, psf["volume"], 
            dx_data = data["x_res"], 
            dz_data = data["z_step"],
            dx_psf = psf["x_res"], 
            dz_psf = psf["z_step"],
            background = 0, n_iters = 10)
    
    """if stack_subset.shape == stack_result.shape:
        decon_im_shape = stack_subset.shape
    
    else:
        
        print(stack_subset.shape)
        print(stack_result.shape)
        diff = stack_subset.shape[0]-stack_subset.shape[1]
        stack_subset = stack_subset[:-diff]
        stack_result = decon(stack_subset, psf["volume"], 
            dx_data = data["x_res"], 
            dz_data = data["z_step"],
            dx_psf = psf["x_res"], 
            dz_psf = psf["z_step"],
            background = 0, n_iters = 10)
        if stack_subset.shape == stack_result.shape:
            decon_im_shape = stack_subset.shape
        
        else:
            print(stack_subset.shape)
            print(stack_result.shape)
            diff = stack_subset.shape[0]-stack_subset.shape[1]
            stack_subset = stack_subset[:-diff]
            stack_result = decon(stack_subset, psf["volume"], 
            dx_data = data["x_res"], 
            dz_data = data["z_step"],
            dx_psf = psf["x_res"], 
            dz_psf = psf["z_step"],
            background = 0, n_iters = 10)
            
            if stack_subset.shape == stack_result.shape:
                decon_im_shape = stack_subset.shape
                
            else:
                print(stack_subset.shape)
                print(stack_result.shape)
                diff = stack_subset.shape[0]-stack_subset.shape[1]
                stack_subset = stack_subset[:-diff]
                stack_result = decon(stack_subset, psf["volume"], 
                dx_data = data["x_res"], 
                dz_data = data["z_step"],
                dx_psf = psf["x_res"], 
                dz_psf = psf["z_step"],
                background = 0, n_iters = 10)

                if stack_subset.shape == stack_result.shape:
                    decon_im_shape = stack_subset.shape
                    
                else:
                    print(stack_subset.shape)
                    print(stack_result.shape)
                    print("couldnt find stack size")"""
            
        
    deconvolved_im = np.zeros([stack_result.shape[0], channel.shape[1], channel.shape[2]])
    deconvolved_im[deconvolved_im == 0] = np.nan
    
    # how to add overlap
    for y_no in tqdm(range(y_subsets+1)):
        y_overlap = int(y_size * (overlap/100))
        y1 = y_no * (y_size - y_overlap)
        y2 = y1 + y_size
        #print("y1 = {} and y2 = {}".format(y1, y2))
        
        
        if y1 == 0:
            top_edge = True
                
            print("top edge")
        else:
            top_edge = False
            
        if y1 > channel.shape[1] - y_size:
            bottom_edge = True
            print("bottom edge")
            y2 = channel.shape[1]
            
        
        else:
            bottom_edge = False
            
        

        for x_no in tqdm(range(x_subsets+1)):
            
            x_overlap = int(x_size * (overlap/100))
            x1 = x_no * (x_size - x_overlap)
            x2 = x1 + x_size
            #print("x1 = {} and x2 = {}".format(x1, x2))
            
            
            if x1 == 0:
                left_edge = True
                
                print("left_edge")
            else:
                left_edge = False
            if x1 > channel.shape[2] -x_size:
                right_edge = True
                print("right edge")
                x2 = channel.shape[2]
                
            
            else:
                right_edge = False
                
            
            subset = channel[:, y1:y2, x1:x2]
            #print("subset shape" + str(subset.shape))
            #print("subset dtype is {}".format(subset.dtype))
            
            result = decon(subset, psf["volume"], 
            dx_data = data["x_res"], 
            dz_data = data["z_step"],
            dx_psf = psf["x_res"], 
            dz_psf = psf["z_step"],
            background = 0, n_iters = 10) # try change to 50
            
            
            diff = np.array(subset.shape) - np.array(result.shape)
            
            #print("result dtype is {}".format(result.dtype))
            result2 = np.nanmedian([deconvolved_im[:, y1:y2-diff[1], x1:x2-diff[2]], result], axis=0)
                
            #print("result2 dtype is {}".format(result2.dtype))
            #print("result shape" +str(result.shape))
            deconvolved_im[:, y1:y2-diff[1], x1:x2-diff[2]] = result2   


    
    return deconvolved_im
        

In [8]:
def select_file():
    image_path = filedialog.askopenfilename()
    return image_path

In [9]:
def run(stitched= False, wv = None):
    
    """Select File"""
    image = select_file()
    
    """Load Image"""
    data = load_image(image)
    
    if stitched:
        channels = {wv : data["volume"]}
    else:
        """Split into separate channels"""
        channels = split_channels(data["volume"], data["wavelengths"])
    
    """Assign psf file locations"""
    psf_470 = '/Pierce/Lightsheet PSF_Macros/PSF 470 16 bit.tif'
    psf_560 = '/Pierce/Lightsheet PSF_Macros/PSF 560 16 bit.tif'
    psf_630 = '/Pierce/Lightsheet PSF_Macros/PSF 630 16 bit.tif'
    
    psf_dict= {"470": psf_470,
               "560": psf_560,
               "630": psf_630}
    
    """Loop through each channel, load the associated psf"""
    for wavelength, channel in tqdm(channels.items()):
        psf_path = psf_dict[wavelength]
        psf = load_image(psf_path)
        decon_im = deconvolution_with_overlap(data, channel, psf, overlap =10)
        print(decon_im.dtype)
        decon_im = decon_im.astype("float32")
        #plot_slices(decon_im)
        """Export deconvolved image for imagej"""
        tf.imwrite(image[:-4]+"_"+wavelength+".tif", decon_im, 
                   imagej=True, resolution=(1/data["x_res"], 1/data["y_res"]),
                   metadata={'spacing': data["z_step"], 'unit': 'um', 'axes': 'ZYX'})
        channels[wavelength] = decon_im
    
    return channels, data
        
        
    
    

In [10]:
channels, data = run()
#channels, data = run(True, "560")

  0%|                                                                                            | 0/3 [00:00<?, ?it/s]
  0%|                                                                                            | 0/3 [00:00<?, ?it/s][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

top edge
left_edge




 25%|█████████████████████                                                               | 1/4 [01:34<04:42, 94.19s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:10<03:11, 95.60s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:48<01:36, 96.44s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [05:08<00:00, 77.07s/it][A[A

 33%|███████████████████████████▋                                                       | 1/3 [05:08<10:16, 308.27s/it][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

left_edge




 25%|█████████████████████                                                               | 1/4 [01:35<04:46, 95.60s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:12<03:12, 96.09s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:46<01:35, 95.47s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [05:07<00:00, 76.80s/it][A[A

 67%|███████████████████████████████████████████████████████▎                           | 2/3 [10:15<05:07, 307.64s/it][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

left_edge




 25%|█████████████████████                                                               | 1/4 [01:34<04:44, 94.76s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:10<03:10, 95.44s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:44<01:34, 94.88s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [05:05<00:00, 76.33s/it][A[A

100%|███████████████████████████████████████████████████████████████████████████████████| 3/3 [15:20<00:00, 306.93s/it][A


float64


 33%|███████████████████████████▋                                                       | 1/3 [16:37<33:14, 997.23s/it]
  0%|                                                                                            | 0/3 [00:00<?, ?it/s][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

top edge
left_edge




 25%|█████████████████████                                                               | 1/4 [01:31<04:33, 91.27s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:06<03:07, 93.74s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:44<01:35, 95.50s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [05:04<00:00, 76.08s/it][A[A

 33%|███████████████████████████▋                                                       | 1/3 [05:04<10:08, 304.31s/it][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

left_edge




 25%|█████████████████████                                                               | 1/4 [01:32<04:36, 92.21s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:07<03:08, 94.10s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:41<01:33, 93.79s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [05:00<00:00, 75.17s/it][A[A

 67%|███████████████████████████████████████████████████████▎                           | 2/3 [10:05<05:02, 302.18s/it][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

left_edge




 25%|█████████████████████                                                               | 1/4 [01:32<04:37, 92.38s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:02<03:02, 91.33s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:34<01:31, 91.25s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [04:53<00:00, 73.32s/it][A[A

100%|███████████████████████████████████████████████████████████████████████████████████| 3/3 [14:58<00:00, 299.43s/it][A


float64


 67%|███████████████████████████████████████████████████████▎                           | 2/3 [32:48<16:21, 981.87s/it]
  0%|                                                                                            | 0/3 [00:00<?, ?it/s][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

top edge
left_edge




 25%|█████████████████████                                                               | 1/4 [01:31<04:33, 91.24s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:03<03:03, 91.72s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:33<01:30, 90.91s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [04:52<00:00, 73.16s/it][A[A

 33%|███████████████████████████▋                                                       | 1/3 [04:52<09:45, 292.65s/it][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

left_edge




 25%|█████████████████████                                                               | 1/4 [01:30<04:31, 90.49s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [03:03<03:03, 92.00s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:35<01:32, 92.16s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [04:55<00:00, 73.89s/it][A[A

 67%|███████████████████████████████████████████████████████▎                           | 2/3 [09:48<04:54, 294.37s/it][A

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A[A

left_edge




 25%|█████████████████████                                                               | 1/4 [01:29<04:27, 89.18s/it][A[A

 50%|██████████████████████████████████████████                                          | 2/4 [02:59<02:59, 89.85s/it][A[A

 75%|███████████████████████████████████████████████████████████████                     | 3/4 [04:30<01:30, 90.47s/it][A[A

right edge




100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [04:49<00:00, 72.40s/it][A[A

100%|███████████████████████████████████████████████████████████████████████████████████| 3/3 [14:37<00:00, 292.61s/it][A


float64


100%|███████████████████████████████████████████████████████████████████████████████████| 3/3 [48:37<00:00, 972.62s/it]


In [11]:
ikjhkj
image = select_file()
    
"""Load Image"""
data = load_image(image)

NameError: name 'ikjhkj' is not defined

In [None]:
with tf.TiffFile(image) as tif:
            print(tif.pages[0].tags)
            ij_data = tif.imagej_metadata

#ij_data            

In [None]:
name = ij_data["Info"].split("\n")[6].split("=")[1]
wavelengths = [item for item in name.split("_") if 
                           (item=="470")|(item =="560")|(item=="630")]
wavelengths

### Load data

In [None]:
#image = '/Pierce/DRG/Rat2/DRG1/DRG1.ome'
#psf_path = '/Pierce/Lightsheet PSF_Macros/560 PSF 8bit.tif'
#psf_path = '/Pierce/Lightsheet PSF_Macros/PSF-16bit.tif'
#result = decon(image_path, psf_path)
#data = load_image(image)

In [None]:
#gpus = get_gpu_size()
#mem = gpus[0].memoryTotal



### Split channels

In [None]:
#channels = split_channels(data["volume"], data["wavelengths"])

In [None]:
c1 = channels["560"]

In [None]:
file_size = (c1.size * c1.itemsize)/1e6

In [None]:
file_size/4

In [None]:
np.ceil(mem/file_size)

In [None]:
##data["x_res"]


In [None]:
#data["z_step"]

In [None]:
channels["560"] = channels["560"].astype("uint16")

In [None]:
channels["560"]

In [None]:
#tf.imwrite(image[-4:]+"_560"+".tif", channels["560"], 
                   imagej=True, resolution=(1/data["x_res"], 1/data["y_res"]),
                   metadata={'spacing': data["z_step"], 'unit': 'um', 'axes': 'ZYX'})

### Load psf image

In [None]:
#psf = load_image(psf_path)

### Deconvolve and stitch

In [None]:
#deconvolved_im = deconvolution(c1, psf) # i think decon is changing output shape 
#maybe try ensuring shape output is right
# maybe just try stitching with no overlap

In [None]:
#decon_im = deconvolution_with_overlap(c1, psf, overlap =10) 

In [None]:
image

In [None]:
#np.array(c1.shape)-np.array(c1.shape)

In [None]:
2160/5

In [None]:
2560/5

In [None]:
#c1_8bit[0,0, :2160]

In [None]:
"""Gamma then rolling ball"""

In [None]:
256*256

In [None]:
#plot_slices(c1)
#plot_slices(deconvolved_im)
#plot_slices(decon_im_overlap)
#plot_slices(result)
#plot_slices(bk)
#plot_slices(bk_im)


## Now check in ImageJ - do further processing such as background subtraction and gamma adjustment etc if neccesary

In [None]:

#result = exposure.adjust_gamma(deconvolved_im, 0.5)
#result = (result*256).astype("uint16")
#deconvolved_im