In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString, MultiPoint
import os
import numpy as np
import scipy.io as sio
import scipy.signal as signal
import vdmlab as vdm

In [None]:
import info.r063d2 as r063d2
info = r063d2

In [None]:
# pos_filename = 'C:/Users/Emily/Desktop/R063-2015-03-20-VT1.nvt'
# position = vdm.load_position(pos_filename, info.pxl_to_cm)

In [None]:
evt_filename = 'C:/Users/Emily/Desktop/R063-2015-03-20-Events.nev'
labels = dict(led1='TTL Output on AcqSystem1_0 board 0 port 2 value (0x0001).',
              led2='TTL Output on AcqSystem1_0 board 0 port 2 value (0x0002).',
              ledoff='TTL Output on AcqSystem1_0 board 0 port 2 value (0x0000).',
              pb1id='TTL Input on AcqSystem1_0 board 0 port 1 value (0x0040).',
              pb2id='TTL Input on AcqSystem1_0 board 0 port 1 value (0x0020).',
              pboff='TTL Input on AcqSystem1_0 board 0 port 1 value (0x0000).',
              feeder1='TTL Output on AcqSystem1_0 board 0 port 0 value (0x0004).',
              feeder2='TTL Output on AcqSystem1_0 board 0 port 0 value (0x0040).',
              feederoff='TTL Output on AcqSystem1_0 board 0 port 0 value (0x0000).')

events = vdm.load_events(evt_filename, labels)

pos_filename = 'C:/Users/Emily/Desktop/R063-2015-03-20-VT1.nvt'
nvt_data = vdm.load_nvt(pos_filename)
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
def extract_xy(target, info):
    """Extracts x and y from neuralynx target.
    
    Parameters
    ----------
    target: np.array
    
    Returns
    -------
    x: np.array
    y: np.array
    
    """
    binary_target = "{:032b}".format(target)
    x = int(binary_target[20:31], 2) / info.pxl_to_cm[0]
    y = int(binary_target[4:15], 2) / info.pxl_to_cm[1]
    
    return x, y

In [None]:
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)
time = np.zeros(targets.shape)
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # To speed things up we can take advantage of the fact that when 
        # the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]), info)
        
col_idx = (np.sum(x==0, axis=0) == x.shape[0]) & (np.sum(y==0, axis=0) == y.shape[0])
xs = np.array(x[:, ~col_idx])
ys = np.array(y[:, ~col_idx])

In [None]:
# Put the LED events in the same array, sorted by time
leds = []
leds.extend([(event, 'led1') for event in events['led1']])
leds.extend([(event, 'led2') for event in events['led2']])
sorted_leds = sorted(leds)

# Get an array of feeder locations when that feeder is actively flashing
feeder_x_location = np.zeros(xs.shape[0])
feeder_y_location = np.zeros(ys.shape[0])

feeder1_x = info.path_pts['feeder1'][0]
feeder1_y = info.path_pts['feeder1'][1]
feeder2_x = info.path_pts['feeder2'][0]
feeder2_y = info.path_pts['feeder2'][1]

last_label = ''

for time, label in sorted_leds:
    if label == last_label:
        continue
    idx = vdm.find_nearest_idx(times, time)
    x_location = feeder1_x if label == 'led1' else feeder2_x
    y_location = feeder1_y if label == 'led1' else feeder2_y
    
    feeder_x_location[idx:] = x_location
    feeder_y_location[idx:] = y_location
    
    last_label = label
    
# Initialize xx and yy as the first target
xx = np.array(xs[:, 0])
yy = np.array(ys[:, 0])
ttime = times

# Find indices where only one target was available
one_target_idx = (xs[:, 1]==0) | (ys[:, 1]==0)

In [None]:
target_x_dist = np.abs(xs[:, 1] - xs[:, 0])
target_y_dist = np.abs(ys[:, 1] - ys[:, 0])

