In [1]:
######################
trials = [
'011620_135607'
    ]
######################

In [26]:
# scrap trial lists
#     '112119_112752',
#     '112119_134527',
#     '112219_115341', 
#     '112219_135444',
#     '112519_105926',
#     '112519_132930',
#     '112519_160932',
#     '120919_101920',
#     '120919_134003',
#     '120919_162513',
#     '121219_115309'

In [27]:
# imports
import numpy as np
import math
from scipy import signal
import scipy.interpolate
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.patches as patches
import h5py
import json
import yaml
from PIL import Image
import sys
import pickle
import warnings
warnings.filterwarnings('ignore')
#warnings.filterwarnings(action='once')

# %pylab inline

In [28]:
# convert to mm


In [29]:
# DEF convert trial schedule to seconds, and get total trial duration
def convert_schedule_to_sec(trial_schedule):
    t_total=0
    schedule_in_sec = {}
    for i, phase in enumerate(trial_schedule):
        phase_duration = trial_schedule[i]['duration']*60.0
        schedule_in_sec[i]={}
        schedule_in_sec[i]['duration']=phase_duration
        schedule_in_sec[i]['state']=trial_schedule[i]['state']
        t_total += phase_duration
    return schedule_in_sec, t_total

In [30]:
# DEF radial distance
def create_radial_array(center, x, y):
    x_center = center[0]
    y_center = center[1]
    X_dist = x-x_center
    Y_dist = y-y_center
    dist = np.array([X_dist, Y_dist])
    # generate radial distance array
    radial = np.hypot(X_dist, Y_dist)
    return radial

In [31]:
# DEF go through trial and create dictionary with experiment data
def extract_trial_data(trial):
    data={}
    for arena in range(100):
        experiment = trial+'_'+str(arena)
        f = 'data/' + experiment + '.hdf5'
        try:
            exp = h5py.File(f,'r')
            data[arena]={
                # trial data
                'param':{},
                'date_time':{},
                'trial_schedule':{},
                't_total':{}, # total time (sec)
                # arena-specific metadata
                'bg_image':{},
                'fly':{},
                'arena':{},
                'center':{},
                'region_width':{},
                'region_height':{},
                'led_policy':{},
                'classifier_type':{},
                'interpolation':{}, # frequency of subsampling
                'notes':{},
                # arena-specific timeseries data
                'elapsed_t':{}, # raw time data
                'times':{}, # interpolated time grid
                'object_found':{},
                'led_enabled':{},
                'classifier':{},
                'x':{},
                'y':{},
                'led':{},
                'radial':{}, # create empty dictionary
                'speed':{}
            }
            
            # trial data
            data[arena]['all_param']=exp.attrs['param'] # dump all attributes here just in case
            param=yaml.load(exp.attrs['param'])
            data[arena]['region_height']=param['regions']['height']
            data[arena]['region_width']=param['regions']['width']
            data[arena]['trial_schedule']=convert_schedule_to_sec(param['trial_schedule'])[0]
            data[arena]['t_total']=convert_schedule_to_sec(param['trial_schedule'])[1]
            
            # arena-specific metadata
            data[arena]['bg_image']=np.asarray(exp['bg_image'])
            data[arena]['fly']=exp.attrs['fly']
            data[arena]['arena']=exp.attrs['index']+1
            data[arena]['center']=exp.attrs['center']
            data[arena]['led_policy']=exp.attrs['led_policy']
            data[arena]['classifier_type']=exp.attrs['classifier']
            classifier_type = yaml.load(exp.attrs['classifier'])
            data[arena]['param'] = yaml.load(classifier_type['param'])
            data[arena]['date_time']=trial
            data[arena]['notes']=exp.attrs['notes']
            
            # arena-specific timeseries data (interpolated)
            data[arena]['elapsed_t']=np.asarray(exp['elapsed_t'])
            data[arena]['object_found']=np.asarray(exp['object_found'])
            data[arena]['led_enabled']=np.asarray(exp['led_enabled'])
            data[arena]['classifier']=np.asarray(exp['classifier'])
            data[arena]['fly_x']=np.asarray(exp['fly_x'])
            data[arena]['fly_y']=np.asarray(exp['fly_y'])
            data[arena]['ball_x']=np.asarray(exp['ball_x'])
            data[arena]['ball_y']=np.asarray(exp['ball_y'])
            data[arena]['led']=np.asarray(exp['led'])
            
#             # INTERPOLATION (to make regular time steps)
#             data[arena]['elapsed_t']=np.asarray(exp['elapsed_t'])
#             Hz = 20 # frequency of subsampling
#             # make time grid for length of experiment time with defined resampling frequency.
#             data[arena]['times'] = np.linspace(0,math.floor(data[arena]['elapsed_t'][-1]),math.floor(data[arena]['elapsed_t'][-1])*Hz+1)
#             data[arena]['interpolation']=Hz
            
#             # arena-specific timeseries data (interpolated)
#             data[arena]['object_found']=np.asarray(exp['object_found']) # grab data
#             interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['object_found']) # create interpolator function based on raw times
#             data[arena]['object_found'] = interpolator(data[arena]['times']) # interpolate based on timegrid
#             data[arena]['led_enabled']=np.asarray(exp['led_enabled'])
#             interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['led_enabled'])
#             data[arena]['led_enabled'] = interpolator(data[arena]['times'])
#             data[arena]['classifier']=np.asarray(exp['classifier'])
#             interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['classifier'])
#             data[arena]['classifier'] = interpolator(data[arena]['times'])
#             data[arena]['x']=np.asarray(exp['x'])
#             interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['x'])
#             data[arena]['x'] = interpolator(data[arena]['times'])
#             data[arena]['y']=np.asarray(exp['y'])
#             interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['y'])
#             data[arena]['y'] = interpolator(data[arena]['times'])
#             data[arena]['led']=np.asarray(exp['led'])
#             interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['led'])
#             data[arena]['led'] = interpolator(data[arena]['times'])
#             data[arena]['radial']=create_radial_array(data[arena]['center'],data[arena]['x'],data[arena]['y'])
        except:
            pass
    return data

