In [6]:
from intanutil.load_intan_rhd_format import read_data

import numpy as np
import matplotlib.pyplot as plt
import glob
import re # Regular expression operations
import pandas as pd
from scipy import signal
import mne
from tqdm import tqdm
import os
from shapely.geometry import Point, Polygon
import geopandas as gpd

plt.rcParams["axes.labelsize"]=16
plt.rcParams["font.size"]=14
plt.rcParams["font.family"] = "Arial"

In [79]:
def init_ezm_gem():
    open_area1 = Polygon([(78, 40), (128, 110), (298, 119), (360, 60), (280, 0),(156, 0)])
    open_area2 = Polygon([(280, 291), (338, 369), (260, 400), (106, 400), (49, 335), (113, 277)])
    close_area1 = Polygon([(49, 335), (113, 277), (128, 110), (78, 40), (0,100), (0, 270)])
    close_area2 = Polygon([(360, 60), (298, 119), (280, 291), (338, 369), (400, 280), (400, 160)])
return [open_area1, open_area2, close_area1, close_area2]
def init_electrod_gem()
    '''
    The logic of the id array is that the indices from 0-63 correspond to the
    electrode pads from top to bottom in their geometrical positions, and
    the corresponding entries show which Intan amplifier channel belongs to
    that electrode pad. (e.g. Ch 51 in Intan software is the deepest
    electrode, Ch 28 in Intan software is the most superficial one).
    '''

    id = {0:[51], 1:[53], 2:[48], 3:[37], 4:[61], 5:[59], 6:[36], 7:[44], 
          8:[63], 9:[58], 10:[55], 11:[50], 12:[47], 13:[42], 14:[39], 15:[34], 
          16:[62], 17:[57], 18:[54], 19:[49], 20:[46], 21:[41], 22:[38], 23:[33], 
          24:[52], 25:[60], 26:[40], 27:[35], 28:[56], 29:[43], 30:[45], 31:[32], 
          32:[3], 33:[4], 34:[5], 35:[16], 36:[13], 37:[0], 38:[24], 39:[19], 
          40:[30], 41:[25], 42:[22], 43:[17], 44:[14], 45:[9], 46:[6], 47:[2],
          48:[31], 49:[26], 50:[23], 51:[18], 52:[15], 53:[10], 54:[8], 55:[1], 
          56:[7], 57:[11], 58:[21], 59:[20], 60:[12], 61:[27], 62:[29], 63:[28]
         }

    vertical_distance = 25 ## um
    horizontal_distance = 5 ## um

    geometry = {}

    for i in range(64):
        geometry[id[63-i][0]] = (i*horizontal_distance, i*vertical_distance)
        
    
    return [geometry]


def load_all_data(dataset):
    
    data     = read_data  (dataset['lfp_file'])
    location = pd.read_hdf(dataset['loc_file']) 
    
    ch_name = [d['native_channel_name'] for d in data['amplifier_channels']]
    ch_impedence = np.array([int(d['electrode_impedance_magnitude']) for d in data['amplifier_channels']])/1000
    event_raw = np.array(data['board_dig_in_data'])
    
    return {'lfp_file'     : dataset['lfp_file'], 
            'loc_file'     : dataset['loc_file'], 
            'lfp'          : data,
            'loc'          : location,
            'ch_name'      : ch_name,
            'ch_impedence' : ch_impedence,
            'event_raw'    : event_raw,
            'sampling_rate': 20000} # get it from the raw data

def process_downsample_by_50(dataset):
    ephys_2000Hz = signal.decimate(dataset['lfp']['amplifier_data'], 10, ftype='fir', axis= -1, zero_phase=True)
    ephys_400hz = signal.decimate(ephys_2000Hz, 5, ftype='fir', axis= -1, zero_phase=True)
    lfp         = dataset['lfp']
    lfp['amplifier_data'] =  ephys_400hz

    dataset.update({'lfp':lfp})
    dataset.update({'sampling_rate':400})
    return dataset

def process_band_pass_filer(dataset):
    freqs = {
    'Raw':0,
    'delta':{'low': 1,      'high':4},
    'theta':{'low': 4,      'high':8},
    'alpha':{'low': 8,      'high':13},
    'beta' :{'low': 13,     'high':30},
    'gamma':{'low': 30,     'high':70}
    }
    sampling_rate = dataset['sampling_rate']
  
    theta = np.zeros(np.shape(dataset['lfp']['amplifier_data']))
    beta  = np.zeros(np.shape(dataset['lfp']['amplifier_data']))
    gamma = np.zeros(np.shape(dataset['lfp']['amplifier_data']))
    
    bands = {'theta':{'value':theta},
             'beta' :{'value':beta},
             'gamma':{'value':gamma}
            }
    
    for i in tqdm(range(dataset['lfp']['amplifier_data'].shape[0])):
        for frq in bands:
            bands[frq]['value'][i, :] = mne.filter.filter_data(dataset['lfp']['amplifier_data'][i, :], 
                                               sfreq  = sampling_rate, 
                                               l_freq = freqs[frq]['low'], 
                                               h_freq = freqs[frq]['high'], 
                                               n_jobs=1)
    dataset['bands'] = bands    
    return dataset
    

