In [None]:
# default_exp util

# Utils for plotting

> API details.


In [None]:
#export 
import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
import numpy as np
from PIL import Image
from matplotlib.widgets import Slider
import smpr3d.operators as op
import torch as th

MAX_DIM = 512*512*850

def show3DStack(image_3d, axis = 2, cmap = "gray", clim = None, extent = (0, 1, 0, 1)):
    if clim is None:
        clim = (np.min(image_3d), np.max(image_3d)) 
    if axis == 0:
        image  = lambda index: image_3d[index, :, :]
    elif axis == 1:
        image  = lambda index: image_3d[:, index, :]
    else:
        image  = lambda index: image_3d[:, :, index]

    current_idx= 0
    _, ax      = plt.subplots(1, 1, figsize=(6.5, 5))
    plt.subplots_adjust(left=0.15, bottom=0.15)
    fig        = ax.imshow(image(current_idx), cmap = cmap,  clim = clim, extent = extent)
    ax.set_title("layer: " + str(current_idx))
    plt.colorbar(fig, ax=ax)
    plt.axis('off')
    ax_slider  = plt.axes([0.15, 0.1, 0.65, 0.03])
    slider_obj = Slider(ax_slider, "layer", 0, image_3d.shape[axis]-1, valinit=current_idx, valfmt='%d')
    def update_image(index):
        global current_idx
        index       = int(index)
        current_idx = index
        ax.set_title("layer: " + str(index))
        fig.set_data(image(index))
    def arrow_key(event):
        global current_idx
        if event.key == "left":
            if current_idx-1 >=0:
                current_idx -= 1
        elif event.key == "right":
            if current_idx+1 < image_3d.shape[axis]:
                current_idx += 1
        slider_obj.set_val(current_idx)
    slider_obj.on_changed(update_image)
    plt.gcf().canvas.mpl_connect("key_release_event", arrow_key)
    plt.show()
    return slider_obj

def compare3DStack(stack_1, stack_2, axis = 2, cmap = "gray", clim = (0, 1), extent = (0, 1, 0, 1) , colorbar = True, flag_vertical = False):
    assert stack_1.shape == stack_2.shape, "shape of two input stacks should be the same!"

    if axis == 0:
        image_1  = lambda index: stack_1[index, :, :]
        image_2  = lambda index: stack_2[index, :, :]
    elif axis == 1:
        image_1  = lambda index: stack_1[:, index, :]
        image_2  = lambda index: stack_2[:, index, :]
    else:
        image_1  = lambda index: stack_1[:, :, index]
        image_2  = lambda index: stack_2[:, :, index]

    current_idx  = 0
    if flag_vertical:
        _, ax        = plt.subplots(2, 1, figsize=(10, 2.5), sharex = 'all', sharey = 'all')
    else:
        _, ax        = plt.subplots(1, 2, figsize=(9, 5), sharex = 'all', sharey = 'all')
    plt.subplots_adjust(left=0.15, bottom=0.15)
    fig_1        = ax[0].imshow(image_1(current_idx), cmap = cmap,  clim = clim, extent = extent)
    ax[0].axis("off")
    ax[0].set_title("stack 1, layer: " + str(current_idx))
    if colorbar:
        plt.colorbar(fig_1, ax = ax[0])
    fig_2        = ax[1].imshow(image_2(current_idx), cmap = cmap,  clim = clim, extent = extent)
    ax[1].axis("off")
    ax[1].set_title("stack 2, layer: " + str(current_idx))
    if colorbar:
        plt.colorbar(fig_2, ax = ax[1])
    ax_slider    = plt.axes([0.10, 0.05, 0.65, 0.03])
    slider_obj   = Slider(ax_slider, 'layer', 0, stack_1.shape[axis]-1, valinit=current_idx, valfmt='%d')
    def update_image(index):
        global current_idx
        index       = int(index)
        current_idx = index
        ax[0].set_title("stack 1, layer: " + str(index))
        fig_1.set_data(image_1(index))
        ax[1].set_title("stack 2, layer: " + str(index))
        fig_2.set_data(image_2(index))
    def arrow_key(event):
        global current_idx
        if event.key == "left":
            if current_idx-1 >=0:
                current_idx -= 1
        elif event.key == "right":
            if current_idx+1 < stack_1.shape[axis]:
                current_idx += 1
        slider_obj.set_val(current_idx)
    slider_obj.on_changed(update_image)
    plt.gcf().canvas.mpl_connect("key_release_event", arrow_key)
    plt.show()
    return slider_obj