# Contaminated samples are using the feeder LED instead of the implant LEDs
contamination_thresh = 5
contaminated_idx = (target_x_dist > contamination_thresh) | (target_y_dist > contamination_thresh)

# Non contaminated implant LED samples with two targets get averaged
idx = ~contaminated_idx & ~one_target_idx
xx[idx] = np.mean(xs[idx], axis=1)
yy[idx] = np.mean(ys[idx], axis=1)

# For contaminated samples, we use the sample that is furthest from the feeder location
feeder_x_dist = np.abs(xs - feeder_x_location[..., np.newaxis])
feeder_y_dist = np.abs(ys - feeder_y_location[..., np.newaxis])

feeder_dist = feeder_x_dist + feeder_y_dist
furthest_idx = np.argmax(feeder_dist, axis=1)

idx = contaminated_idx & ~one_target_idx
xx[idx] = xs[idx, furthest_idx[idx]]
yy[idx] = ys[idx, furthest_idx[idx]]

In [None]:
# Applying a median filter to the x and y positions
kernel = 9
filtered_x = signal.medfilt(xx, kernel_size=kernel)
filtered_y = signal.medfilt(yy, kernel_size=kernel)

In [None]:
np.sum(np.sqrt(np.diff(filtered_x)**2 + np.diff(filtered_y)**2)>10)

In [None]:
plt.plot(filtered_x, filtered_y, 'g.', ms=2)
idx = np.where(np.append(np.sqrt(np.diff(filtered_x)**2 + np.diff(filtered_y)**2), np.array([0]))>10)[0][2]
plt.plot(filtered_x[idx], filtered_y[idx], 'r.', ms=10)
plt.plot(filtered_x[idx-1], filtered_y[idx-1], 'k.', ms=10)
plt.plot(filtered_x[idx+1], filtered_y[idx+1], 'm.', ms=10)
plt.show()

In [None]:
plt.plot(filtered_x, filtered_y, 'g.', ms=2)
plt.show()

In [None]:
final_x = np.array(xx)
final_y = np.array(yy)
final_time = np.array(ttime)

# Find those indices that have both targets contaminated by the feeder LEDs
# by locating unnatural jumps in the position. Not including jumps that are
# due to jumps in time (from stopping the recording). 
time_thresh = 1.
jump_thresh = 50
dist_thresh = 60

while True: 
    jumps = np.append(np.array([0]), np.sqrt(np.diff(final_x)**2 + np.diff(final_y)**2))
    remove_idx = jumps > jump_thresh
    print(np.sum(remove_idx))

    time_jumps = np.append(np.diff(final_time) > time_thresh, np.array([False], dtype=bool))    
    remove_idx[time_jumps] = False
    print(np.sum(remove_idx))
    
#     dist_feeder1 = np.sqrt((final_x - feeder1_x)**2 + (final_y - feeder1_y)**2)
#     dist_feeder2 = np.sqrt((final_x - feeder2_x)**2 + (final_y - feeder2_y)**2)
#     dist_feeder = np.minimum(dist_feeder1, dist_feeder2)
#     dist_jumps = dist_feeder > dist_thresh
#     remove_idx[dist_jumps] = False

    if np.sum(remove_idx) > 0:
        final_x = final_x[~remove_idx]
        final_y = final_y[~remove_idx]
        final_time = final_time[~remove_idx] 

    else:
        break

In [None]:
filtered_x = signal.medfilt(final_x)
filtered_y = signal.medfilt(final_y)

In [None]:
len(final_x)

In [None]:
# corrected.n_samples # 229833

In [None]:
np.sum(np.sqrt(np.diff(final_x)**2 + np.diff(final_y)**2)>100)

In [None]:
np.sum(np.sqrt(np.diff(filtered_x)**2 + np.diff(filtered_y)**2)>100)

In [None]:
plt.hist(np.sqrt(np.diff(filtered_x)**2 + np.diff(filtered_y)**2))
plt.show()

In [None]:
plt.plot(filtered_x, filtered_y, 'g.', ms=2)
plt.show()