In [32]:
# DEF filter data: only indices with object found, cut at trial end
def filter_data(data):
    for arena in data:
        # mask: only indices where object found
        obj_found = np.where(data[arena]['object_found']==1)[0]
        data[arena]['elapsed_t']=data[arena]['elapsed_t'][obj_found]
        data[arena]['object_found']=data[arena]['object_found'][obj_found]
        data[arena]['led_enabled']=data[arena]['led_enabled'][obj_found]
        data[arena]['classifier']=data[arena]['classifier'][obj_found]
        data[arena]['fly_x']=data[arena]['fly_x'][obj_found]
        data[arena]['fly_y']=data[arena]['fly_y'][obj_found]
        data[arena]['ball_x']=data[arena]['ball_x'][obj_found]
        data[arena]['ball_y']=data[arena]['ball_y'][obj_found]
        data[arena]['led']=data[arena]['led'][obj_found]
#         if len(obj_found)>0:
#             # find end index
#             t_total = data[arena]['t_total']
#             elapsed_t = data[arena]['elapsed_t'][obj_found]
#             try:
#                 end = np.where(elapsed_t>t_total)[0][0]
#             except:
#                 end = len(elapsed_t)
#             data[arena]['indices']=obj_found[0:end]
#         else:
#             data[arena]['indices']=[]
    return data

In [33]:
# DEF interpolate data (to make regular time steps)
def interpolate_data(data, Hz):
    for arena in data:
        
        data[arena]['interpolation']=Hz
        
        obj_found = np.where(data[arena]['object_found']==1)[0]
        if np.sum(obj_found)>100:

            # make time grid for length of experiment time with defined resampling frequency.
            data[arena]['times'] = np.linspace(0,math.floor(data[arena]['elapsed_t'][-1]),math.floor(data[arena]['elapsed_t'][-1])*Hz+1)
            # set first raw timepoint to zero, to allow interpolation to timegrid
            data[arena]['elapsed_t'][0]=0.0

            # go thru categorical timeseries data and interpolate to nearest value
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['object_found'], kind='nearest') # create interpolator function based on raw times
            data[arena]['object_found'] = interpolator(data[arena]['times']) # interpolate based on timegrid
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['led_enabled'], kind='nearest')
            data[arena]['led_enabled'] = interpolator(data[arena]['times'])
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['classifier'], kind='nearest')
            data[arena]['classifier'] = interpolator(data[arena]['times'])
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['led'], kind='nearest')
            data[arena]['led'] = interpolator(data[arena]['times'])

            # go thru scalar timeseries data and interpolate linear
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['fly_x'])
            data[arena]['fly_x'] = interpolator(data[arena]['times'])
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['fly_y'])
            data[arena]['fly_y'] = interpolator(data[arena]['times'])
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['ball_x'])
            data[arena]['ball_x'] = interpolator(data[arena]['times'])
            interpolator = scipy.interpolate.interp1d(data[arena]['elapsed_t'], data[arena]['ball_y'])
            data[arena]['ball_y'] = interpolator(data[arena]['times'])
    
    return data

In [34]:
# DEF get speed arrays
def get_speed(data):
    for arena in data:
        dt = 1/float(data[arena]['interpolation'])
        x=data[arena]['x']
        y=data[arena]['y']
        dX = np.diff(x)
        dY = np.diff(y)
        velX = dX/dt
        velY = dY/dt
        data[arena]['speed'] = np.power(np.power(velX,2) + np.power(velY,2),0.5)
    
    return data

In [35]:
# create pickle for each trial
for trial in trials:
    data = extract_trial_data(trial)
    data = filter_data(data)
    data = interpolate_data(data, Hz=20) # Hz = frequency of subsampling
    # save data as pickle
    data_path = 'data/' + trial + '.pickle'
    with open(data_path, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print('pickle created for: {}'.format(trial))

pickle created for: 011620_135607


In [36]:
# get all unique values of key parameters in dataset
fly = {}

for i in data:
    di=data[i]
    fly[i] = di['fly']
    
print('fly:', set(fly.values()))


classifier = {}

for i in data:
    di=data[i]
    c_type = yaml.load(di['classifier_type'])
    classifier[i] =c_type['type']
    
print('classifier:', set(classifier.values()))

fly: {'Gr5a'}
classifier: {'ficfruit_touch'}


In [31]:
for i in data:
    print(i)
    print(data[i]['fly'])

0
C-137
1
C-137
2
C-137
3
C-137
4
C-137
5
C-137
6
C-137
7
C-137
8
C-137
9
C-137
10
C-137
11
C-137
12
C-137
13
C-137
14
C-137
15
C-137
16
C-137
17
C-137
18
C-137
19
C-137
20
C-137
21
C-137
22
C-137
23
C-137
24
C-137
25
C-137
26
C-137
27
C-137
28
C-137
29
C-137
30
C-137
31
C-137
32
C-137
33
C-137
34
C-137


In [25]:
for a in range(len(data)):
    if len(data[a]['fly'])==0:
        print(a)