def compare4DStack(stack_1, stack_2, cmap = "gray", clim = (0, 1), extent = (0, 1, 0, 1) , colorbar = True, flag_vertical = False):
    assert stack_1.shape == stack_2.shape, "shape of two input stacks should be the same!"
    axis, axis2 = 2, 3
    image_1  = lambda index: stack_1[:, :, index[0], index[1]]
    image_2  = lambda index: stack_2[:, :, index[0], index[1]]

    current_idx1, current_idx2  = 0, 0
    if flag_vertical:
        _, ax        = plt.subplots(2, 1, figsize=(10, 2.5), sharex = 'all', sharey = 'all')
    else:
        _, ax        = plt.subplots(1, 2, figsize=(9, 5), sharex = 'all', sharey = 'all')
    plt.subplots_adjust(left=0.15, bottom=0.15)
    fig_1        = ax[0].imshow(image_1((current_idx1,current_idx2)), cmap = cmap,  clim = clim, extent = extent)
    ax[0].axis("off")
    ax[0].set_title("stack 1, layer: " + str(current_idx1) + " and " + str(current_idx2))
    if colorbar:
        plt.colorbar(fig_1, ax = ax[0])
    fig_2        = ax[1].imshow(image_2((current_idx1,current_idx2)), cmap = cmap,  clim = clim, extent = extent)
    ax[1].axis("off")
    ax[1].set_title("stack 2, layer: " + str(current_idx1) + " and " + str(current_idx2))
    if colorbar:
        plt.colorbar(fig_2, ax = ax[1])
    ax_slider    = plt.axes([0.10, 0.10, 0.65, 0.03])
    slider_obj   = Slider(ax_slider, 'layer', 0, stack_1.shape[axis]-1, valinit=current_idx1, valfmt='%d')
    def update_image(index):
        global current_idx1
        global current_idx2
        index       = int(index)
        current_idx1 = index
        current_idx2 = current_idx2
        ax[0].set_title("stack 1, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_1.set_data(image_1((current_idx1,current_idx2)))
        ax[1].set_title("stack 2, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_2.set_data(image_2((current_idx1,current_idx2)))
    ax_slider2    = plt.axes([0.10, 0.05, 0.65, 0.03])
    slider_obj2   = Slider(ax_slider2, 'layer', 0, stack_1.shape[axis2]-1, valinit=current_idx2, valfmt='%d')
    def update_image2(index):
        global current_idx1
        global current_idx2
        index       = int(index)
        current_idx1= current_idx1
        current_idx2= index
        ax[0].set_title("stack 1, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_1.set_data(image_1((current_idx1,current_idx2)))
        ax[1].set_title("stack 2, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_2.set_data(image_2((current_idx1,current_idx2)))        
    def arrow_key(event):
        global current_idx1
        global current_idx2
        current_idx2 = current_idx2
        if event.key == "left":
            if current_idx1-1 >=0:
                current_idx1 -= 1
        elif event.key == "right":
            if current_idx1+1 < stack_1.shape[axis]:
                current_idx1 += 1
        slider_obj.set_val(current_idx1)
    slider_obj.on_changed(update_image)
    slider_obj2.on_changed(update_image2)
    plt.gcf().canvas.mpl_connect("key_release_event", arrow_key)
    plt.show()
    return slider_obj, slider_obj2


def generate_grid_1d(shape, pixel_size = 1, flag_fourier = False, dtype = th.float32, device = th.device('cuda')):
    """
    This function generates 1D Fourier grid, and is centered at the middle of the array
    Inputs:
        shape    - length of the array
        pixel_size      - pixel size
    Optional parameters:
        flag_fourier - flag indicating whether the final array is circularly shifted
                     should be false when computing real space coordinates
                     should be true when computing Fourier coordinates
    Outputs:
        x_lin       - 1D grid (real or fourier)

    """
    pixel_size = 1./pixel_size/shape if flag_fourier else pixel_size
    x_lin = (th.arange(shape, dtype=dtype, device=device) - shape//2) * pixel_size
    if flag_fourier:
        x_lin = th.roll(x_lin, -1 * int(shape)//2)
    return x_lin

def generate_grid_2d(shape, pixel_size = 1, flag_fourier = False, dtype = th.float32, device = th.device('cuda')):
    """
    This function generates 2D Fourier grid, and is centered at the middle of the array
    Inputs:
        shape              - shape of the grid (number of y pixels, number of x pixels)
        pixel_size         - pixel size
    Optional parameters:
        flag_fourier       - flag indicating whether the final array is circularly shifted
                             should be false when computing real space coordinates
                             should be true when computing Fourier coordinates
    Outputs:
        y_lin, x_lin       - 2D grid
    Usage:
        y_lin, x_lin = generate_grid_2d(...)

    """    
    assert len(shape) == 2, "shape should be two dimensional!"
    #recompute pixel size for fourier space sampling
    y_lin  = generate_grid_1d(shape[0], pixel_size, flag_fourier = flag_fourier, dtype=dtype, device=device)
    x_lin  = generate_grid_1d(shape[1], pixel_size, flag_fourier = flag_fourier, dtype=dtype, device=device)
    y_lin, x_lin = th.meshgrid(y_lin, x_lin)
    return y_lin, x_lin


class ImageRotation:
    """
    A rotation class compute 3D rotation using FFT
    """
    def __init__(self, shape, axis = 0, pad = True, pad_value = 0, dtype = th.float32, device = th.device('cuda')):
        self.dim       = np.array(shape)
        self.axis      = axis
        self.pad_value = pad_value
        if pad:
            self.pad_size            = np.ceil(self.dim / 2.0).astype('int')
            self.pad_size[self.axis] = 0
            self.dim                += 2*self.pad_size
        else:
            self.pad_size  = np.asarray([0,0,0])
        
        self.dim          = [int(size) for size in self.dim]

        self.range_crop_y = slice(self.pad_size[0],self.pad_size[0] + shape[0])
        self.range_crop_x = slice(self.pad_size[1],self.pad_size[1] + shape[1])
        self.range_crop_z = slice(self.pad_size[2],self.pad_size[2] + shape[2])

        self.y            = generate_grid_1d(self.dim[0], dtype=dtype, device=device).unsqueeze(-1).unsqueeze(-1)
        self.x            = generate_grid_1d(self.dim[1], dtype=dtype, device=device).unsqueeze(0).unsqueeze(-1)
        self.z            = generate_grid_1d(self.dim[2], dtype=dtype, device=device).unsqueeze(0).unsqueeze(0)
        
        self.ky           = generate_grid_1d(self.dim[0], flag_fourier = True, dtype=dtype, device=device).unsqueeze(-1).unsqueeze(-1)
        self.kx           = generate_grid_1d(self.dim[1], flag_fourier = True, dtype=dtype, device=device).unsqueeze(0).unsqueeze(-1)
        self.kz           = generate_grid_1d(self.dim[2], flag_fourier = True, dtype=dtype, device=device).unsqueeze(0).unsqueeze(0)

        #Compute FFTs sequentially if object size is too large
        self.slice_per_tile = int(np.min([np.floor(MAX_DIM * self.dim[self.axis] / np.prod(self.dim)), self.dim[self.axis]]))            
        self.dtype          = dtype
        self.device         = device

        if self.axis == 0:
            self.coord_phase_1 = -2.0 * np.pi * self.kz * self.x
            self.coord_phase_2 = -2.0 * np.pi * self.kx * self.z
        elif self.axis == 1:
            self.coord_phase_1 = -2.0 * np.pi * self.kz * self.y
            self.coord_phase_2 = -2.0 * np.pi * self.ky * self.z
        elif self.axis == 2:
            self.coord_phase_1 = -2.0 * np.pi * self.kx * self.y
            self.coord_phase_2 = -2.0 * np.pi * self.ky * self.x

    def _rotate_3d(self, obj, shear_phase_1, shear_phase_2):
        """
        This function rotates a 3D image by shearing, (applied in Fourier space)
        ** Note: the rotation is performed along the z axis

        [ cos(theta)  -sin(theta) ] = [ 1  alpha ] * [ 1     0  ] * [ 1  alpha ]
        [ sin(theta)  cos(theta)  ]   [ 0    1   ]   [ beta  1  ]   [ 0    1   ]
        alpha = tan(theta/2)
        beta = -sin(theta)

        Shearing in one shapeension is applying phase shift in 1D fourier transform
        Input:
          obj: 3D array (supposed to be an image), the axes are [z,y,x]
          theta: desired angle of rotation in *degrees*
        Output:
          obj_rotate: rotate 3D array
        """
        flag_complex = obj.is_complex()
        self.obj_rotate[self.range_crop_y, self.range_crop_x, self.range_crop_z] = op.r2c(obj)
        if self.axis == 0:
            self.obj_rotate = op.convolve_kernel(self.obj_rotate, shear_phase_1) #y,x,z
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([0,2,1]), shear_phase_2.permute([0,2,1])) #y,z,x
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([0,2,1]), shear_phase_1) #y,x,z

        elif self.axis == 1:
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([1,0,2]), shear_phase_1.permute([1,0,2])) #x,y,z
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([0,2,1]), shear_phase_2.permute([1,2,0])) #x,z,y
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([0,2,1]), shear_phase_1.permute([1,0,2])) #x,y,z
            self.obj_rotate = self.obj_rotate.permute([1,0,2])

        elif self.axis == 2:
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([2,0,1]), shear_phase_1.permute([2,0,1])) #z,y,x
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([0,2,1]), shear_phase_2.permute([2,1,0])) #z,x,y
            self.obj_rotate = op.convolve_kernel(self.obj_rotate.permute([0,2,1]), shear_phase_1.permute([2,0,1])) #z,y,x
            self.obj_rotate = self.obj_rotate.permute([1,2,0])
        if flag_complex:
            obj[:] = self.obj_rotate[self.range_crop_y, self.range_crop_x, self.range_crop_z]
        else:
            obj[:] = self.obj_rotate[self.range_crop_y, self.range_crop_x, self.range_crop_z].real
        return obj

    def forward(self, obj, theta):
        self.theta = theta
        if theta == 0:
            return obj
        else:
            flag_cpu = False
            if self.device == th.device('cuda'):
                if not obj.is_cuda:
                    flag_cpu = True
            #         obj = obj.to(self.device)
            theta      *= np.pi / 180.0
            alpha       = 1.0 * np.tan(theta / 2.0)
            beta        = np.sin(-1.0 * theta)

            shear_phase_1 = th.exp(1j * self.coord_phase_1 * alpha)
            shear_phase_2 = th.exp(1j * self.coord_phase_2 * beta)

            self.dim[self.axis] = self.slice_per_tile
            self.obj_rotate = op.r2c(th.ones([self.dim[0], self.dim[1], self.dim[2]], dtype=self.dtype, device=self.device) * self.pad_value)

            for idx_start in range(0, obj.shape[self.axis], self.slice_per_tile):
                idx_end = np.min([obj.shape[self.axis], idx_start+self.slice_per_tile])
                idx_slice = slice(idx_start, idx_end)
                self.dim[self.axis] = int(idx_end - idx_start)
                if self.axis == 0:
                    self.range_crop_y = slice(0, self.dim[self.axis])
                    obj[idx_slice,:,:] = self._rotate_3d(obj[idx_slice,:,:].cuda(), shear_phase_1, shear_phase_2).cpu()
                elif self.axis == 1:
                    self.range_crop_x = slice(0, self.dim[self.axis])
                    obj[:,idx_slice,:] = self._rotate_3d(obj[:,idx_slice,:].cuda(), shear_phase_1, shear_phase_2).cpu()
                elif self.axis == 2:
                    self.range_crop_z = slice(0, self.dim[self.axis])
                    obj[:,:,idx_slice] = self._rotate_3d(obj[:,:,idx_slice].cuda(), shear_phase_1, shear_phase_2).cpu()
                self.obj_rotate[:] = self.pad_value + 0.j
            self.dim[self.axis] = obj.shape[self.axis]
            self.obj_rotate = None
            if self.device == th.device('cuda'):
                th.cuda.empty_cache()
            if flag_cpu:
                obj = obj.cpu()
            return obj

    def backward(self, obj):
        theta = -1 * self.theta
        if theta == 0:
            return obj
        else:
            if self.device == th.device("cuda"):
                if not obj.is_cuda:
                    obj = obj.to(self.device)
            theta      *= np.pi / 180.0
            alpha       = 1.0 * np.tan(theta / 2.0)
            beta        = np.sin(-1.0 * theta)
            
            shear_phase_1 = th.exp(1j * self.coord_phase_1 * alpha)
            shear_phase_2 = th.exp(1j * self.coord_phase_2 * beta)

            self.dim[self.axis] = self.slice_per_tile
            self.obj_rotate = op.r2c(th.zeros([self.dim[0], self.dim[1], self.dim[2]], dtype=self.dtype, device=self.device))

            for idx_start in range(0, obj.shape[self.axis], self.slice_per_tile):
                idx_end = np.min([obj.shape[self.axis], idx_start+self.slice_per_tile])
                idx_slice = slice(idx_start, idx_end)
                self.dim[self.axis] = int(idx_end - idx_start)
                if self.axis == 0:
                    self.range_crop_y = slice(0, self.dim[self.axis])
                    obj[idx_slice,:,:] = self._rotate_3d(obj[idx_slice,:,:], alpha, beta, shear_phase_1, shear_phase_2)
                elif self.axis == 1:
                    self.range_crop_x = slice(0, self.dim[self.axis])
                    obj[:,idx_slice,:] = self._rotate_3d(obj[:,idx_slice,:], alpha, beta, shear_phase_1, shear_phase_2)
                elif self.axis == 2:
                    self.range_crop_z = slice(0, self.dim[self.axis])
                    obj[:,:,idx_slice] = self._rotate_3d(obj[:,:,idx_slice], alpha, beta, shear_phase_1, shear_phase_2)
                self.obj_rotate[:] = 0.0
            self.dim[self.axis] = obj.shape[self.axis]
            self.obj_rotate = None         
            if not obj.is_cuda:
                obj = obj.cpu()
            return obj



