In [None]:
# import packages
import os 
path_main = '/home/anna/Src/spykshrk_realtime'
os.chdir(path_main)  # navigate to main spykshrk parent directory 

%load_ext autoreload
%autoreload 2
import holoviews as hv
import loren_frank_data_processing as lfdp
from loren_frank_data_processing import Animal
import numpy as np
np.seterr(divide='ignore',invalid='ignore')
import trodes2SS
import scipy as sp
import sungod_util
from spykshrk.franklab.data_containers import RippleTimes, pos_col_format#FlatLinearPosition, SpikeFeatures, Posteriors, \
         #EncodeSettings, pos_col_format, SpikeObservation, RippleTimes, DayEpochEvent, DayEpochTimeSeries
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPDecoder
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer

hv.extension('bokeh')

In [None]:
#TO DO : 
# 1. improve save options. make convert posterior to xarray more flexible - make params optional 


In [None]:
#### Define parameters
rat_name = 'remy'
day = 20      #previously:{'remy':[20], 'gus':[28], 'bernard':[23], 'fievel':[19]}
epoch = 2   # previously:{'remy':[4], 'gus':[2], 'bernard':[4], 'fievel':[2]} 

# define data source filepaths
path_base = '/data2/mcoulter/'
raw_directory = path_base + 'raw_data/' + rat_name + '/'
linearization_path = path_base + 'maze_info/'
day_ep = str(day) + '_' + str(epoch)

#tetrodes_dictionary = {'remy': [4,6,9,10,11,12,13,14,15,17,19,20,21,22,23,24,25,26,28,29,30], # for a 45 min runtime on virga use tetrodes 4,9,11,13,15,19,21,23,25,28,30
#                       'gus': [6,7,8,9,10,11,12,17,18,19,21,24,25,26,27,30], # list(range(6,13)) + list(range(17,22)) + list(range(24,28)) + [30]
#                        'bernard': [1,2,3,4,5,7,8,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29], 
#                       'fievel': [1,2,3,5,6,7,8,9,10,11,12,14,15,16,17,18,19,20,22,23,24,25,27,28,29]}

# if you want all ca1 tets with no deadchans, set tetlist to None. otherwise, specify list

tetlist = None
#tetlist = [4,6]

if tetlist is None:
    animalinfo  = {rat_name: Animal(directory=raw_directory, short_name=rat_name)}
    tetinfo = lfdp.tetrodes.make_tetrode_dataframe(animalinfo)
    tetinfo['ndtype'] = tetinfo['deadchans'].apply(lambda d: isinstance(d,np.ndarray)) # add column with datatype of deadchans entry
    tmp = tetinfo['deadchans'][tetinfo['ndtype'].values].apply(lambda d: len(d))   # add length of deadchans list
    tetinfo['ndlength'] = tmp   # store lengths as an additional column. no dead chans = length 0 
    tetrodes = tetinfo.query('area=="ca1" & ndlength==0 & day==@day & epoch==@epoch').index.get_level_values('tetrode_number').unique().tolist()
else:
    tetrodes= tetlist
    
pos_bin_size = 5
velocity_thresh_for_enc_dec = 4
velocity_buffer = 0

shift_amt_for_shuffle = 0

discrete_tm_val=.99   # for classifier

print(tetrodes)

In [None]:
%%time
# IMPORT and process data

#initialize data importer
datasrc = trodes2SS.TrodesImport(raw_directory, rat_name, [day], 
                       [epoch], tetrodes)
# Import marks
marks = datasrc.import_marks()
print('original length: '+str(marks.shape[0]))
# OPTIONAL: to reduce mark number, can filter by size. Current detection threshold is 100  
marks = trodes2SS.threshold_marks(marks, maxthresh=2000,minthresh=100)
# remove any big negative events (artifacts?)
marks = trodes2SS.threshold_marks_negative(marks, negthresh=-999)
print('after filtering: '+str(marks.shape[0]))

# Import trials
trials = datasrc.import_trials()

# Import raw position
linear_pos_raw = datasrc.import_pos(xy='x')   # this is basically just to pull in speed, will be replaced with linearized below
#posY = datasrc.import_pos(xy='y')          #  OPTIONAL; useful for 2d visualization

