In [None]:
def plot_binned_average(parent_projection_coords, axs1, col, channel, datasets, area_num, axis, binsize, sigma, cmap):
    
    hist_list = []
    num_dsets = len(datasets)
    
    for d in datasets:
        if area_num is None:
            points = np.array(d.cell_coords[channel]).T
        else:
            parent, children = bt.children_from(area_num, depth=0)
            areas = [parent] + children
            points = np.array(bt._get_cells_in(areas, d, channel)).T
        x_bins, y_bins, z_bins = btp.get_bins(0, binsize), btp.get_bins(1, binsize), btp.get_bins(2, binsize)
        hist, _ = np.histogramdd(points, bins=(x_bins, y_bins, z_bins), range=((0,1140),(0,800),(0,1320)), normed=False)
        
        if hist.sum() != 0:
            hist = hist / hist.sum() # turn into probability density distribution
        if sigma is not None: # 3D smooth # sigma = width of kernel
            x, y, z = np.arange(-3,4,1), np.arange(-3,4,1), np.arange(-3,4,1) # coordinate arrays -- make sure they include (0,0)!
            xx, yy, zz = np.meshgrid(x,y,z)
            kernel = np.exp(-(xx**2 + yy**2 + zz**2)/(2*sigma**2))
            hist = signal.convolve(hist, kernel, mode='same')
        hist = np.sum(hist, axis=axis) # take the maximum projection of the distribution
        atlas_res = 10
        scale = int(binsize / atlas_res) ## make ready for plotting
        hist = hist.repeat(scale, axis=0).repeat(scale, axis=1) # multiply up to the atlas resolution
        
        (px_min, py_min, pz_min), (px_max, py_max, pz_max) = parent_projection_coords
        if axis == 2:
            hist = hist[px_min : px_max, py_min : py_max] # crop the axes of the binned data that were scaled up to atlas resolution
        elif axis == 1:
            hist = hist[px_min : px_max, pz_min : pz_max]
        else:
            hist = hist[py_min : py_max, pz_min : pz_max]
        hist_list.append(hist)
    all_hists = np.array(hist_list) # get cell distributions for each dataset, ready for plotting
    
    for i in range(num_dsets):
        ax = axs1[i,col]
        d = datasets[i]
        this_hist = all_hists[i,:,:]
        
        this_hist = this_hist if axis == 0 else this_hist.T # side-on orientation does not need axis swapping
        ax.set_title(f'{d.group}  -  {d.name}')
        data_type = 'px' if d.fluorescence else 'cells'
        ax.annotate(f'{len(bt._get_cells_in(areas, d, channel)[0])} {data_type}', xy=(10, 40))
        
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        im = ax.imshow(this_hist, cmap=cmap) #, vmin=0, vmax=0.015)
        plt.colorbar(im, cax)
        
    # normal median
    av_im = np.median(all_hists, axis=0) # get the median cell distribution
    av_im = av_im / av_im.sum() # convert back into exact [0,1] probability map
    av_im = av_im if axis == 0 else av_im.T # side-on orientation does not need axis swapping
    
    ax = axs1[-2,col]
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    im = ax.imshow(av_im, cmap=cmap)# , vmin=0, vmax=0.000075)
    plt.colorbar(im, cax)
    ax.set_title("Median (eq. vmax)")
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    # weighted median
    new_hists = []
    for i, d in enumerate(datasets):
        hist = all_hists[i,:,:] * d.num_cells() # probability map [0,1] is multiplied by total number of cells in brain
        new_hists.append(hist)
    new_hists = np.array(new_hists)
    av_im = np.median(new_hists, axis=0) # get the median cell distribution
    av_im = av_im / av_im.sum() # convert back into exact [0,1] probability map
    av_im = av_im if axis == 0 else av_im.T # side-on orientation does not need axis swapping
    
    ax = axs1[-1,col]
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    im = ax.imshow(av_im, cmap=cmap)# , vmin=0, vmax=0.000075)
    plt.colorbar(im, cax)
    ax.set_title("Median weighted (eq. vmax)")
    ax.set_xticklabels([])
    ax.set_yticklabels([])