In [None]:
#export 
def HSV_to_RGB(cin):
    """\
    HSV to RGB transformation.
    """

    # HSV channels
    h, s, v = cin

    i = (6. * h).astype(int)
    f = (6. * h) - i
    p = v * (1. - s)
    q = v * (1. - s * f)
    t = v * (1. - s * (1. - f))
    i0 = (i % 6 == 0)
    i1 = (i == 1)
    i2 = (i == 2)
    i3 = (i == 3)
    i4 = (i == 4)
    i5 = (i == 5)

    imout = np.zeros(h.shape + (3,), dtype=h.dtype)
    imout[:, :, 0] = 255 * (i0 * v + i1 * q + i2 * p + i3 * p + i4 * t + i5 * v)
    imout[:, :, 1] = 255 * (i0 * t + i1 * v + i2 * v + i3 * q + i4 * p + i5 * p)
    imout[:, :, 2] = 255 * (i0 * p + i1 * p + i2 * t + i3 * v + i4 * v + i5 * q)

    return imout

In [None]:
#export 
def P1A_to_HSV(cin, vmin=None, vmax=None):
    """\
    Transform a complex array into an RGB image,
    mapping phase to hue, amplitude to value and
    keeping maximum saturation.
    """
    # HSV channels
    h = .5 * np.angle(cin) / np.pi + .5
    s = np.ones(cin.shape)

    v = abs(cin)
    if vmin is None: vmin = 0.
    if vmax is None: vmax = v.max()
    assert vmin < vmax
    v = (v.clip(vmin, vmax) - vmin) / (vmax - vmin)

    return HSV_to_RGB((h, s, v))