In [None]:
def extract_color(target):
    binary_target = "{:032b}".format(target)
    
    color = dict()
    color['red'] = int(binary_target[1], 2)
    color['green'] = int(binary_target[2], 2)
    color['blue'] = int(binary_target[3], 2)
    color['raw_red'] = int(binary_target[17], 2)
    color['raw_green'] = int(binary_target[18], 2)
    color['raw_blue'] = int(binary_target[19], 2)
    color['intensity'] = int(binary_target[16], 2)
    
    return color

In [None]:
n_targets = 2
xs = xs[:, 0:n_targets]

In [None]:
xs.reshape(xs.shape[0]/2, 2)

In [None]:
y = position[:-1]
t = position[1:]
dist = y.distance(t)
dist = np.hstack((dist, np.array([0])))

In [None]:
def find_event_idx(position, events, duration=4):
    # Duration is in timestamps
    event_position = np.zeros(position.n_samples, dtype=bool)
    
    for event in events:
        idx = vdm.find_nearest_idx(position.time, event)
        remaining_samples = position.n_samples - idx
        for i in range(min(duration, remaining_samples)):
            event_position[idx+i] = True
    return event_position

In [None]:
led1_on = find_event_idx(position, events['led1'])

In [None]:
led2_on = find_event_idx(position, events['led2'])

In [None]:
feeder_radius = 40
feeder1_pt = Point(info.path_pts['feeder1'])
feeder1 = feeder1_pt.buffer(feeder_radius)

feeder2_pt = Point(info.path_pts['feeder2'])
feeder2 = feeder2_pt.buffer(feeder_radius)

In [None]:
def is_in_zone(position, zone):
    within_zone = np.zeros(position.n_samples, dtype=bool)
    for i, (x, y) in enumerate(zip(position.x, position.y)):
        point = Point(x, y)
        if zone.contains(point):
            within_zone[i] = True
    return within_zone

In [None]:
within_feeder1 = is_in_zone(position, feeder1)

In [None]:
within_feeder2 = is_in_zone(position, feeder2)

In [None]:
fixed = position[~(led1_on & within_feeder1) & ~(led2_on & within_feeder2)]
plt.plot(fixed.x, fixed.y, 'g.', ms=2)
plt.show()

In [None]:
# Plot to check
plt.plot(corrected.x, corrected.y, 'b.', ms=2)
plt.show()

In [None]:
position.n_samples - fixed.n_samples

In [None]:
position.n_samples, fixed.n_samples

In [None]:
def position_remove(position, remove_idx):
    """Removes indices from vdmlab.Position object.
    
    Parameters
    ----------
    position: vdmlab.Position
    remove_idx: list
    
    Returns
    -------
    filtered_position: vdmlab.Position
    """
    data = np.delete(position.data, remove_idx, axis=0)
    time = np.delete(position.time, remove_idx)
    
    return vdm.Position(data, time)

In [None]:
np.shape(position.data)

In [None]:
len(remove_idx), len(get_all)

In [None]:
pos = position_remove(position, remove_idx)

In [None]:
plt.plot(pos.x, pos.y, 'g.', ms=2)
plt.show()

In [None]:
len(pos.x), len(position.x)

In [None]:
len(position.x) - len(pos.x)

In [None]:
plt.plot(pos.x, pos.y, 'b.', ms=1)
plt.show()

In [None]:
for i in range(len(position.x)):
    if i not in get_all:
        plt.plot(position.x[i], position.y[i], 'g.', ms=1)
plt.show()

In [None]:
light_remove_idx = []
lights = ['led1', 'led2']
for light in lights:
    for event in events[light]:
        idx = vdm.find_nearest_idx(position.time, event)
        light_remove_idx.append(idx)

In [None]:
len(light_remove_idx)

In [None]:
np.mean(dist)

