In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString, MultiPoint
import os
import numpy as np
import scipy.io as sio
import scipy.signal as signal
import vdmlab as vdm

In [None]:
import info.r063d2 as r063d2
import info.r063d3 as r063d3
info = r063d2

In [None]:
# pos_filename = 'C:/Users/Emily/Desktop/R063-2015-03-20-VT1.nvt'
# position = vdm.load_position(pos_filename, info.pxl_to_cm)

In [None]:
evt_filename = 'C:/Users/Emily/Desktop/R063-2015-03-20-Events.nev'
labels = dict(led1='TTL Output on AcqSystem1_0 board 0 port 2 value (0x0001).',
              led2='TTL Output on AcqSystem1_0 board 0 port 2 value (0x0002).',
              ledoff='TTL Output on AcqSystem1_0 board 0 port 2 value (0x0000).',
              pb1id='TTL Input on AcqSystem1_0 board 0 port 1 value (0x0040).',
              pb2id='TTL Input on AcqSystem1_0 board 0 port 1 value (0x0020).',
              pboff='TTL Input on AcqSystem1_0 board 0 port 1 value (0x0000).',
              feeder1='TTL Output on AcqSystem1_0 board 0 port 0 value (0x0004).',
              feeder2='TTL Output on AcqSystem1_0 board 0 port 0 value (0x0040).',
              feederoff='TTL Output on AcqSystem1_0 board 0 port 0 value (0x0000).')

events = vdm.load_events(evt_filename, labels)

pos_filename = 'C:/Users/Emily/Desktop/R063-2015-03-20-VT1.nvt'

In [None]:
def extract_xy(target, info):
    """Extracts x and y from neuralynx target. Converts to cm.
    
    Parameters
    ----------
    target: np.array
    
    Returns
    -------
    x: np.array
    y: np.array
    
    """
    binary_target = "{:032b}".format(target)
    x = int(binary_target[20:31], 2) / info.pxl_to_cm[0]
    y = int(binary_target[4:15], 2) / info.pxl_to_cm[1]
    
    return x, y

In [None]:
def load_shortcut_position(info, pos_filename, events):
    """Loads and corrects shortcut position.
    
    Parameters
    ----------
    info: module
    pos_filename: str
    events: dict
    
    Returns
    -------
    position: vdm.Position
    
    """
    nvt_data = vdm.load_nvt(pos_filename)
    targets = nvt_data['targets']
    times = nvt_data['time']

    # Initialize x, y arrays
    x = np.zeros(targets.shape)
    y = np.zeros(targets.shape)
    # time = np.zeros(targets.shape)

    # X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
    # Briefly, each record contains up to 50 targets, each stored in 32bit field. 
    # X field at [20:31] and Y at [4:15].
    for target in range(targets.shape[1]):
        this_sample = targets[:, target]
        for sample in range(targets.shape[0]):
            # When the bitfield is equal to zero there is no valid data for that field
            # and remains zero for the rest of the bitfields in the record.
            if this_sample[target] == 0:
                break
            x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]), info)

    # Remove columns with no target data
    col_idx = (np.sum(x==0, axis=0) == x.shape[0]) & (np.sum(y==0, axis=0) == y.shape[0])
    xs = np.array(x[:, ~col_idx])
    ys = np.array(y[:, ~col_idx])

    # This correction method assumes we are working with two targets 
    # (eg. subtracts the two targets, averages over two targets, etc.)
    if xs.shape[1] != 2 or ys.shape[1] != 2:
        raise ValueError("must have two targets for x and y")

    # Put the LED events in the same array, sorted by time
    leds = []
    leds.extend([(event, 'led1') for event in events['led1']])
    leds.extend([(event, 'led2') for event in events['led2']])
    sorted_leds = sorted(leds)

    # Get an array of feeder locations when that feeder is actively flashing
    feeder_x_location = np.zeros(xs.shape[0])
    feeder_y_location = np.zeros(ys.shape[0])

    feeder1_x = info.path_pts['feeder1'][0]
    feeder1_y = info.path_pts['feeder1'][1]
    feeder2_x = info.path_pts['feeder2'][0]
    feeder2_y = info.path_pts['feeder2'][1]

    last_label = ''

    for time, label in sorted_leds:
        if label == last_label:
            continue
        idx = vdm.find_nearest_idx(times, time)
        x_location = feeder1_x if label == 'led1' else feeder2_x
        y_location = feeder1_y if label == 'led1' else feeder2_y

        feeder_x_location[idx:] = x_location
        feeder_y_location[idx:] = y_location

        last_label = label

    # Initialize xx and yy as the first target
    xx = np.array(xs[:, 0])
    yy = np.array(ys[:, 0])

    # Find indices where only one target was available
    one_target_idx = (xs[:, 1]==0) | (ys[:, 1]==0)

    # One target is contaminated when the distance between the two targets is large
    target_x_dist = np.abs(xs[:, 1] - xs[:, 0])
    target_y_dist = np.abs(ys[:, 1] - ys[:, 0])

    # Contaminated samples are using the feeder LED instead of the implant LEDs
    contamination_thresh = 5
    contaminated_idx = (target_x_dist > contamination_thresh) | (target_y_dist > contamination_thresh)

    # Non contaminated implant LED samples with two targets get averaged
    idx = ~contaminated_idx & ~one_target_idx
    xx[idx] = np.mean(xs[idx], axis=1)
    yy[idx] = np.mean(ys[idx], axis=1)

    # For contaminated samples, we use the sample that is furthest from the feeder location
    feeder_x_dist = np.abs(xs - feeder_x_location[..., np.newaxis])
    feeder_y_dist = np.abs(ys - feeder_y_location[..., np.newaxis])

    feeder_dist = feeder_x_dist + feeder_y_dist
    furthest_idx = np.argmax(feeder_dist, axis=1)

    idx = contaminated_idx & ~one_target_idx
    xx[idx] = xs[idx, furthest_idx[idx]]
    yy[idx] = ys[idx, furthest_idx[idx]]

    # Applying a median filter to the x and y positions
    kernel = 7
    filtered_x = signal.medfilt(xx, kernel_size=kernel)
    filtered_y = signal.medfilt(yy, kernel_size=kernel)

    # Construct a vdm.Position object
    position = vdm.Position(np.hstack(np.array([filtered_x, filtered_y])[..., np.newaxis]), times)
    
    return position