# if linearization exists, load it. if not, run the linearization.
lin_output1 = linearization_path + rat_name + '/' + rat_name + '_' + day_ep + '_' + 'linearized_distance.npy'
if os.path.exists(lin_output1) == False:
    print('Linearization result doesnt exist. Doing linearization calculation')
    sungod_util.run_linearization_routine(rat_name, day, epoch, linearization_path, raw_directory, gap_size=20)
else: 
    print('Linearization found. Loading it')
    lin_output2 = linearization_path + rat_name + '/' + rat_name + '_' + day_ep + '_' + 'linearized_track_segments.npy'
    linear_pos_raw['linpos_flat'] = np.load(lin_output1)   #replace x pos with linerized 
    track_segment_ids = np.load(lin_output2)
    
# Import ripples
rips_tmp = datasrc.import_rips(linear_pos_raw, velthresh=4) 
rips = RippleTimes.create_default(rips_tmp,1)  # cast to rippletimes obj
print('Rips less than velocity thresh: '+str(len(rips)))
# generate boundary definitions of each segment
arm_coords, _ = sungod_util.define_segment_coordinates(linear_pos_raw, track_segment_ids)  # optional addition output of all occupied positions (not just bounds)

#bin linear position 
binned_linear_pos, binned_arm_coords, pos_bins = sungod_util.bin_position_data(linear_pos_raw, arm_coords, pos_bin_size)

# calculate bin coverage based on determined binned arm bounds   TO DO: prevent the annnoying "copy of a slice" error [prob need .values rather than a whole column]
pos_bin_delta = sungod_util.define_pos_bin_delta(binned_arm_coords, pos_bins, linear_pos_raw, pos_bin_size)

max_pos = binned_arm_coords[-1][-1]+1

In [None]:
# decide what to use as encoding and decoding data
marks, binned_linear_pos = sungod_util.assign_enc_dec_set_by_velocity(binned_linear_pos, marks, velocity_thresh_for_enc_dec, velocity_buffer)

# rearrange data by trials 
pos_reordered, marks_reordered, order = sungod_util.reorder_data_by_random_trial_order(trials, binned_linear_pos, marks)

encoding_marks = marks_reordered.loc[marks_reordered['encoding_set']==1]
decoding_marks = marks_reordered.loc[marks_reordered['encoding_set']==0]
encoding_marks.drop(columns='encoding_set',inplace=True)  # drop these columns after use so they don't take up a bunch of extra space
decoding_marks.drop(columns='encoding_set',inplace=True)

print('Encoding spikes: '+str(len(encoding_marks)))
print('Decoding spikes: '+str(len(decoding_marks)))

encoding_pos = pos_reordered.loc[pos_reordered['encoding_set']==1]

# apply shift for shuffling 
encoding_marks_shifted, shift_amount = sungod_util.shift_enc_marks_for_shuffle(encoding_marks, shift_amt_for_shuffle)
# put marks back in chronological order for some reason
encoding_marks_shifted.sort_index(level='time',inplace=True)

In [None]:
# populate enc/dec settings. any parameter settable should be defined in parameter cell above and used here as a variable

encode_settings = trodes2SS.AttrDict({'sampling_rate': 3e4,
                                'pos_bins': np.arange(0,max_pos,1), # actually indices of valid bins. different from pos_bins above 
                                'pos_bin_edges': np.arange(0,max_pos + .1,1), # indices of valid bin edges
                                'pos_bin_delta': pos_bin_delta, 
                                # 'pos_kernel': sp.stats.norm.pdf(arm_coords_wewant, arm_coords_wewant[-1]/2, 1),
                                'pos_kernel': sp.stats.norm.pdf(np.arange(0,max_pos,1), max_pos/2, 1), #note that the pos_kernel mean should be half of the range of positions (ie 180/90)     
                                'pos_kernel_std': 0, # 0 for histogram encoding model, 1+ for smoothing
                                'mark_kernel_std': int(20), 
                                'pos_num_bins': max_pos, 
                                'pos_col_names': [pos_col_format(ii, max_pos) for ii in range(max_pos)], # or range(0,max_pos,10)
                                'arm_coordinates': binned_arm_coords,   
                                'spk_amp': 60,
                                'vel': 0}) 

decode_settings = trodes2SS.AttrDict({'trans_smooth_std': 2,
                                'trans_uniform_gain': 0.0001,
                                'time_bin_size':60})

sungod_trans_mat = sungod_util.calc_sungod_trans_mat(encode_settings, decode_settings)

