# Figures 5, S6
This script runs our analysis of Tre1 intoxication of *E. coli* by *S. proteamaculans*. 

In [None]:
!nvcc --version
!nvidia-smi
%load_ext autoreload
%autoreload 2

import numpy as np
import time, os, sys
from urllib.parse import urlparse
from urllib.parse import urlparse
from cellpose import models, core, io
from skimage import measure
from sklearn.cluster import k_means

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

In [None]:
import time, os, sys
import skimage.io
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
from cellpose import utils, io, transforms
# from cupyx import scipy

# np.abs(np.array([1+1j,1]))
# np.asarray(1)

In [None]:
from pathlib import Path
home = str(Path.home())
basedir = '/home/kcutler/DataDrive/ecVSsp/'
phase = Path(basedir).rglob('*c1.tif')
pfiles = [str(p) for p in phase]
gfiles = [p[:-5]+'2.tif' for p in pfiles]

In [None]:
files = pfiles
imgs = [skimage.io.imread(f) for f in pfiles]
gfps = [skimage.io.imread(f) for f in gfiles]
nimg = len(imgs)
print(nimg)
from cellpose import transforms
from skimage.filters import gaussian
# plt.figure(figsize=(10,10))

def localnormalize(im,sigma):
    return im/gaussian(im,sigma)


for k in range(len(imgs)):
    img = transforms.move_min_dim(imgs[k])
    gfp = transforms.move_min_dim(gfps[k])
    if len(img.shape)>2:
        imgs[k] = img[:,:,0]
        gfps[k] = gfp[:,:,0]
    imgs[k] = transforms.normalize99(imgs[k],omni=True)
    # gfps[k] = localnormalize(gfps[k],sigma=100)
    gfps[k] = transforms.normalize99(gfps[k],omni=True)

In [None]:
# check that everything matches
i = 5
g = gfps[i].copy()
g = g-gaussian(g,15)

fig = plt.figure(figsize=(20,20))
plt.imshow(np.hstack((imgs[i],gfps[i],g)))
# p = plt.hist(g[g>.01],bins=100)
# plt.show()

In [None]:
# CP corresponds to 0, omni to 1
# modeldir = ['/home/kcutler/DataDrive/omnipose_train/registered/models/cellpose_residual_on_style_on_concatenation_off_registered_2021_07_22_11_48_00.465748',
#             '/home/kcutler/DataDrive/omnipose_train/registered/models/cellpose_residual_on_style_on_concatenation_off_registered_2021_09_30_21_41_17.488619_epoch_3999']
model_type = ['bact','bact_omni']
model = [models.CellposeModel(gpu=use_GPU, model_type=model_type[k]) for k in range(2)]
suffix = ['_CP','_OP']
names = ['Cellpose','Omnipose']

In [None]:
imglist = imgs[:] #easier selection
n = len(names)
N = len(imgs)
J = range(len(model))

In [None]:
masks,flows,styles,d = [[]]*n,[[]]*n,[[]]*n,[[]]*n

In [None]:
clean = False
if clean:
    chans = [0,0]
    for j in J:
        masks[j], flows[j], styles[j] = model[j].eval(imglist,channels=chans,
                                                      rescale=None,mask_threshold=-1,transparency=True,
                                                      flow_threshold=0,omni=j,resample=False,tile=True)

In [None]:
if clean:
    for j in J:
        savedir = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/'
        if not os.path.isdir(savedir):
            os.mkdir(savedir)
        io.save_masks(imglist, masks[j], flows[j], pfiles, suffix = suffix[j],save_flows=1,
                      save_outlines=1,savedir=savedir,in_folders=True,save_txt=False)

## Load masks

In [None]:
#Already run segmentation
def getname(path,suffix=''):
    return os.path.splitext(Path(path).name)[0].replace(suffix,'')
pnames = [getname(p) for p in pfiles]