In [None]:
def remove_teleports(position, events, speed_thresh, min_length):
    """Removes positions above a certain speed threshold

    Parameters
    ----------
    position : vdmlab.Position
    speed_thresh : int
        Maximum speed to consider natural rat movements. Anything
        above this theshold will not be included in the filtered positions.
    min_length : int
        Minimum length for a sequence to be included in filtered positions.

    Returns
    -------
    filtered_position : vdmlab.Position

    """
    remove_idx = []
    lights = ['led1', 'led2']
    for light in lights:
        for event in events[light]:
            idx = vdm.find_nearest_idx(position.time, event)
            remove_idx.append(idx)
            
    velocity = np.squeeze(position.speed().data)

    split_idx = np.where(velocity >= speed_thresh)[0]
    keep_idx = [idx for idx in np.split(np.arange(position.n_samples), split_idx) 
                if idx.size >= min_length and idx[0] in remove_idx]

    if len(keep_idx) == 0:
        raise ValueError("resulted in all position samples removed. Adjust min_length or speed_thresh.")

    x = []
    y = []
    time = []
    for idx_sequence in keep_idx:
        x.extend(position.x[idx_sequence[0]:idx_sequence[-1]])
        y.extend(position.y[idx_sequence[0]:idx_sequence[-1]])
        time.extend(position.time[idx_sequence[0]:idx_sequence[-1]])

    return vdm.Position(np.hstack([np.array(x)[..., np.newaxis],
                                   np.array(y)[..., np.newaxis]]), time)

In [None]:
pos = remove_teleports(position, events, speed_thresh=10, min_length=3)

In [None]:
plt.plot(pos.x, pos.y, 'b.', ms=2)
plt.show()

In [None]:
len(position.x) - len(pos.x)

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString, MultiPoint

import vdmlab as vdm

from load_data import get_pos, get_raw_pos, get_events
from analyze_maze import spikes_by_position
from analyze_plotting import plot_intersects, plot_zone

import sys
# sys.path.append('E:\\code\\python-vdmlab\\projects\\emily_shortcut\\info')
sys.path.append('C:\\Users\\Emily\\Code\\emi_shortcut\\info')
import info.r063d2 as r063d2

In [None]:
output_path = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\matlab\\spike_pos\\'
# output_path = 'E:\\code\\emi_shortcut\\cache\\matlab\\spike_pos\\'
import info.r063d2 as r063d2
info = r063d2

In [None]:
corrected = get_pos(info.pos_mat, info.pxl_to_cm)
raw = get_raw_pos(info.raw_pos_mat, info.pxl_to_cm)
events = get_events(info.event_mat)

In [None]:
np.sum(np.sqrt(np.diff(corrected.x)**2 + np.diff(corrected.y)**2)>10)

In [None]:
len(raw.x) - len(corrected.x)

In [None]:
len(raw.x), len(corrected.x)

In [None]:
# Plot to check
plt.plot(corrected.x, corrected.y, 'b.', ms=2)
plt.show()

In [None]:
plt.hist(np.sqrt(np.diff(corrected.x)**2 + np.diff(corrected.y)**2), 100)
plt.show()

In [None]:
corrected.n_samples # 229833

In [None]:
# Plot to check
plt.plot(raw.x, raw.y, 'b.', ms=1)
plt.show()

In [None]:
len(events['led1']), len(events['led2'])

In [None]:
events.keys()

In [None]:
for event in events['led1']:
    idx = vdm.find_nearest_idx(raw.time, event)
    plt.plot(raw.x[idx], raw.y[idx], 'r.', ms=2)
    plt.plot(raw.x[idx-1], raw.y[idx-1], 'b.', ms=2)
    plt.plot(raw.x[idx-2], raw.y[idx-2], 'c.', ms=2)
plt.show()

In [None]:
for event in events['led2']:
    idx = vdm.find_nearest_idx(raw.time, event)
    plt.plot(raw.x[idx], raw.y[idx], 'r.', ms=2)
    plt.plot(raw.x[idx-1], raw.y[idx-1], 'b.', ms=2)
    plt.plot(raw.x[idx-2], raw.y[idx-2], 'c.', ms=2)
