In [None]:
# default_exp util

# Utils for plotting

> API details.


In [None]:
#export 
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
import numpy as np
from PIL import Image
from matplotlib.widgets import Slider
import torch as th
import matplotlib.cm as cma
from matplotlib_scalebar.scalebar import ScaleBar

def plot_tableau(S: th.tensor,
                         s_matrix_meta,
                         cbar_label: str,
                         plot_title = '',
                         nstd=4,
                         dx_scale_label='nm',
                         relative_axis_size=0.12,
                         return_fig=False):
    dx = s_matrix_meta.dx[0] / 10

    vmean = np.mean(S)
    vstd = np.std(S)

    vmin = vmean - nstd * vstd
    vmax = vmean + nstd * vstd

    r = s_matrix_meta.parent_beams_coords.cpu().float().numpy() * np.array([1, -1])
    r /= r.max()
    r /= 2.5

    q = s_matrix_meta.parent_beams_q.cpu().numpy() * np.array([1, -1])

    fig = plt.figure(dpi=600)
    fig.set_size_inches(18.300000 / 2.54, 18.300000 / 2.54, forward=True)
    ax = fig.subplots(r.shape[0] + 2)
    axs = ax.ravel()

    rel_size = np.array([relative_axis_size, relative_axis_size])

    for axi, ri, qi, si in zip(axs, r, q, S):
        axi.imshow(si, interpolation='nearest', cmap=plt.get_cmap('inferno'), vmin=vmin, vmax=vmax)
        axi.set_xticks([])
        axi.set_yticks([])
        axi.spines['top'].set_visible(False)
        axi.spines['right'].set_visible(False)
        axi.spines['bottom'].set_visible(False)
        axi.spines['left'].set_visible(False)
        psix = 0.5 - rel_size[0] / 2 + ri[0]
        psiy = 0.5 - rel_size[1] / 2 + ri[1]
        axi.set_position([psix, psiy, rel_size[0], rel_size[1]])
        plt.figure(1).text(psix + .005, psiy + rel_size[0] - .015,
                           f'q = ({qi[1]:2.1f},{qi[0]:2.1f})' + r' $\mathrm{\AA^{-1}}$',
                           transform=plt.figure(1).transFigure, fontsize=6)  # id=plt.figure(1).texts[0].new
        # plt.figure(1).texts[0].set_position([psix, psiy])

    im = np.zeros_like(si)
    im[:] = np.nan
    scalebar = ScaleBar(dx, dx_scale_label, length_fraction=0.75, height_fraction=0.1, location='lower right')
    axs[-2].imshow(im, interpolation='nearest', cmap=plt.get_cmap('inferno'), vmin=vmin, vmax=vmax)
    axs[-2].set_xticks([])
    axs[-2].set_yticks([])
    axs[-2].spines['top'].set_visible(False)
    axs[-2].spines['right'].set_visible(False)
    axs[-2].spines['bottom'].set_visible(False)
    axs[-2].spines['left'].set_visible(False)
    axs[-2].add_artist(scalebar)
    axs[-2].set_position([0.85, 0.03, rel_size[0], rel_size[1]])

    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
    cbar = fig.colorbar(cma.ScalarMappable(norm=norm, cmap=plt.get_cmap('inferno')), cax=axs[-1])
    axs[-1].set_position([0.01, 0.03, 0.02, 0.2])
    plt.figure(1).text(0.01, 0.01, cbar_label, transform=plt.figure(1).transFigure, fontsize=8)

    # % start: automatic generated code from pylustrator
    plt.figure(1).ax_dict = {ax.get_label(): ax for ax in plt.figure(1).axes}
    plt.figure(1).text(0.020760, 0.962231, plot_title, transform=plt.figure(1).transFigure)

    return fig if return_fig else None

