In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from matplotlib import animation
from IPython.display import HTML

from startup import extract_xy, sort_led_locations, correct_targets, median_filter, remove_jumps_to_feeder
from loading_data import get_data, unzip_nvt_file, zip_nvt_file
from analyze_decode_bytrial import get_trials

import warnings
warnings.filterwarnings("ignore")

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "correcting_position")

In [None]:
import info.r066d2 as info

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)

In [None]:
filename = os.path.join(dataloc, info.position_filename)
variance_thresh = 5.
epsilon = 0.01

In [None]:
# Load raw position from file
nvt_data = nept.load_nvt(filename)
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
# Initialize x, y arrays
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)

# X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
# Briefly, each record contains up to 50 targets, each stored in 32bit field.
# X field at [20:31] and Y at [4:15].
# TODO: make into a separate function in nept
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # When the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]))

# Replacing targets with no samples with nan instead of 0
x[x == 0] = np.nan
y[y == 0] = np.nan

# Scale the positions
x = x / info.scale_targets[0]
y = y / info.scale_targets[1]

In [None]:
plt.plot(times[9000:9400], y[9000:9400], "r.")
# plt.ylim(20, 70)
plt.show()

In [None]:
plt.plot(x[9000:9400], y[9000:9400], "r.")
# plt.ylim(20, 70)
plt.show()

In [None]:
# Finding which feeder led is on over time
leds = []
leds.extend([(event, 'led1') for event in events['led1']])
leds.extend([(event, 'led2') for event in events['led2']])
sorted_leds = sorted(leds)

# Get an array of feeder locations when that feeder is actively flashing
feeder_x_location = np.empty(times.shape[0]) * np.nan
feeder_y_location = np.empty(times.shape[0]) * np.nan

for time, label in sorted_leds:
    idx = nept.find_nearest_idx(times, time)
    x_location = info.path_pts['feeder1'][0] if label == 'led1' else info.path_pts['feeder2'][0]
    y_location = info.path_pts['feeder2'][0] if label == 'led1' else info.path_pts['feeder2'][1]

    feeder_x_location[idx:idx+4] = x_location
    feeder_y_location[idx:idx+4] = y_location

In [None]:
# Removing the contaminated samples that are closest to the feeder location
def remove_contaminated(targets, current_feeder):
    new_targets = targets
    for i, (target, feeder) in enumerate(zip(targets, current_feeder)):
        if not np.isnan(feeder):
            idx = np.nanargmin(np.abs(target - feeder))
            target[idx] = np.nan
        new_targets[i] = target
    return new_targets

x = remove_contaminated(x, feeder_x_location)
y = remove_contaminated(y, feeder_y_location)

In [None]:
# Calculating the mean of the remaining targets
xx = np.nanmean(x, axis=1)
yy = np.nanmean(y, axis=1)
ttimes = times

In [None]:
position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

In [None]:
plt.plot(position.x, position.y, "k.", ms=1)
plt.show()

In [None]:
plt.plot(position.time, position.y, "k.", ms=1)
plt.show()

In [None]:
plt.plot(position.time[9050:9100], position.y[9050:9100], "k.", ms=5)
plt.ylim(20,70)
plt.show()

In [None]:
plt.plot(x, y, "b.", ms=1)
plt.show()

In [None]:
plt.plot(times, y, 'k.', ms=1)
plt.show()

In [None]:
# Get the feeder locations
feeder_x, feeder_y = sort_led_locations(info, events, times)

# One target is contaminated when the distance between the targets is large
target_x_var = np.nanvar(x, axis=1)
target_y_var = np.nanvar(y, axis=1)

# Contaminated samples are using the feeder LED instead of the implant LEDs
contaminated_x_idx = target_x_var > variance_thresh
contaminated_y_idx = target_y_var > variance_thresh

# Removing the contaminated samples that are closest to the feeder location
def remove_contaminated_samples(targets, current_feeder, contaminated_idx):
    contaminated_targets = targets[contaminated_idx]
    contaminated_current_feeder = current_feeder[contaminated_idx]
    for i, (target, feeder) in enumerate(zip(contaminated_targets, contaminated_current_feeder)):
        if np.nansum(target) > 0:
            nidx = np.nanargmin(np.abs(target - feeder))
            target[nidx] = np.nan
        contaminated_targets[i] = target
    targets[contaminated_idx] = contaminated_targets
    
    return targets

x = remove_contaminated_samples(x, feeder_x, contaminated_x_idx)
y = remove_contaminated_samples(y, feeder_y, contaminated_y_idx)

In [None]:
target_x_var = np.nanvar(x, axis=1)

In [None]:
plt.hist(target_x_var, bins=50)

