In [1]:
from MERFISH_Objects.FISHData import *
import os
from analysis_scripts.classify import *
import random
from tqdm import tqdm
from scipy.spatial.distance import cdist
import multiprocessing
import sys

In [2]:
def populate_decoys(pos,nbits,iterations,name):
    decoys = []
    base = list((pos*"1").zfill(nbits))
    for i in tqdm(range(iterations),total=iterations):
        random.shuffle(base)
        decoys.append(''.join(base))
    decoys = list(np.unique(decoys))
    decoy_names = [name+'_'+str(i) for i in range(len(decoys))]
    decoys_array = np.array([[int(i) for i in decoys[j]] for j in range(len(decoys))])
    return decoys_array,decoy_names

In [3]:
def classify_pixels(cstk,normalization_factors,codeword_vectors,rel_peak_thresh=99):
#     codeword_vectors = 0.5*(codeword_vectors>0)
    if len(cstk.shape)==3:
        max_mask = np.max(cstk,axis=2)
        mask = max_mask>np.percentile(max_mask.ravel(),rel_peak_thresh)
        x,y = np.where(mask)
        vectors = cstk[x,y,:]
    else:
        vectors = cstk
    vectors = np.divide(vectors.astype('float32'), normalization_factors.astype('float32'))
#     vectors = normalize(vectors, norm='l2')
#     d = distance_matrix(codeword_vectors, vectors)
    d = cdist(codeword_vectors, vectors)
    dimg = np.nan*np.ones((cstk.shape[0],cstk.shape[1]))
    if len(cstk.shape)==3:
        dimg[x,y] = np.min(d,axis=0)
        cimg = -1*np.ones((cstk.shape[0],cstk.shape[1]))
        cimg[x,y] = np.argmin(d, axis=0)
    else:
        dimg = np.min(d,axis=0)
        cimg = np.argmin(d, axis=0)
    return dimg,cimg

In [4]:
def parse_classification_image(class_img, cstk, cvectors, genes, zindex, distance_img):
    label2d = label((class_img+1).astype('uint16'), connectivity=1)
    properties = regionprops(label2d, (class_img+1).astype('uint16'))
    areas = []
    nclasses = []
    multiclass_sets = 0
    gene_call_rows = []
    below_threshold_rows = []
    for prop in properties:
        coords = prop.coords
        centroid = prop.centroid
        classes = list(set(prop.intensity_image.flatten())-set([0]))
        if len(classes)==0:
            print('Label with no classes.')
            pdb.set_trace()
            continue
        elif not len(classes)==1:
            pdb.set_trace()
            multiclass_sets+=1
            continue
        else:
            nclasses.append(len(classes))
            areas.append(prop.area)
        codeword_idx = classes[0]-1
        gene = genes[codeword_idx]
        bits = np.where(cvectors[codeword_idx]>0)[0]
        spot_pixel_values = []
        spot_pixel_means = []
        # Calculating the mean pixel intensities for each positive bit for a single spot
        spot_nf = np.zeros(cvectors.shape[1])
        for b in bits:
            spot_bit_intensities = cstk[coords[:,0], coords[:,1], b]
            spot_nf[b] = np.mean(spot_bit_intensities)
            spot_pixel_values.append(spot_bit_intensities)
        spot_sum = np.sum(spot_pixel_values)
        spot_mean = np.mean(spot_pixel_values)
        spot_distance = np.mean(distance_img[coords[:,0], coords[:,1]])
        # If the spot is above spot_sum_thresh then add it to the gene spot list
        # the hope is to filter out background here
        gene_call_rows.append([genes[codeword_idx], spot_sum, centroid,
                        spot_mean, len(coords), codeword_idx,spot_distance])
    df = pd.DataFrame(gene_call_rows, columns=['gene', 'ssum', 'centroid', 'ave', 'npixels', 'cword_idx','cword_dist'])
    return df

In [5]:
def load_codestack(fishdata,bitmap,dataset,posname,zindex):
    cstk = np.zeros((2048,2048,len(bitmap)))
    for bitmap_idx in range(len(bitmap)):
        seq,hybe,channel = bitmap[bitmap_idx]
        temp = fishdata.load_data('image',dataset=dataset,posname=posname,hybe=hybe,channel=channel,zindex=zindex)
        if not isinstance(temp,type(None)):
            cstk[:,:,bitmap_idx] = temp
    return cstk