In [None]:
#export 
def imsave(a, filename=None, vmin=None, vmax=None, cmap=None):
    """
    imsave(a) converts array a into, and returns a PIL image
    imsave(a, filename) returns the image and also saves it to filename
    imsave(a, ..., vmin=vmin, vmax=vmax) clips the array to values between vmin and vmax.
    imsave(a, ..., cmap=cmap) uses a matplotlib colormap.
    """

    if a.dtype.kind == 'c':
        # Image is complex
        if cmap is not None:
            print('imsave: Ignoring provided cmap - input array is complex')
        i = P1A_to_HSV(a, vmin, vmax)
        im = Image.fromarray(np.uint8(i), mode='RGB')

    else:
        if vmin is None:
            vmin = a.min()
        if vmax is None:
            vmax = a.max()
        im = Image.fromarray((255 * (a.clip(vmin, vmax) - vmin) / (vmax - vmin)).astype('uint8'))
        if cmap is not None:
            r = im.point(lambda x: cmap(x / 255.0)[0] * 255)
            g = im.point(lambda x: cmap(x / 255.0)[1] * 255)
            b = im.point(lambda x: cmap(x / 255.0)[2] * 255)
            im = Image.merge("RGB", (r, g, b))
            # b = (255*(a.clip(vmin,vmax)-vmin)/(vmax-vmin)).astype('uint8')
            # im = Image.fromstring('L', a.shape[-1::-1], b.tostring())

    if filename is not None:
        im.save(filename)
    return im