In [None]:
position = load_shortcut_position(info, pos_filename, events)

In [None]:
np.sum(np.sqrt(np.diff(position.x)**2 + np.diff(position.y)**2)>10)

In [None]:
plt.plot(position.x, position.y, 'g.', ms=2)
idx = np.where(np.append(np.sqrt(np.diff(position.x)**2 + np.diff(position.y)**2), np.array([0]))>10)[0]
plt.plot(position.x[idx], position.y[idx], 'r.', ms=10)
plt.plot(position.x[idx-1], position.y[idx-1], 'k.', ms=10)
plt.plot(position.x[idx+1], position.y[idx+1], 'm.', ms=10)
plt.show()

In [None]:
plt.plot(position.x, position.y, 'g.', ms=2)
plt.show()

In [None]:
plt.plot(position.time, position.y, 'b.', ms=2)
plt.show()

In [None]:
plt.plot(position.time, position.x, 'k.', ms=2)
plt.show()

In [None]:
final_x = np.array(xx)
final_y = np.array(yy)
final_time = np.array(ttime)

# Find those indices that have both targets contaminated by the feeder LEDs
# by locating unnatural jumps in the position. Not including jumps that are
# due to jumps in time (from stopping the recording). 
time_thresh = 1.
jump_thresh = 50
dist_thresh = 60

while True: 
    jumps = np.append(np.array([0]), np.sqrt(np.diff(final_x)**2 + np.diff(final_y)**2))
    remove_idx = jumps > jump_thresh
    print(np.sum(remove_idx))

    time_jumps = np.append(np.diff(final_time) > time_thresh, np.array([False], dtype=bool))    
    remove_idx[time_jumps] = False
    print(np.sum(remove_idx))
    
#     dist_feeder1 = np.sqrt((final_x - feeder1_x)**2 + (final_y - feeder1_y)**2)
#     dist_feeder2 = np.sqrt((final_x - feeder2_x)**2 + (final_y - feeder2_y)**2)
#     dist_feeder = np.minimum(dist_feeder1, dist_feeder2)
#     dist_jumps = dist_feeder > dist_thresh
#     remove_idx[dist_jumps] = False

    if np.sum(remove_idx) > 0:
        final_x = final_x[~remove_idx]
        final_y = final_y[~remove_idx]
        final_time = final_time[~remove_idx] 

    else:
        break

In [None]:
filtered_x = signal.medfilt(final_x)
filtered_y = signal.medfilt(final_y)

In [None]:
len(final_x)

In [None]:
np.sum(np.sqrt(np.diff(final_x)**2 + np.diff(final_y)**2)>100)

In [None]:
np.sum(np.sqrt(np.diff(filtered_x)**2 + np.diff(filtered_y)**2)>100)

In [None]:
plt.hist(np.sqrt(np.diff(filtered_x)**2 + np.diff(filtered_y)**2))
plt.show()

In [None]:
plt.plot(filtered_x, filtered_y, 'g.', ms=2)
plt.show()

In [None]:
def extract_color(target):
    binary_target = "{:032b}".format(target)
    
    color = dict()
    color['red'] = int(binary_target[1], 2)
    color['green'] = int(binary_target[2], 2)
    color['blue'] = int(binary_target[3], 2)
    color['raw_red'] = int(binary_target[17], 2)
    color['raw_green'] = int(binary_target[18], 2)
    color['raw_blue'] = int(binary_target[19], 2)
    color['intensity'] = int(binary_target[16], 2)
    
    return color

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString, MultiPoint

import vdmlab as vdm

from load_data import get_pos, get_raw_pos, get_events
from analyze_maze import spikes_by_position
from analyze_plotting import plot_intersects, plot_zone

import sys
# sys.path.append('E:\\code\\python-vdmlab\\projects\\emily_shortcut\\info')
sys.path.append('C:\\Users\\Emily\\Code\\emi_shortcut\\info')
import info.r063d2 as r063d2

In [None]:
output_path = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\matlab\\spike_pos\\'
# output_path = 'E:\\code\\emi_shortcut\\cache\\matlab\\spike_pos\\'
import info.r063d2 as r063d2
info = r063d2

In [None]:
corrected = get_pos(info.pos_mat, info.pxl_to_cm)
raw = get_raw_pos(info.raw_pos_mat, info.pxl_to_cm)
events = get_events(info.event_mat)

In [None]:
np.sum(np.sqrt(np.diff(corrected.x)**2 + np.diff(corrected.y)**2)>10)

In [None]:
len(raw.x) - len(corrected.x)

In [None]:
len(raw.x), len(corrected.x)

In [None]:
# Plot to check
plt.plot(corrected.x, corrected.y, 'b.', ms=2)
plt.show()

In [None]:
plt.hist(np.sqrt(np.diff(corrected.x)**2 + np.diff(corrected.y)**2), 100)
plt.show()

In [None]:
corrected.n_samples # 229833

In [None]:
# Plot to check
plt.plot(raw.x, raw.y, 'b.', ms=1)
plt.show()