In [1]:
import group_data as gd

In [2]:
def parse_GMR_genotype(genotype):
    """get the GMR and UAS keys in a given genotype"""
    print genotype
    gcamp_idx =  genotype.find('GCaMP6f')
    GFP_idx = genotype.find('GFP')
    gmr_idx = genotype.find('GMR')
    gmr_str = genotype[gmr_idx+3:gmr_idx+8]
    if gcamp_idx > 0:
        return {'uas':'GCaMP6f','gal4':gmr_str}

In [8]:
def get_line_database(line_name):
    #print line_name
    import cPickle
    f = open('../mn_expression_matrix_plot/line_database.cpkl','rb')
    line_database = cPickle.load(f)
    f.close()
    return line_database

def get_muscle_list(line_name):
    line_database = get_line_database(line_name)
    muscle_names = list()
    for key in line_database[line_name].keys():
        if line_database[line_name][key] > 0:
            muscle_names.append(key)
    muscle_names = sorted(muscle_names)
    #muscle_names = sorted(get_muscle_list(line_name))
    return muscle_names

In [4]:
fly = gd.swarms['GMR22H05'].flies[0]

In [None]:
def get_update_list(file_name ='nnls_fits_no_bk_dF_F.cpkl', 
                     swarms = gd.swarms,
                     replace = False):
    """ if replace is False this will scan the database to 
    create a 'pathlist' containing just flies that don't have 
    a file with file_name, otherwise all the flies in swarms will be used"""
    import os
    update_flylist = list()
    for swarm_name,swarm in swarms.items():
        #print swarm_name
        for fly in swarm.flies:
            try:
                if not(replace):
                    if os.path.exists(fly.fly_path + 'nnls_fits_no_bk_dF_F.cpkl'):
                        pass
                        #print str(fly.fly_num) + ' exists'
                    else:
                        update_flylist.append((fly.fly_path,'GMR22H05'))
                else:
                    update_flylist.append(fly)
            except Exception as er:
                print er
    return update_flylist

In [None]:
flypaths = get_update_paths(swarms = gd.swarms)
#fly_paths.append((fly.fly_path,key.split('_')[0]))

In [None]:
#fit to each fly in serial (block) but break up the data to run in parallel within a fly.
from IPython.parallel import Client
clients = Client() 
clients.block = True
print clients.ids
v = clients[:]

In [None]:
def fit_to_model(imchunk,model, mode = 'pinv',fit_pix_mask = None,baseline = None):
    import numpy as np
    im_array = (imchunk-baseline)/baseline
    imshape = np.shape(im_array[0])
    im_array = im_array.reshape((-1,imshape[0]*imshape[1]))
    if mode == 'nnls':
        fits = np.empty((np.shape(model)[0],np.shape(im_array)[0]))
        for i,im2 in enumerate(im_array):
            im = im2.copy()
            im[~np.isfinite(im)] = 0
            from scipy.optimize import nnls
            if not(fit_pix_mask is None):
                fits[:,i] = nnls(model[:,fit_pix_mask].T,im[fit_pix_mask])[0]
            else:
                fits[:,i] = nnls(model.T,im)[0]
    else:
        im = im_array
        print np.shape(im_array)
        from numpy.linalg import pinv
        if not(fit_pix_mask is None):
            fits = np.dot(pinv(model[:,fit_pix_mask]).T,im[:,fit_pix_mask].T)
        else:
            fits = np.dot(pinv(model).T,im)
    return fits

#extract the data give the fly_path and 'line_name'
def extract_signals(fly):
    import muscle_model as mm
    import numpy as np
    import h5py
    confocal_model = mm.GeometricModel(filepath = 'model_data.cpkl')
    confocal_view = mm.ModelViewMPL(confocal_model)
    fly_path = fly.fly_path
    line_name = parse_GMR_genotype(fly.get_genotype())['gal4']
    #first create the model for the fly
    #get the list of muscles for a given line
    muscles = get_muscle_list(line_name)
    muscles = [m for m in muscles if not('DVM' in m) and not('DLM' in m) and not('ps' in m)]
    #get a reference to the image data
    fly_record = h5py.File(fly_path + 'fly_record.hdf5','r')
    exp_record = fly_record['experiments'].values()[0]
    imgs = exp_record['tiff_data']['images']
    #construct the model for a fly
    import os
    import cPickle
    pkname = fly_path + '/basis_fits.cpkl'
    fly_frame = mm.Frame();fly_frame.load(pkname)    
    #get the mask of all the muscles for fit
    masks = confocal_model.get_masks(fly_frame,np.shape(imgs[0]))
    #create the model using only the muscles that express in a given line
    model = np.vstack([masks[mask_key].T.ravel().astype(float) for mask_key in muscles])
    fit_pix_mask = np.sum(model,axis=0) > 0
    #add a background term
    ####model = np.vstack([model,np.ones_like(masks[mask_key].ravel())])
    #plb.imshow(np.sum(model,axis = 0).reshape(np.shape(imgs[0])),cmap = plb.cm.gray)
    ##subtract baseline epoch from each image
    f = open(fly_path + 'epoch_data.cpkl')
    import cPickle
    baseline_range = cPickle.load(f)['baseline_F']
    f.close()
    baseln = np.mean(imgs[baseline_range],axis = 0)
    
    chnk_sz = 2000
    num_samps = np.shape(imgs)[0]
    chunks = [slice(x,x+chnk_sz if x+chnk_sz < num_samps else num_samps) for x in range(0,num_samps,chnk_sz)]
    
    img_chunks = [np.array(imgs[chunk]) for chunk in chunks]
    models = [model for chunk in chunks]
    modes = ['nnls' for chunk in chunks]
    fit_pix_masks = [fit_pix_mask for chunk in chunks]
    baselines = [baseln for chunk in chunks]
    
    fits = v.map(fit_to_model,img_chunks,models,modes,fit_pix_masks,baselines)
    #fit = fit_to_model(imchunk,model,mode = 'nnls',fit_pix_mask = fit_pix_mask)
    return fits,muscles

In [None]:
import numpy as np
for flypath in flypaths:
    print flypath
    try:
        fits,muscles = extract_signals(*flypath)
        f = open(flypath[0]+'nnls_fits_no_bk_dF_F.cpkl','wb')
        import cPickle as cpkl
        cpkl.dump({'fits':np.hstack(fits),'muscles':muscles},f)
        f.close()
    except IOError as err:
        print flypath
        print err