def show3DStack(image_3d, axis = 2, cmap = "gray", clim = None, extent = (0, 1, 0, 1)):
    if clim is None:
        clim = (np.min(image_3d), np.max(image_3d)) 
    if axis == 0:
        image  = lambda index: image_3d[index, :, :]
    elif axis == 1:
        image  = lambda index: image_3d[:, index, :]
    else:
        image  = lambda index: image_3d[:, :, index]

    current_idx= 0
    _, ax      = plt.subplots(1, 1, figsize=(6.5, 5))
    plt.subplots_adjust(left=0.15, bottom=0.15)
    fig        = ax.imshow(image(current_idx), cmap = cmap,  clim = clim, extent = extent)
    ax.set_title("layer: " + str(current_idx))
    plt.colorbar(fig, ax=ax)
    plt.axis('off')
    ax_slider  = plt.axes([0.15, 0.1, 0.65, 0.03])
    slider_obj = Slider(ax_slider, "layer", 0, image_3d.shape[axis]-1, valinit=current_idx, valfmt='%d')
    def update_image(index):
        global current_idx
        index       = int(index)
        current_idx = index
        ax.set_title("layer: " + str(index))
        fig.set_data(image(index))
    def arrow_key(event):
        global current_idx
        if event.key == "left":
            if current_idx-1 >=0:
                current_idx -= 1
        elif event.key == "right":
            if current_idx+1 < image_3d.shape[axis]:
                current_idx += 1
        slider_obj.set_val(current_idx)
    slider_obj.on_changed(update_image)
    plt.gcf().canvas.mpl_connect("key_release_event", arrow_key)
    plt.show()
    return slider_obj

def compare3DStack(stack_1, stack_2, axis = 2, cmap = "gray", clim = (0, 1), extent = (0, 1, 0, 1) , colorbar = True, flag_vertical = False):
    assert stack_1.shape == stack_2.shape, "shape of two input stacks should be the same!"

    if axis == 0:
        image_1  = lambda index: stack_1[index, :, :]
        image_2  = lambda index: stack_2[index, :, :]
    elif axis == 1:
        image_1  = lambda index: stack_1[:, index, :]
        image_2  = lambda index: stack_2[:, index, :]
    else:
        image_1  = lambda index: stack_1[:, :, index]
        image_2  = lambda index: stack_2[:, :, index]

    current_idx  = 0
    if flag_vertical:
        _, ax        = plt.subplots(2, 1, figsize=(10, 2.5), sharex = 'all', sharey = 'all')
    else:
        _, ax        = plt.subplots(1, 2, figsize=(9, 5), sharex = 'all', sharey = 'all')
    plt.subplots_adjust(left=0.15, bottom=0.15)
    fig_1        = ax[0].imshow(image_1(current_idx), cmap = cmap,  clim = clim, extent = extent)
    ax[0].axis("off")
    ax[0].set_title("stack 1, layer: " + str(current_idx))
    if colorbar:
        plt.colorbar(fig_1, ax = ax[0])
    fig_2        = ax[1].imshow(image_2(current_idx), cmap = cmap,  clim = clim, extent = extent)
    ax[1].axis("off")
    ax[1].set_title("stack 2, layer: " + str(current_idx))
    if colorbar:
        plt.colorbar(fig_2, ax = ax[1])
    ax_slider    = plt.axes([0.10, 0.05, 0.65, 0.03])
    slider_obj   = Slider(ax_slider, 'layer', 0, stack_1.shape[axis]-1, valinit=current_idx, valfmt='%d')
    def update_image(index):
        global current_idx
        index       = int(index)
        current_idx = index
        ax[0].set_title("stack 1, layer: " + str(index))
        fig_1.set_data(image_1(index))
        ax[1].set_title("stack 2, layer: " + str(index))
        fig_2.set_data(image_2(index))
    def arrow_key(event):
        global current_idx
        if event.key == "left":
            if current_idx-1 >=0:
                current_idx -= 1
        elif event.key == "right":
            if current_idx+1 < stack_1.shape[axis]:
                current_idx += 1
        slider_obj.set_val(current_idx)
    slider_obj.on_changed(update_image)
    plt.gcf().canvas.mpl_connect("key_release_event", arrow_key)
    plt.show()
    return slider_obj