In [None]:
%%time
# run encoder
print('Starting encoder')

encoder = OfflinePPEncoder(linflat=encoding_pos, dec_spk_amp=decoding_marks, encode_settings=encode_settings, 
                               decode_settings=decode_settings, enc_spk_amp=encoding_marks_shifted, dask_worker_memory=1e9,
                               dask_chunksize = None)

    #new output format from encoder: observ_obj
observ_obj = encoder.run_encoder()


In [None]:
%%time

print('Starting decoder')

decoder = OfflinePPDecoder(observ_obj=observ_obj, trans_mat=sungod_trans_mat, 
                               prob_no_spike=encoder.prob_no_spike,
                               encode_settings=encode_settings, decode_settings=decode_settings, 
                               time_bin_size=decode_settings.time_bin_size, all_linear_position=binned_linear_pos)

posteriors = decoder.run_decoder()
print('Decoder finished!')
print('Posteriors shape: '+ str(posteriors.shape))

In [None]:
# TEMPORARY: save posteriors and position
posterior_file_name = '/data2/mcoulter/remy/' + rat_name + '_' + str(day) + '_' + str(epoch) + '_shuffle_' + str(shift_amount) + '_posteriors_functionalized_TEST.nc'

post1 = posteriors.apply_time_event(rips, event_mask_name='ripple_grp')
post2 = post1.reset_index()
post3 = trodes2SS.convert_dan_posterior_to_xarray(post2, tetrodes, 
                                        velocity_thresh_for_enc_dec, encode_settings, decode_settings, sungod_trans_mat, order, shift_amount)
    #print(len(post3))
post3.to_netcdf(posterior_file_name)
print('Saved posteriors to '+posterior_file_name)

    # to export linearized position to MatLab: again convert to xarray and then save as netcdf

position_file_name = '/data2/mcoulter/remy/' + rat_name + '_' + str(day) + '_' + str(epoch) + '_shuffle_' + str(shift_amount) + '_linearposition_functionalized_TEST.nc'

linearized_pos1 = binned_linear_pos.apply_time_event(rips, event_mask_name='ripple_grp')
linearized_pos2 = linearized_pos1.reset_index()
linearized_pos3 = linearized_pos2.to_xarray()
linearized_pos3.to_netcdf(position_file_name)
print('Saved linearized position to '+position_file_name)

In [None]:
%%time
# run classifier 
sungod_no_offset = sungod_util.calc_sungod_trans_mat(encode_settings, decode_settings, uniform_gain=0)

causal_state1, causal_state2, causal_state3, acausal_state1, acausal_state2, acausal_state3, trans_mat_dict = sungod_util.decode_with_classifier(decoder.likelihoods, 
                                                                                                                                 sungod_no_offset, 
                                                                                                                                 encoder.occupancy, discrete_tm_val)

In [None]:
# save classifier outputs 
base_name = '/data2/mcoulter/remy/' + rat_name + '_' + day_ep + '_shuffle_' + str(shift_amount) + '_posterior_'
fname = 'causal'
trodes2SS.convert_save_classifier(base_name, fname, causal_state1, causal_state2, causal_state3, tetrodes, decoder.likelihoods,
                                  encode_settings, decode_settings, rips, velocity_thresh_for_enc_dec, velocity_buffer, sungod_no_offset, order, shift_amount)

fname = 'acausal'
trodes2SS.convert_save_classifier(base_name, fname, acausal_state1, acausal_state2, acausal_state3, tetrodes, decoder.likelihoods,
                                  encode_settings, decode_settings, rips, velocity_thresh_for_enc_dec, velocity_buffer, sungod_no_offset, order, shift_amount)

In [None]:
base_name

In [None]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 aspect=2 colorbar=True]
%%opts Points [height=100 width=250 aspect=2 ] (marker='o' color='#AAAAFF' size=1 alpha=0.7)
%%opts Polygons (color='grey', alpha=0.5 fill_color='grey' fill_alpha=0.5)
#%%opts Image {+framewise}

# visualize posteriors - note will only work a small chunck of the posteriors table
# currently this is ugly because posteriors, linpos, and rips all refer to different chunks of data

dec_viz = DecodeVisualizer(posteriors[0:200000], linpos=binned_linear_pos.loc[(binned_linear_pos["linvel_flat"]>4)], riptimes=rips[50:100], enc_settings=encode_settings)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=100, slide=10)