In [None]:
#export 
def plot_complex_multi(x, title='_', figsize=(10, 10), savePath=None, scale=None, show=True):
    n, h, w = x.shape
    rows = int(np.floor(np.sqrt(n)))
    cols = n // rows + 1
    fontprops = fm.FontProperties(size=18)
    fig = plt.figure(dpi=300, constrained_layout=True)
    gs1 = gridspec.GridSpec(rows, cols)
    gs1.update(wspace=0.1, hspace=0.1)
    for r in range(rows):
        for c in range(cols):
            i = cols * r + c
            ax = plt.subplot(gs1[i])
            if i < n:
                imax1 = ax.imshow(imsave(x[i]), interpolation='nearest')
                if scale is not None and i == 0:
                    scalebar = AnchoredSizeBar(ax.transData,
                                               scale[0], scale[1], 'lower right',
                                               pad=0.1,
                                               color='white',
                                               frameon=False,
                                               size_vertical=x.shape[0] / 40,
                                               fontproperties=fontprops)
                    ax.add_artist(scalebar)
            else:
                ax.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])

    plt.suptitle(title)
    plt.grid(False)
    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1)
    if savePath is not None:
        fig.savefig(savePath + '.png')
        fig.savefig(savePath + '.pdf')
    if show:
        plt.show()
    return fig