def compare4DStack(stack_1, stack_2, cmap = "gray", clim = (0, 1), extent = (0, 1, 0, 1) , colorbar = True, flag_vertical = False):
    assert stack_1.shape == stack_2.shape, "shape of two input stacks should be the same!"
    axis, axis2 = 2, 3
    image_1  = lambda index: stack_1[:, :, index[0], index[1]]
    image_2  = lambda index: stack_2[:, :, index[0], index[1]]

    current_idx1, current_idx2  = 0, 0
    if flag_vertical:
        _, ax        = plt.subplots(2, 1, figsize=(10, 2.5), sharex = 'all', sharey = 'all')
    else:
        _, ax        = plt.subplots(1, 2, figsize=(9, 5), sharex = 'all', sharey = 'all')
    plt.subplots_adjust(left=0.15, bottom=0.15)
    fig_1        = ax[0].imshow(image_1((current_idx1,current_idx2)), cmap = cmap,  clim = clim, extent = extent)
    ax[0].axis("off")
    ax[0].set_title("stack 1, layer: " + str(current_idx1) + " and " + str(current_idx2))
    if colorbar:
        plt.colorbar(fig_1, ax = ax[0])
    fig_2        = ax[1].imshow(image_2((current_idx1,current_idx2)), cmap = cmap,  clim = clim, extent = extent)
    ax[1].axis("off")
    ax[1].set_title("stack 2, layer: " + str(current_idx1) + " and " + str(current_idx2))
    if colorbar:
        plt.colorbar(fig_2, ax = ax[1])
    ax_slider    = plt.axes([0.10, 0.10, 0.65, 0.03])
    slider_obj   = Slider(ax_slider, 'layer', 0, stack_1.shape[axis]-1, valinit=current_idx1, valfmt='%d')
    def update_image(index):
        global current_idx1
        global current_idx2
        index       = int(index)
        current_idx1 = index
        current_idx2 = current_idx2
        ax[0].set_title("stack 1, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_1.set_data(image_1((current_idx1,current_idx2)))
        ax[1].set_title("stack 2, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_2.set_data(image_2((current_idx1,current_idx2)))
    ax_slider2    = plt.axes([0.10, 0.05, 0.65, 0.03])
    slider_obj2   = Slider(ax_slider2, 'layer', 0, stack_1.shape[axis2]-1, valinit=current_idx2, valfmt='%d')
    def update_image2(index):
        global current_idx1
        global current_idx2
        index       = int(index)
        current_idx1= current_idx1
        current_idx2= index
        ax[0].set_title("stack 1, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_1.set_data(image_1((current_idx1,current_idx2)))
        ax[1].set_title("stack 2, layer: " + str(current_idx1) + " and " + str(current_idx2))
        fig_2.set_data(image_2((current_idx1,current_idx2)))        
    def arrow_key(event):
        global current_idx1
        global current_idx2
        current_idx2 = current_idx2
        if event.key == "left":
            if current_idx1-1 >=0:
                current_idx1 -= 1
        elif event.key == "right":
            if current_idx1+1 < stack_1.shape[axis]:
                current_idx1 += 1
        slider_obj.set_val(current_idx1)
    slider_obj.on_changed(update_image)
    slider_obj2.on_changed(update_image2)
    plt.gcf().canvas.mpl_connect("key_release_event", arrow_key)
    plt.show()
    return slider_obj, slider_obj2






In [None]:
#export 
def HSV_to_RGB(cin):
    """\
    HSV to RGB transformation.
    """

    # HSV channels
    h, s, v = cin

    i = (6. * h).astype(int)
    f = (6. * h) - i
    p = v * (1. - s)
    q = v * (1. - s * f)
    t = v * (1. - s * (1. - f))
    i0 = (i % 6 == 0)
    i1 = (i == 1)
    i2 = (i == 2)
    i3 = (i == 3)
    i4 = (i == 4)
    i5 = (i == 5)

    imout = np.zeros(h.shape + (3,), dtype=h.dtype)
    imout[:, :, 0] = 255 * (i0 * v + i1 * q + i2 * p + i3 * p + i4 * t + i5 * v)
    imout[:, :, 1] = 255 * (i0 * t + i1 * v + i2 * v + i3 * q + i4 * p + i5 * p)
    imout[:, :, 2] = 255 * (i0 * p + i1 * p + i2 * t + i3 * v + i4 * v + i5 * q)

    return imout

In [None]:
#export 
def mosaic(data):
    n, w, h = data.shape
    diff = np.sqrt(n) - int(np.sqrt(n))
    s = np.sqrt(n)
    m = int(s)
    if diff > 1e-6: m += 1
    mosaic = np.zeros((m * w, m * h)).astype(data.dtype)
    for i in range(m):
        for j in range(m):
            if (i * m + j) < n:
                mosaic[i * w:(i + 1) * w, j * h:(j + 1) * h] = data[i * m + j]
    return mosaic

In [None]:
#export 
def plotmosaic(img, title='Image', savePath=None, cmap='hot', show=True, figsize=(10, 10), vmax=None):
    fig, ax = plt.subplots(figsize=figsize)
    mos = mosaic(img)
    cax = ax.imshow(mos, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax)
    cbar = fig.colorbar(cax)
    ax.set_title(title)
    plt.grid(False)
    plt.show()
    if savePath is not None:
        fig.savefig(savePath + '.png', dpi=600)

In [None]:
#export 
def P1A_to_HSV(cin, vmin=None, vmax=None):
    """\
    Transform a complex array into an RGB image,
    mapping phase to hue, amplitude to value and
    keeping maximum saturation.
    """
    # HSV channels
    h = .5 * np.angle(cin) / np.pi + .5
    s = np.ones(cin.shape)

    v = abs(cin)
    if vmin is None: vmin = 0.
    if vmax is None: vmax = v.max()
    assert vmin < vmax
    v = (v.clip(vmin, vmax) - vmin) / (vmax - vmin)

    return HSV_to_RGB((h, s, v))

In [None]:
#export 
def imsave(a, filename=None, vmin=None, vmax=None, cmap=None):
    """
    imsave(a) converts array a into, and returns a PIL image
    imsave(a, filename) returns the image and also saves it to filename
    imsave(a, ..., vmin=vmin, vmax=vmax) clips the array to values between vmin and vmax.
    imsave(a, ..., cmap=cmap) uses a matplotlib colormap.
    """

    if a.dtype.kind == 'c':
        # Image is complex
        if cmap is not None:
            print('imsave: Ignoring provided cmap - input array is complex')
        i = P1A_to_HSV(a, vmin, vmax)
        im = Image.fromarray(np.uint8(i), mode='RGB')

    else:
        if vmin is None:
            vmin = a.min()
        if vmax is None:
            vmax = a.max()
        im = Image.fromarray((255 * (a.clip(vmin, vmax) - vmin) / (vmax - vmin)).astype('uint8'))
        if cmap is not None:
            r = im.point(lambda x: cmap(x / 255.0)[0] * 255)
            g = im.point(lambda x: cmap(x / 255.0)[1] * 255)
            b = im.point(lambda x: cmap(x / 255.0)[2] * 255)
            im = Image.merge("RGB", (r, g, b))
            # b = (255*(a.clip(vmin,vmax)-vmin)/(vmax-vmin)).astype('uint8')
            # im = Image.fromstring('L', a.shape[-1::-1], b.tostring())

    if filename is not None:
        im.save(filename)
    return im

In [None]:
#export 
def plot_complex_multi(x, title='_', figsize=(10, 10), savePath=None, scale=None, show=True):
    n, h, w = x.shape
    rows = int(np.floor(np.sqrt(n)))
    cols = n // rows + 1
    fontprops = fm.FontProperties(size=18)
    fig = plt.figure(dpi=300, constrained_layout=True)
    gs1 = gridspec.GridSpec(rows, cols)
    gs1.update(wspace=0.1, hspace=0.1)
    for r in range(rows):
        for c in range(cols):
            i = cols * r + c
            ax = plt.subplot(gs1[i])
            if i < n:
                imax1 = ax.imshow(imsave(x[i]), interpolation='nearest')
                if scale is not None and i == 0:
                    scalebar = AnchoredSizeBar(ax.transData,
                                               scale[0], scale[1], 'lower right',
                                               pad=0.1,
                                               color='white',
                                               frameon=False,
                                               size_vertical=x.shape[0] / 40,
                                               fontproperties=fontprops)
                    ax.add_artist(scalebar)
            else:
                ax.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])

    plt.suptitle(title)
    plt.grid(False)
    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1)
    if savePath is not None:
        fig.savefig(savePath + '.png')
        fig.savefig(savePath + '.pdf')
    if show:
        plt.show()
    return fig