#add a line to read back in masks here
maskdir = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/masks/'
om_files = [str(s) for s in Path(maskdir).rglob('*masks_OP.png')]
cp_files = [str(s) for s in Path(maskdir).rglob('*masks_CP.png')]
om_names = [getname(s,suffix='_cp_masks_OP') for s in om_files]
cp_names = [getname(s,suffix='_cp_masks_CP') for s in cp_files]
om_ind = np.array([ [idx for idx,s in enumerate(om_names) if s==p] for p in pnames]).flatten()
cp_ind = np.array([ [idx for idx,s in enumerate(cp_names) if s==p] for p in pnames]).flatten()
om_files = [om_files[i] for i in om_ind]
cp_files = [cp_files[i] for i in cp_ind]

om_masks = [skimage.io.imread(f)for f in om_files]
cp_masks = [skimage.io.imread(f)for f in cp_files]

In [None]:
#stardist masks
maskdir_sd = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/StarDist/'
sd_files = [str(s) for s in Path(maskdir_sd).rglob('*.tif')]
sd_names = [getname(s,suffix='_stardist_masks') for s in sd_files]
sd_ind = np.array([ [idx for idx,s in enumerate(sd_names) if s==p] for p in pnames]).flatten()
sd_files = [sd_files[i] for i in sd_ind]
sd_masks = [skimage.io.imread(f) for f in sd_files]

In [None]:
#MiSiC masks
maskdir_ms = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/MiSiC/'
ms_files = [str(s) for s in Path(maskdir_ms).rglob('*.tif')]
ms_names = [getname(s,suffix='_masks') for s in ms_files]
ms_ind = np.array([ [idx for idx,s in enumerate(ms_names) if s==p] for p in pnames]).flatten()
ms_files = [ms_files[i] for i in ms_ind]
ms_masks = [skimage.io.imread(f) for f in ms_files]

## Define experimental categories 
Image set contains both initial time point and 20hr of wild-type and mutant (control) co-cultures. Here we separate the images into four groups, and only analyse two of them (20hr time point). 

In [None]:
names = ['StarDist','Cellpose','Omnipose','MiSiC']
abbrev =  ['SD','CP','OP','MS']
masks = [sd_masks,cp_masks,om_masks,ms_masks]
J = range(len(masks))
# plt.imshow(flows[0][0][0])
def getname(path,suffix=''):
        return os.path.splitext(Path(path).name)[0].replace(suffix,'')
img_names = [getname(f) for f in pfiles]

cat1 = ['wt','mut']
cat2 = '20hr'
indices = [[[] for j in range(2)] for k in range(2)]
K = range(len(cat1))
for k in K:
    indices[k][0] = [i for i, s in enumerate(pfiles) if (cat1[k] in s) and (cat2 not in s)]
    indices[k][1] = [i for i, s in enumerate(pfiles) if (cat1[k] in s) and (cat2 in s)]

## Collect data

In [None]:
intensity = [[[] for k in K] for j in J]
intensity_eroded = [[[] for k in K] for j in J]
area = [[[] for k in K] for j in J]
regions = [[[] for k in K] for j in J]
species = [[[] for k in K] for j in J]

from scipy.ndimage import binary_erosion
from skimage.morphology import thin
for j in J:
    print(names[j])
    for k in K:
        inds = indices[k][1] #[1]->just process 20hr images
        for i in inds: # loop over each group of images 
            regs = measure.regionprops(masks[j][i],intensity_image=gfps[i])
            for r in regs[1:]: #no background
                # mask = binary_erosion(r.image,iterations=2)
                mask = thin(r.image,max_num_iter=6)
                if mask is not None and np.sum(mask)>0: 
                    r.J = j
                    r.K = k
                    r.img_index = i
                    r.img_name = img_names[i]
                    intensity[j][k].append(r.mean_intensity)
                    intensity_eroded[j][k].append(np.mean(r.intensity_image[mask])) #avoid edge artifacts from bleedover/GFP registration 
                    area[j][k].append(r.area)
                    regions[j][k].append(r)
                    if len(area[j][k])!=len(intensity[j][k]):
                        print(i,j,k)

In [None]:
len(area[j][k])==len(intensity[j][k]) # sanity check

In [None]:
text = ['Tre 1 experiment','image count: {}'.format(len(imgs)),'']

for j in J:
    text += ['{} masks\n'.format(names[j])]
    cellcounts = [len(area[j][k]) for k in K]
    for k in K:
        text+=['\t{} cell count: {}'.format(cat1[k],cellcounts[k])]
    text+=['\ttotal cell count: {}\n'.format(np.sum(cellcounts))]