In [None]:
#export 
def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, figsize=(10, 10), scale=None):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.set_title(title)
    fontprops = fm.FontProperties(size=18)
    if scale is not None:
        scalebar = AnchoredSizeBar(ax.transData,
                                   scale[0], scale[1], 'lower right',
                                   pad=0.1,
                                   color='white',
                                   frameon=False,
                                   size_vertical=img.shape[0] / 40,
                                   fontproperties=fontprops)

        ax.add_artist(scalebar)
    ax.grid(False)
    if savePath is not None:
        fig.savefig(savePath + '.pdf', dpi=600)
        # fig.savefig(savePath + '.eps', dpi=600)
    if show:
        plt.show()
    return fig

In [None]:
#export 
def zplot(imgs, suptitle='Image', savePath=None, cmap=['hot', 'hsv'], title=['Abs', 'Phase'], show=True,
          figsize=(9, 5), scale=None):
    im1, im2 = imgs
    fig = plt.figure(figsize=figsize, dpi=300)
    fig.suptitle(suptitle, fontsize=15, y=0.8)
    gs1 = gridspec.GridSpec(1, 2)
    gs1.update(wspace=0, hspace=0)  # set the spacing between axes.
    ax1 = plt.subplot(gs1[0])
    ax2 = plt.subplot(gs1[1])
    div1 = make_axes_locatable(ax1)
    div2 = make_axes_locatable(ax2)

    imax1 = ax1.imshow(im1, interpolation='nearest', cmap=plt.cm.get_cmap(cmap[0]))
    imax2 = ax2.imshow(im2, interpolation='nearest', cmap=plt.cm.get_cmap(cmap[1]))

    cax1 = div1.append_axes("left", size="10%", pad=0.4)
    cax2 = div2.append_axes("right", size="10%", pad=0.4)

    cbar1 = plt.colorbar(imax1, cax=cax1)
    cbar2 = plt.colorbar(imax2, cax=cax2)

    cax1.yaxis.set_ticks_position('left')
    ax2.yaxis.set_ticks_position('right')

    ax1.set_title(title[0])
    ax2.set_title(title[1])

    if scale is not None:
        fontprops = fm.FontProperties(size=18)
        scalebar = AnchoredSizeBar(ax1.transData,
                                   scale[0], scale[1], 'lower right',
                                   pad=0.1,
                                   color='white',
                                   frameon=False,
                                   size_vertical=im1.shape[0] / 40,
                                   fontproperties=fontprops)

        ax1.add_artist(scalebar)

    ax1.grid(False)
    ax2.grid(False)

    if show:
        plt.show()
    if savePath is not None:
        # print 'saving'
        fig.savefig(savePath + '.png')


def plotAbsAngle(img, suptitle='Image', savePath=None, cmap=['gray', 'gray'], title=['Abs', 'Phase'], show=True,
                 figsize=(10, 10), scale=None):
    zplot([np.abs(img), np.angle(img)], suptitle, savePath, cmap, title, show, figsize, scale)