In [None]:
#export 
def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, figsize=(10, 10), scale=None):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.set_title(title)
    fontprops = fm.FontProperties(size=18)
    if scale is not None:
        scalebar = AnchoredSizeBar(ax.transData,
                                   scale[0], scale[1], 'lower right',
                                   pad=0.1,
                                   color='white',
                                   frameon=False,
                                   size_vertical=img.shape[0] / 40,
                                   fontproperties=fontprops)

        ax.add_artist(scalebar)
    ax.grid(False)
    if savePath is not None:
        fig.savefig(savePath + '.pdf', dpi=600)
        # fig.savefig(savePath + '.eps', dpi=600)
    if show:
        plt.show()
    return fig

In [None]:
#export 
def zplot(imgs, suptitle='Image', savePath=None, cmap=['hot', 'hsv'], title=['Abs', 'Phase'], show=True,
          figsize=(9, 5), scale=None):
    im1, im2 = imgs
    fig = plt.figure(figsize=figsize, dpi=300)
    fig.suptitle(suptitle, fontsize=15, y=0.8)
    gs1 = gridspec.GridSpec(1, 2)
    gs1.update(wspace=0, hspace=0)  # set the spacing between axes.
    ax1 = plt.subplot(gs1[0])
    ax2 = plt.subplot(gs1[1])
    div1 = make_axes_locatable(ax1)
    div2 = make_axes_locatable(ax2)

    imax1 = ax1.imshow(im1, interpolation='nearest', cmap=plt.cm.get_cmap(cmap[0]))
    imax2 = ax2.imshow(im2, interpolation='nearest', cmap=plt.cm.get_cmap(cmap[1]))

    cax1 = div1.append_axes("left", size="10%", pad=0.4)
    cax2 = div2.append_axes("right", size="10%", pad=0.4)

    cbar1 = plt.colorbar(imax1, cax=cax1)
    cbar2 = plt.colorbar(imax2, cax=cax2)

    cax1.yaxis.set_ticks_position('left')
    ax2.yaxis.set_ticks_position('right')

    ax1.set_title(title[0])
    ax2.set_title(title[1])

    if scale is not None:
        fontprops = fm.FontProperties(size=18)
        scalebar = AnchoredSizeBar(ax1.transData,
                                   scale[0], scale[1], 'lower right',
                                   pad=0.1,
                                   color='white',
                                   frameon=False,
                                   size_vertical=im1.shape[0] / 40,
                                   fontproperties=fontprops)

        ax1.add_artist(scalebar)

    ax1.grid(False)
    ax2.grid(False)

    if show:
        plt.show()
    if savePath is not None:
        # print 'saving'
        fig.savefig(savePath + '.png')


def plotAbsAngle(img, suptitle='Image', savePath=None, cmap=['gray', 'gray'], title=['Abs', 'Phase'], show=True,
                 figsize=(10, 10), scale=None):
    zplot([np.abs(img), np.angle(img)], suptitle, savePath, cmap, title, show, figsize, scale)