def process_hilbert_tranform(dataset):

    for frq in dataset['bands']:
        power = np.zeros(dataset['bands'][frq]['value'])
        phase = np.zeros(dataset['bands'][frq]['value'])

        for i in range(dataset['lfp']['amplifier_data'].shape[0]):        
            # convert uV to mV
            hilbert_trans = signal.hilbert(dataset['bands'][frq]['value'][i, :]*10e-3) 
            ### square the amplitude to get the power
            power[i, :] = np.abs(hilbert_trans)**2 
            phase[i, :] = np.angle(hilbert_trans)

        dataset['bands'][frq].update({'power':power})
        dataset['bands'][frq].update({'phase':phase})
    return dataset


def plot_location(dataset):
    score = dataset['loc'].columns[0][0];
    plt.scatter(x = dataset['loc'][score, 'shoulder', 'x'][:],
                y = dataset['loc'][score, 'shoulder', 'y'][:])

def plot_hist(dataset, band_frq):

    working_ch_hipp = range(30) + 30
    fig, axs = plt.subplots(len(working_ch_hipp), 1, figsize=(4, 4*8))

    for i, ele in enumerate(working_ch_hipp):
        phase_diff1 = dataset['bands'][band_frq][0,:] - dataset['bands'][band_frq][ele-30,:]
        axs[i].hist(theta_phase_diff1, bins=round(len(theta_phase_diff1)/srate*0.5), alpha=1) ### bin size = 0.5 second
        axs[i].set_title('Theta phase lag / Ch64 vs. Ch' + str(ele))
        axs[i].grid(True)

    fig.tight_layout()
    plt.show()

def plot_event(dataset):
    sample_rate = 20000
    time = np.arange(0, dataset['event_raw'].shape[-1]/sample_rate, 1/sample_rate)
    plt.figure(figsize=(11, 4))
    plt.plot(time, dataset['event_raw'][0], '--k', alpha=.4)
    plt.title('Trigger + video_frame')


In [None]:
# initizaltion
[open_area1, open_area2, close_area1, close_area2] = init_ezm_gem()
[geometry] = init_electrod_gem()

#define the location
mPFC_labels = [
 'A-000', 'A-001', 'A-002', 'A-003', 'A-005', 'A-006', 'A-007', 'A-008', 'A-009', 'A-010', 'A-011',
 'A-012', 'A-013', 'A-014', 'A-015', 'A-016', 'A-017', 'A-018', 'A-019', 'A-020', 'A-021', 'A-022', 
 'A-023', 'A-024', 'A-025', 'A-026', 'A-027', 'A-028', 'A-029', 'A-030', 'A-031']

hipp_labels = ['A-036', 'A-044', 'A-050', 'A-055', 'A-058', 'A-059', 'A-061', 'A-063']

arena_data = {'lfp_dir'           : 'raw_data/mBWfus004_Arena/',
              'lfp_file_list'     : ['mBWfus004_arena_201224_150113.rhd'],
              'loc_file'          :  'Trajectory/mBWfus004_arena.h5',
              'brain_area'        : {'mPFC':mPFC_labels, 
                                     'hipp':hipp_labels},
              'electrod_geometry' : geometry,
              'timing'            : {'video': 2.0, 
                                     'lfp'  :18.31}
             }

ezm_data   = {'lfp_dir      '     : ['raw_data/mBWfus004_EZM/'],
              'lfp_file_list'     : ['mBWfus004_EZM_201224_182952.rhd'],
              'loc_file'          :  'Trajectory/mBWfus004_ezm.h5',
              'brain_area'        :{'mPFC':mPFC_labels, 
                                    'hipp':hipp_labels},
              'electrod_geometry' : geometry,
              'ezm_gem'           : [open_area1, open_area2, close_area1, close_area2],
              'timing'            : {'video': 6.0, 
                                     'lfp':15.36}
             }

#load raw data
arena_data = load_all_data(arena_data)
ezm_data   = load_all_data(ezm_data)

#Process Spike Detection
ezm_data   = process_get_spike_train(ezm_data)
arena_data = process_get_spike_train(arena_data)

#Process EMZ Data
ezm_data = process_downsample_by_50(ezm_data)
ezm_data = process_band_pass_filer(ezm_data)
ezm_dat  = process_hilbert_tranform(ezm_data)

#Process ARENA Data
arena_data = process_downsample_by_50(arena_data)
arena_data = process_band_pass_filer(arena_data)
arena_data = process_hilbert_tranform(arena_data)

#compare the power when
plot_compare(arena_data, ezm_data)

#align loc and lfp
ezm_data = process_compare_lfp_vs_loc(ezm_data)

# Compare power vs location (hue: frq band)
plot_compare_lfp_power_vs_loc(ezm_data, grouby_brain_area = ['mPFC','hipp'])
stat_compare_lfp_power_vs_loc(ezm_data, grouby_brain_area = ['mPFC','hipp'])

# Compare power coh (hue: frq band)
plot_compare_lfp_power_coh_vs_loc(ezm_data, 
                            brain_area_1 = 'mPFC', 
                            brain_area_2 = 'hipp')

stat_compare_lfp_power_coh_vs_loc(ezm_data, 
                              brain_area_1 = 'mPFC', 
                              brain_area_2 = 'hipp')

# Compare phase coh (hue: frq band)
plot_compare_lfp_phase_coh_vs_loc(ezm_data, 
                            brain_area_1 = 'mPFC', 
                            brain_area_2 = 'hipp')

stat_compare_lfp_phase_coh_vs_loc(ezm_data, 
                              brain_area_1 = 'mPFC', 
                              brain_area_2 = 'hipp')

# Compare spike vs Loc
plot_compare_fire_rate_vs_loc(ezm_data, grouby_brain_area = ['mPFC','hipp'])