basedir = '/home/kcutler/DataDrive/omnipose_all/'        
with open(os.path.join(basedir,'Tre1_stats.txt'), "w") as text_file:
    [print(t,file=f) for t in text for f in [None,text_file]]

In [None]:
J = range(len(masks))
I= [[] for j in J]
A = [[] for j in J]
R = [[] for j in J]
pixelArea = 0.065**2
for j in J:
    # I[j] = [np.array([item for sublist in intensity_eroded[j][k] for item in sublist]) for k in K]
    # A[j] = [np.array([item for sublist in area[j][k] for item in sublist])*pixelArea for k in K]
    # R[j] = [np.array([item for sublist in regions[j][k] for item in sublist]) for k in K]
    I[j] = [np.nan_to_num(intensity_eroded[j][k]) for k in K]
    A[j] = [np.array(area[j][k]) for k in K]
    R[j] = [np.array(regions[j][k]) for k in K]


In [None]:
%matplotlib inline
# len(I[0][0]),len(I[1][0]),len(I[2][0])
# np.argwhere(A[0][1]==A[1][1])
bins=50
# p = plt.hist(I[1],bins=bins) 
j = 2
print(names[j])
fig,ax = plt.subplots()
p2 = plt.hist(I[j],bins=bins,density=True)
ax.set_xscale('log')
# len(intensity[1][1]),len(intensity[0][0])

In [None]:
# Is = np.concatenate([I[j][k]*A[j][k] for j,k in zip(J,K)])
# As = np.concatenate([A[j][k] for j,k in zip(J,K)])
# mI = np.median(Is)/np.median(As)
J = range(len(masks))
Is = [I[j][k] for j,k in zip(J,K)]
# Is = (Is-Is.min())/(Is.max()-Is.min())
mI=[np.percentile(Is[k],50) for k in K]

In [None]:
# define species per mask
species = [[[] for k in K] for j in J]
labels = [[[] for k in K] for j in J]
        
for j in J:    
    for k in K:
        i = I[j][k].copy()/mI[k]
        X = i.reshape(-1, 1)
        km = k_means(X,4,random_state=42) # k-means binning; km[1] is an index list of bins
        idx = np.argsort(km[0].flatten())
        # inds = np.logical_or(km[1]==idx[0], km[1]==idx[1])
        inds = km[1]==idx[0]
        species[j][k] = inds
        labels[j][k] = [r.label for r in R[j][k]]

## Make Figure 5,S6 Segmentation plots

In [None]:
from cellpose import plot, metrics
import omnipose
from omnipose.utils import rescale
import ncolor
from scipy.optimize import linear_sum_assignment
import fastremap
from scipy.ndimage import binary_dilation
cmap = mpl.cm.get_cmap('plasma')
cmap2 = mpl.cm.get_cmap('gray')
outl_col = cmap(0.85)[:3]
pad = 10
bg = 0.5
savedir = '/home/kcutler/DataDrive/omnipose_all/tre1/bigcells'
io.check_dir(savedir)

J = range(len(masks))
# for k in K:
#to plot these points on the graph, I either need to track a bunch of indices, or I make a suplemental list of areas 
#vs fluroescence (normalized also by mI). because I compressed R, I, and A, I don't think it will be easy to find
#indices. So for each Omnipose long cell, we make a list of corresponding area/intesity points to plot. 