In [None]:
def setup_plot():
    dsets_g1 = [d for d in bt.datasets if d.group == btp.__get_bt_groups()[0] and d.fluorescence == False]
    dsets_g2 = [d for d in bt.datasets if d.group == btp.__get_bt_groups()[1] and d.fluorescence == False]
    dsets_g3 = [d for d in bt.datasets if d.group == btp.__get_bt_groups()[0] and d.fluorescence == True]
    dsets_g4 = [d for d in bt.datasets if d.group == btp.__get_bt_groups()[1] and d.fluorescence == True]
    dset_groups = [len(dsets_g1), len(dsets_g2), len(dsets_g3), len(dsets_g4)]
    max_dset_len = np.max(dset_groups)
    f, axs1 = plt.subplots(max_dset_len+2, 4, figsize=(14,(max_dset_len*2)+4))
    for icol, g in enumerate(dset_groups):
        if g < max_dset_len: # delete extraneous rows from columns that don't need all rows
            delrows = list(set(range(max_dset_len)) - set(range(g))) # get the indexes of one set that aren't in the other
            for i in delrows:
                f.delaxes(axs1[i,icol])
    f.tight_layout(rect=[0, 0, 1, 0.95])
    f.suptitle("Individual datasets")
    plt.text(x=0.25, y=0.96, s="Retrograde", fontsize=15, ha="center", transform=f.transFigure)
    plt.text(x=0.75, y=0.96, s= "Anterograde", fontsize=15, ha="center", transform=f.transFigure)
    for row in axs1:
        for ax in row:
            ax.set_xticklabels([])
            ax.set_yticklabels([])
    return axs1, [dsets_g1, dsets_g2, dsets_g3, dsets_g4]

In [None]:
def equalise_vlims(ax1, ax2):
    im1 = ax1.get_images()[0]
    im2 = ax2.get_images()[0]
    vmin1, vmax1 = im1.get_clim()
    vmin2, vmax2 = im2.get_clim()
    vmax = (vmax1 + vmax2) / 2
    im1.set_clim(vmin=0, vmax=vmax)
    im2.set_clim(vmin=0, vmax=vmax)

In [None]:
def region_pmap_summary(num, binsize, sigma, axis=2):
    axs1, dsets = setup_plot()
    parent_projection, min_coords, max_coords = btp.get_projection(num, padding=10, axis=2-axis)
    parent_projection_coords = (min_coords, max_coords)
    plot_binned_average(parent_projection_coords, axs1, 0, 'r', dsets[0], num, axis, binsize, sigma, cmap=cmap_reds)
    plot_binned_average(parent_projection_coords, axs1, 1, 'r', dsets[1], num, axis, binsize, sigma, cmap=cmap_blue)
    plot_binned_average(parent_projection_coords, axs1, 2, 'r', dsets[2], num, axis, binsize, sigma, cmap=cmap_reds)
    plot_binned_average(parent_projection_coords, axs1, 3, 'r', dsets[3], num, axis, binsize, sigma, cmap=cmap_blue)
    equalise_vlims(axs1[-2,0], axs1[-2,1])
    equalise_vlims(axs1[-1,0], axs1[-1,1])
    equalise_vlims(axs1[-2,2], axs1[-2,3])
    equalise_vlims(axs1[-1,2], axs1[-1,3])

In [None]:
indexes = bt.area_indexes.index.tolist()
rng = range(len(indexes))

In [None]:
def save_summary_pmaps(area_num):
    area = indexes[area_num]
    name = bt.get_area_info(area)[0][0]
    fname = name.replace('/', '_')
    fname = fname.replace(' ', '_')
    region_pmap_summary(area, binsize=50, sigma=1.5, axis=2)
    btf.save(f'{fname}_pmaps_crnl', as_type='pdf')
    region_pmap_summary(area, binsize=50, sigma=1.5, axis=1)
    btf.save(f'{fname}_pmaps_hriz', as_type='pdf')
    region_pmap_summary(area, binsize=50, sigma=1.5, axis=0)
    btf.save(f'{fname}_pmaps_sgtl', as_type='pdf')

In [None]:
for i in tqdm(rng):
    save_summary_pmaps(i)