In [6]:
def classify(data):
    fishdata_path = data['fishdata_path']
    fishdata = FISHData(fishdata_path)
    cword_config = data['cword_config']
    seqfish_config = importlib.import_module(cword_config)
    bitmap = seqfish_config.bitmap
    nbits = seqfish_config.nbits
    dataset = data['dataset']
    posname = data['posname']
    zindex = data['zindex']
    rel_peak_thresh = data['rel_peak_thresh']
    codeword_vectors = np.load(os.path.join(fishdata_path,'codeword_vectors.npy'))
    genes = np.load(os.path.join(fishdata_path,'genes.npy'))
    normalization_factors = 1000*np.ones(nbits)
    try:
        cstk = load_codestack(fishdata,bitmap,dataset,posname,zindex)
        dimg,cimg = classify_pixels(cstk,normalization_factors,codeword_vectors,rel_peak_thresh=rel_peak_thresh)
        df = parse_classification_image(cimg, cstk, codeword_vectors, genes, zindex, dimg)
        df['posname'] = posname
        good_genes = [i for i in genes if not 'decoy' in i]
        df = df[df['gene'].isin(good_genes)]
        fishdata.add_and_save_data(df,'spotcalls',dataset=dataset,posname=posname,hybe='all',channel='all',zindex=zindex)
    except Exception as e:
        print(posname,zindex,'Failed')
        print(e)
    return None


In [7]:
def classify_wrapper(dataset,fishdata_path,cword_config,rel_peak_thresh=99,ncpu=1):
    seqfish_config = importlib.import_module(cword_config)
    try:
        genes = seqfish_config.gids+seqfish_config.bids
    except:
        genes = seqfish_config.gids
    bitmap = seqfish_config.bitmap
    normalized_gene_vectors = seqfish_config.norm_gene_codeword_vectors
    normalized_all_gene_vectors = seqfish_config.norm_all_codeword_vectors
    nbits = seqfish_config.nbits
    low_decoys,low_decoy_names = populate_decoys(1,nbits,100,'low_decoy')
    high_decoys,high_decoy_names = populate_decoys(8,nbits,1000000,'low_decoy')
    codeword_vectors = (normalized_all_gene_vectors/np.max(normalized_all_gene_vectors)).astype(int)
    codeword_vectors = np.concatenate((codeword_vectors,low_decoys,high_decoys))
    print(codeword_vectors.shape)
    genes = list(genes)
    genes.extend(low_decoy_names)
    genes.extend(high_decoy_names)
    np.save(os.path.join(fishdata_path,'codeword_vectors.npy'),codeword_vectors)
    np.save(os.path.join(fishdata_path,'genes.npy'),genes)
    del codeword_vectors
    del genes
    posnames = list(np.unique([i.split('_')[2] for i in os.listdir(fishdata_path) if 'Pos' in i]))
    zindexes = list(np.unique([i.split('_')[-2] for i in os.listdir(fishdata_path) if 'Pos' in i]))
    zindexes = [i for i in zindexes if not ('all' in i)|('X' in i)]
    Input = []
    fishdata = FISHData(fishdata_path)
    for posname in posnames:
        try:
            flag = fishdata.load_data('flag',dataset=dataset,posname=posname)
        except:
            continue
        if flag!='Passed':
            print(posname,'Failed')
            continue
        for zindex in zindexes:
            data = {'dataset':dataset,
                    'posname':posname,
                    'zindex':zindex,
                    'cword_config':cword_config,
                    'rel_peak_thresh':rel_peak_thresh,
                    'fishdata_path':fishdata_path}
            Input.append(data)
    if ncpu==1:
        for data in tqdm(Input,total=len(Input),desc='Classifying'):
            classify(data)
    else:
        with multiprocessing.Pool(ncpu) as ppool:
            sys.stdout.flush()
            for out in tqdm(ppool.imap(classify,Input,chunksize=1),total=len(Input)):
                do='Nothing'
            ppool.close()
            sys.stdout.flush()

In [8]:
fishdata_path = '/hybedata/Images/Zach/ZebraFinch/A4_2020Jun28/fishdata/'
dataset = 'A4_2020Jun28'
classify_wrapper(dataset,fishdata_path,'seqfish_config_zebrafinch',ncpu=5)

100%|██████████| 100/100 [00:00<00:00, 32256.43it/s]
100%|██████████| 1000000/1000000 [00:26<00:00, 37751.30it/s]


(43963, 18)
Pos157 Failed
Pos25 Failed
Pos262 Failed
Pos263 Failed
Pos264 Failed
Pos39 Failed
Pos53 Failed
Pos66 Failed


100%|██████████| 4144/4144 [37:16:11<00:00, 32.38s/it]     