for k in [0]:
    # big_ind = np.argwhere(A[j][k]>15).flatten()
    
    big_ind = [16550,29239,38647,51014,164489,203681,76074,100162,101979] #[5885,5901] #[5885,5901] 100162 
    # big_ind = [big_ind[4]]#4,-2
    big_ind = [big_ind[i] for i in [0,4,6]]#4,-2
    dotcolors = ['#FF0800','#FF903E','#BBBDBF']
    

    # labels = [r.label for r in R[j][k][species[j][k]]]
    labels0 = labels[j][k]
    L = len(big_ind)
    cell_error_list = [[[] for l in range(L)] for j in J]
    for idx,i in enumerate(big_ind):
    # for i in [big_ind[0]]:
    # for i in [16557]:
    # for i in range(16540,16557):
        print(i)
        j = 2 #omni 
        reg = R[j][k][i]
        l = reg.label
        print('Cell info',l,reg.area*pixelArea)
        
        # extract data
        mask = masks[j][reg.img_index]
        p = imgs[reg.img_index]
        im = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mask==0)])))
        gf = gfps[reg.img_index]
        ly,lx = im.shape
        
        #put in coordinates of the cell of interest (Omnipose)
        cell_error_list[2][idx] = [[A[j][k][i],I[j][k][i]]] #list with one
        
        #get single mask and all the neighbors too 
        cellmask = mask==l
        touching = np.unique(mask[binary_dilation(cellmask)])
        neighbor_labels = [ll for ll in labels0 if ll in touching]
        neighbor_mask = np.zeros_like(cellmask) 
        for nl in neighbor_labels:
            neighbor_mask[mask==nl] = nl
            
        y,x = np.nonzero(neighbor_mask)
        y0 = max(np.min(y)-pad,0)
        y1 = min(np.max(y)+pad,ly)
        x0 = max(np.min(x)-pad,0)
        x1 = min(np.max(x)+pad,lx)
            
        img0 = np.stack([im[y0:y1,x0:x1]]*3,axis=-1)
        gfp0 = np.stack([gf[y0:y1,x0:x1]]*3,axis=-1)
        om_m = mask[y0:y1,x0:x1] 
        # bin0 = np.logical_or(om_m==l,om_m==l-1) # capture both ong cells in this FoV
        bin0 = om_m==l
        a = np.count_nonzero(bin0)*pixelArea
        name = 'area_'+str(round(a))+'_cell'+str(i)
        outline = plot.outline_view(img0,om_m,color=outl_col,mode='thick')
        skimage.io.imsave(os.path.join(savedir,name+'_'+abbrev[2]+'_outline'+'.png'),np.uint8(outline*255)) #omnipose
        
        mgt, remap = fastremap.renumber(om_m) #take omnipose as ground truth 
        # mgt = om_m.copy()
        for ii in [1,0,-1]: #cellpose, stardist, misic
            print(names[ii])
            mpred, rmp = fastremap.renumber(masks[ii][reg.img_index][y0:y1,x0:x1])
#             iou = metrics._intersection_over_union(mgt, mpred)
#             th = 0
#             n_min = min(iou.shape[0], iou.shape[1])
#             costs = -(iou >= th).astype(float) - iou / (2*n_min)
#             true_ind, pred_ind = linear_sum_assignment(costs)
#             miou = iou[true_ind, pred_ind]
            ind = remap[l]
#             # ind = l
#             # s_iou = iou[ind]
#             # hits = np.argwhere(s_iou>.2)
#             # argpart = np.argpartition(s_iou,3)[:3] # DON"T DO IOU - need overlap code for errors
    