In [None]:
def sort_led_locations(info, events, times):
    """Combines and sorts led1 and led2 events

    Parameters
    ----------
    info: module
    events: dict of nept.Epochs
    times: np.array

    Returns
    -------
    feeder_x_location: np.array
    feeder_y_location: np.array

    """
    leds = []
    leds.extend([(event, 'led1') for event in events['led1']])
    leds.extend([(event, 'led2') for event in events['led2']])
    sorted_leds = sorted(leds)

    # Get an array of feeder locations when that feeder is actively flashing
    feeder_x_location = np.zeros(times.shape[0])
    feeder_y_location = np.zeros(times.shape[0])

    feeder1_x = info.path_pts['feeder1'][0]
    feeder1_y = info.path_pts['feeder1'][1]
    feeder2_x = info.path_pts['feeder2'][0]
    feeder2_y = info.path_pts['feeder2'][1]

#     last_label = ''

    for time, label in sorted_leds:
#         if label == last_label:
#             continue
        idx = nept.find_nearest_idx(times, time)
        x_location = feeder1_x if label == 'led1' else feeder2_x
        y_location = feeder1_y if label == 'led1' else feeder2_y

        feeder_x_location[idx] = x_location
        feeder_y_location[idx] = y_location

#         last_label = label
    return feeder_x_location, feeder_y_location

In [None]:
# herrrreeeeeee

leds = []
leds.extend([(event, 'led1') for event in events['led1']])
leds.extend([(event, 'led2') for event in events['led2']])
sorted_leds = sorted(leds)

# Get an array of feeder locations when that feeder is actively flashing
feeder_x_location = np.empty(times.shape[0]) * np.nan
feeder_y_location = np.empty(times.shape[0]) * np.nan

for time, label in sorted_leds:
    idx = nept.find_nearest_idx(times, time)
    x_location = info.path_pts['feeder1'][0] if label == 'led1' else info.path_pts['feeder2'][0]
    y_location = info.path_pts['feeder2'][0] if label == 'led1' else info.path_pts['feeder2'][1]

    feeder_x_location[idx:idx+4] = x_location
    feeder_y_location[idx:idx+4] = y_location

In [None]:
x[9014]

In [None]:
feeder_x_location[9014]

In [None]:
feeder_x_location[8970:9020], feeder_y_location[8970:9020]

In [None]:
plt.plot(times[8970:9020], y[8970:9020], 'b.')
plt.plot(times[8970:8974], y[8970:8974], 'r.')
plt.plot(times[8978:8978+4], y[8978:8978+4], 'r.')
plt.show()

In [None]:
# Removing the contaminated samples that are closest to the feeder location
def remove_contaminated(targets, current_feeder):
    new_targets = targets
    for i, (target, feeder) in enumerate(zip(targets, current_feeder)):
        if not np.isnan(feeder):
            idx = np.nanargmin(np.abs(target - feeder))
            target[idx] = np.nan
        new_targets[i] = target
    return new_targets

xxx = remove_contaminated(x, feeder_x_location)
yyy = remove_contaminated(y, feeder_y_location)

In [None]:
xxx[9013:9019]

In [None]:
plt.plot(times[8970:9020], yyy[8970:9020], 'b.')
plt.plot(times[8970:8974], yyy[8970:8974], 'r.')
plt.plot(times[8978:8978+4], yyy[8978:8978+4], 'r.')
plt.plot(times[9014:9018], yyy[9014:9018], 'c.')
plt.ylim(20, 70)
plt.show()

In [None]:
plt.plot(times, y, 'k.', ms=1)
plt.show()

In [None]:
# Removing the sample that is more than std + buffer from the mean of the targets for both x and y
targets_x_mean = np.nanmean(x, axis=1)[:, np.newaxis]
targets_x_std = np.nanstd(x, axis=1)[:, np.newaxis] + epsilon
keep_x_idx = np.abs(x - targets_x_mean) < targets_x_std
x[~keep_x_idx] = np.nan

targets_y_mean = np.nanmean(y, axis=1)[:, np.newaxis]
targets_y_std = np.nanstd(y, axis=1)[:, np.newaxis] + epsilon
keep_y_idx = np.abs(y - targets_y_mean) < targets_y_std
y[~keep_y_idx] = np.nan

In [None]:
# Calculating the mean of the remaining targets
xx = np.nanmean(x, axis=1)
yy = np.nanmean(y, axis=1)
ttimes = times

In [None]:
# Remove jumps to feeder location
xx, yy, ttimes = remove_jumps_to_feeder(xx, yy, times, info, jump_thresh=2, dist_thresh=10)

In [None]:
# Apply a median filter
xx, yy = median_filter(xx, yy)

In [None]:
position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

In [None]:
fig, ax = plt.subplots()
plt.plot(position.time, position.y, "k.", ms=3)
plt.xlabel("time")
plt.ylabel("y")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.tight_layout()
plt.show()

In [None]:
position.n_samples / len(targets) * 100

In [None]:
plt.plot(position.x, position.y, "b.", ms=1)
plt.show()

In [None]:
plt.plot(position.time, position.y, "b.", ms=1)
plt.show()

In [None]:
trial_epochs = get_trials(events, info.task_times["phase3"])
trial_idx = 7
start = trial_epochs[trial_idx].start
stop = trial_epochs[trial_idx].stop

In [None]:
trial = position.time_slice(start, stop)
plt.plot(trial.x, trial.y, "k.")
plt.show()