plt.show()

In [None]:
def remove_position(position, events, feeder1, feeder2):
    """Removes position that occurs at event times.
    
    Parameters
    ----------
    position: vdmlab.Postition
    event: np.array
    
    Returns
    -------
    filtered_position: vdmlab.Position

    """
    max_idx = len(position.x)
    
    light_on = []
    for event in events['led1']:
        idx = vdm.find_nearest_idx(position.time, event)
        light_on.append(idx)
        if idx < max_idx-3:
            light_on.append(idx+1)
            light_on.append(idx+2)
            light_on.append(idx+3)
        
    for event in events['led2']:
        idx = vdm.find_nearest_idx(position.time, event)
        light_on.append(idx)
        if idx < max_idx-3:
            light_on.append(idx+1)
            light_on.append(idx+2)
            light_on.append(idx+3)
    
    remove_idx = []
    for idx in light_on:
        point = Point(position.x[idx], position.y[idx])
        if feeder1.contains(point) or feeder2.contains(point):
            remove_idx.append(idx)
    
    remove_idx = np.array(remove_idx)
        
    return position_remove(position, remove_idx)
    

In [None]:
np.diff(raw.x[:10])

In [None]:
feeder_radius = 40
feeder1_pt = Point(info.path_pts['feeder1'])
feeder1 = feeder1_pt.buffer(feeder_radius)

feeder2_pt = Point(info.path_pts['feeder2'])
feeder2 = feeder2_pt.buffer(feeder_radius)

In [None]:
fixed = remove_position(raw, events, feeder1, feeder2)

In [None]:
len(pos.time), len(fixed.time)

In [None]:
plt.plot(raw.x, raw.y, 'b.', ms=1)
plt.plot(feeder1.exterior.xy[0], feeder1.exterior.xy[1], 'r')
plt.plot(feeder2.exterior.xy[0], feeder2.exterior.xy[1], 'g')
plt.show()

In [None]:
raw

In [None]:
# Literally from position_shortcut.m

def light_on(events):
    starts = events['led1']
    stops = events['led1'] + 3
    
    led1_epochs = vdm.Epoch(starts, stops-starts)
    
    starts = events['led2']
    stops = events['led2'] + 3
    
    led2_epochs = vdm.Epoch(starts, stops-starts)
    
    return led1_epochs.join(led2_epochs)

In [None]:
light_epochs = light_on(events)

In [None]:
light_epochs.n_epochs

In [None]:
expand_by = 5
feeder1_center = Point(info.path_pts['feeder1'][0], info.path_pts['feeder1'][1])
feeder1 = feeder1_center.buffer(expand_by*2)
feeder2_center = Point(info.path_pts['feeder2'][0], info.path_pts['feeder2'][1])
feeder2 = feeder2_center.buffer(expand_by*2)

In [None]:
plt.plot(raw.x, raw.y, 'g.', ms=1)
plot_zone(feeder1)
plot_zone(feeder2)
plt.show()

In [None]:
def correct_position(position, events, feeder1, feeder2):
    """Removes positions above a certain speed threshold

    Parameters
    ----------
    position : vdmlab.Position
    events: dict
        With led1, feederoff, type, feeder1, pb2, ledoff, led2, pboff, label, pb1, feeder2 as keys.

    Returns
    -------
    filtered_position : vdmlab.Position

    """
    remove_idx = []
    for event in events['feederoff']:
        idx = vdm.find_nearest_idx(raw.time, event)
        for i in range(5):
            point = Point([pos.x[idx-i], pos.y[idx-i]])
            if feeder1.contains(point) or feeder2.contains(point):
                remove_idx.append(i)
    
    return position[~np.hstack(remove_idx)]

In [None]:
test = correct_position(raw, events, feeder1, feeder2)

In [None]:
plt.plot(test.x, test.y, 'g.', ms=1)
plt.show()