#             # print(ind,true_ind[ind],pred_ind[ind],miou[ind],hits)
            
            
            fov_regs = measure.regionprops(mpred,intensity_image=gf[y0:y1,x0:x1])
            ovp = metrics._label_overlap(mgt, mpred)[1:,1:] #throw out columns corresponding to zero  
            areas = np.array([r.area for r in fov_regs])
            # print(len(fov_regs),areas.shape,len(np.unique(mpred)),len(np.unique(mgt)),mgt.shape,mpred.shape,ovp.shape)
            ovr = ovp / areas[np.newaxis,:] # Overlap Ratio
            # p = plt.figure()
            # plt.imshow(np.hstack((utils.rescale(ovp),utils.rescale(ovr>.5))))
            # plt.imshow(utils.rescale(ovr))
            # p.show()
            # ind2 = np.argmax(np.sum(ovr>.75,axis=1))
            s_ovr = ovr[ind-1]
            # print(ind,ind2,s_ovr[s_ovr>0])#np.argmax(np.sum(ovr,axis=1))
            hits = np.argwhere(s_ovr>.5).flatten() #these columns should correspond to the remapped prediction labels 
            # print(hits)
            hitareas = np.array([areas[h] for h in hits]).flatten()
            print('Areas',hitareas)


            errors = []
            hitmask = np.zeros_like(om_m) # for debugging
            # for fov in fov_regs:
            #     lb = fov.label
            #     if lb in hits:
            #         # print('sffd',lb in np.unique(mpred))
            #         hitmask[mpred==lb+1] = lb+1 #label indexes are offset from real labels 
            #         a = fov.area
            #         i = fov.mean_intensity
            #         print('Point',a,i)
            #         print('isinthere',a in hitareas)
            #         errors.append([a*pixelArea,i/mI[k]])
            
            hit_regs = [r for r in fov_regs if r.area in hitareas] # label selection does not work; area does, but I worry about this (could over-select)
            for fov in hit_regs:
                
                # sr,sc = reg.slice
                # y0 = max(sr.start-pad,0)
                # y1 = min(sr.stop+pad,ly)
                # x0 = max(sc.start-pad-specialpad,0)
                # x1 = min(sc.stop+pad,lx)
                coords = fov.coords
                y,x = zip(*(p for p in coords))
                hitmask[y,x] = fov.label #label indexes are offset from real labels 
                a_ = fov.area
                i_ = fov.mean_intensity
                # print('Point',a,i)
                # print('isinthere',a in hitareas)
                errors.append([a_,i_])
            # for lb in hits
       
            # for h in hits:
            #     hitmask[mpred==h+1] = h+1

            cell_error_list[ii][idx] = errors
            
            # if np.any(hitmask):
            #     p = plt.figure(figsize=(10,10))
            #     h = plot.outline_view(img0,hitmask)
            #     ha = plot.outline_view(img0,mpred)
            #     t = plot.outline_view(img0,mgt==ind)
            #     plt.imshow(np.hstack((h,ha,t)))
            #     plt.axis('off')
            #     plt.show()
            
            
            outline = plot.outline_view(img0,mpred,color=outl_col,mode='thick')
            skimage.io.imsave(os.path.join(savedir,name+'_'+abbrev[ii]+'_outline'+'.png'),np.uint8(outline*255))
            
            

        skimage.io.imsave(os.path.join(savedir,name+'_phase'+'.png'),np.uint8(img0*255))
        skimage.io.imsave(os.path.join(savedir,name+'_gfp'+'.png'),np.uint8(gfp0*255))
        
        
        ec_masks = [np.zeros_like(om_m)]*3
        species_name = ['sp','ec']

        if 1:
            S = range(len(species_name))
            species_masks = [[np.zeros_like(om_m) for j in J] for s in S] # want to separate Ec and Sp masks
            for j in J:
                print(names[j])
                m = masks[j][reg.img_index][y0:y1,x0:x1]
                reg0 = R[j][k][i]
                # as of now, the plotting in the cell below needs to be run first for species partitioning...
                #need to pull that out to run first
                
                labels1 = [[r.label for r in R[j][k][species[j][k]==s] if r.img_index == reg0.img_index and r.label in np.unique(m)] for s in S]
                for s in S:
                    for l in labels1[s]:
                        species_masks[s][j][m==l] = 1
                    
                    nc = ncolor.label(m*species_masks[s][j])
                    print(species_name[s])
                    io.imsave(os.path.join(savedir,name+'_'+names[j]+'_mask_'+species_name[s]+'.png'),np.uint8(cmap2(1-rescale(nc))[:,:,:3]*255))
                    plt.imshow(nc)
                    plt.show()
            # outline = plot.outline_view(img0,m,color=outl_col,mode='thick')
            # io.imsave(os.path.join(savedir,name+'_'+names[j]+'_outline_'+species_name[s]+'.png'),np.uint8(outline*255))
             
        

In [None]:
%matplotlib inline
p = plt.figure()
J = len(masks)
for j in range(J):
    errors = cell_error_list[j]
    if errors:
        for l in range(L):
            pts = errors[l]
            if pts: #if not empty
                print(len(pts))
                x, y = zip(*(p for p in pts))
                plt.scatter(np.array(x)*pixelArea,y,facecolor=dotcolors[l],edgecolor=None)
# cell_error_list[0]
# p.show()

## Make Figure 5,S6 Scatter Plots

In [None]:
%matplotlib inline

labelsize = 9
markersize = 5
mpl.rcParams.update(mpl.rcParamsDefault)
darkmode = 1
if darkmode:
    plt.style.use('dark_background')
    axcol = 'k'
    background_color = [0,0,0,0]
    suffix = '_dark_mode'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    background_color = np.array([1,1,1,1])
    suffix = ''
    
do_errors = False
if not do_errors:
    suffix+='_no_error_cell_colors'
    
J = len(masks)
# sz = 3.25
sz = 1
xmax = 30
# fig,axs = plt.subplots(J,2,figsize=(len(names)*sz,2*sz))
# colors = [[['#8856a7','darkred'],['darkgray','dimgray']],
#           [['limegreen','darkgreen'],['darkgray','dimgray']],
#           [['orange','darkorange'],['darkgray','dimgray']],
#          ]
colors = [['#8BC53F','#58595B'],['#009345','#58595B']]
# colors = ['#009345','#58595B']
# cmap = mpl.cm.get_cmap('Greys')
# cmap2 = mpl.cm.get_cmap('Greens')

savedir = '/home/kcutler/DataDrive/omnipose_all/Tre1/plots/'

io.check_dir(savedir)



for j in range(J):
    print(names[j])
    errors = cell_error_list[j]

    for k in K:
# for j in [1]:
#     for k in [1]:
        # ax = axs[j][k]
        fig,ax = plt.subplots(figsize=(3*sz,sz))
        fig.patch.set_facecolor(background_color)

        a = A[j][k].copy()*pixelArea
        i = I[j][k].copy()/mI[k]
        # X = np.log(i).reshape(-1, 1)
        X = i.reshape(-1, 1)
        km = k_means(X,4,random_state=42) # k-means binning; km[1] is an index list of bins
        idx = np.argsort(km[0].flatten())
        # inds = np.logical_or(km[1]==idx[0], km[1]==idx[1])
        inds = km[1]==idx[0]
        species[j][k] = inds
        ax.scatter(a[inds],i[inds],s=markersize,c=colors[k][1],rasterized=True)
        ax.scatter(a[~inds],i[~inds],s=markersize,c=colors[k][0],rasterized=True)
        
        # if pts and k==0: #if not empty
        #     x, y = zip(*(p for p in pts))
        #     ax.scatter(x,y,c='r',s=markersize,rasterized=True)
            
        if errors and k==0 and do_errors:
            for l in range(L):
                pts = errors[l]
                if pts: #if not empty
                    x,y = zip(*(p for p in pts))
                    a = np.array(x)*pixelArea
                    i = np.array(y)/mI[k]
                    ax.scatter(a,i,marker='o',facecolors=dotcolors[l], edgecolors='none',s=markersize*3,rasterized=False)
#         scatter_heat(ax,a[inds],i[inds],s=markersize,cmap=cmap)
#         scatter_heat(ax,a[~inds],i[~inds],s=markersize,cmap=cmap2)
#         density_scatter(a[~inds],i[~inds],ax,s=markersize,cmap=cmap2)
#         density_scatter(a[inds],i[inds],ax,s=markersize,cmap=cmap)

        ax.tick_params(axis='both', which='major', labelsize=labelsize,
                       length=3, direction="out",colors=axcol,bottom=True,left=True)
#         ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_ylim(.01,10)
        ax.set_xlim(0,xmax) #110 covers all
        ax.tick_params(which='both', colors=axcol)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_color(axcol)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_color(axcol)
        
        ax.patch.set_alpha(0.0)
        # ax.set_title(names[j] + ' ' + cat1[k]+' '+str(a.size)+' %.2f'%(np.sum(inds)/a.size*100)) #+str(j)+str(k)
        # ax.set_title(names[j] + ' ' + cat1[k] ) #+str(j)+str(k)
        
#         ax.set_axis_off()
#         fig.add_axes(ax)
        plt.show()
        a = 0
        tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
        tight_bbox_raw._points+=[[-a,-a],[a,a]]
        tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))
        fig.savefig(os.path.join(savedir,names[j]+'_'+cat1[k]+suffix+'.pdf'),bbox_inches=tight_bbox, dpi=1000)
#         ax.set_axis_on()
# plt.subplots_adjust(hspace = .7,wspace = 0.25)
# plt.show()
   