In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from matplotlib import animation
from IPython.display import HTML

from startup import extract_xy, sort_led_locations, correct_targets, median_filter, remove_jumps_to_feeder
from loading_data import get_data, unzip_nvt_file, zip_nvt_file
from analyze_decode_bytrial import get_trials

import warnings
warnings.filterwarnings("ignore")

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "correcting_position")

In [None]:
import info.r066d2 as info

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)

In [None]:
filename = os.path.join(dataloc, info.position_filename)
variance_thresh = 3.
epsilon = 0.01

In [None]:
# Load raw position from file
nvt_data = nept.load_nvt(filename)
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
# Initialize x, y arrays
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)

# X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
# Briefly, each record contains up to 50 targets, each stored in 32bit field.
# X field at [20:31] and Y at [4:15].
# TODO: make into a separate function in nept
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # When the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]))

# Replacing targets with no samples with nan instead of 0
x[x == 0] = np.nan
y[y == 0] = np.nan

# Scale the positions
x = x / info.scale_targets[0]
y = y / info.scale_targets[1]

In [None]:
plt.plot(x, y, "b.", ms=1)
plt.show()

In [None]:
# Get the feeder locations
feeder_x, feeder_y = sort_led_locations(info, events, times)

# One target is contaminated when the distance between the targets is large
target_x_var = np.nanvar(x, axis=1)
target_y_var = np.nanvar(y, axis=1)

# Contaminated samples are using the feeder LED instead of the implant LEDs
contaminated_x_idx = target_x_var > variance_thresh
contaminated_y_idx = target_y_var > variance_thresh

# Removing the contaminated samples that are closest to the feeder location
x_arrays = x[contaminated_x_idx]
x_values = feeder_x[contaminated_x_idx]
for array, value in zip(x_arrays, x_values):
    nidx = np.nanargmin(np.abs(array - value))
    array[nidx] = np.nan
x[contaminated_x_idx] = x_arrays

y_arrays = y[contaminated_y_idx]
y_values = feeder_y[contaminated_y_idx]
for array, value in zip(y_arrays, y_values):
    nidx = np.nanargmin(np.abs(array - value))
    array[nidx] = np.nan
y[contaminated_y_idx] = y_arrays

In [None]:
# Removing the sample that is more than std + buffer from the mean of the targets for both x and y
targets_x_mean = np.nanmean(x, axis=1)[:, np.newaxis]
targets_x_std = np.nanstd(x, axis=1)[:, np.newaxis] + epsilon
keep_x_idx = np.abs(x - targets_x_mean) < targets_x_std
x[~keep_x_idx] = np.nan

targets_y_mean = np.nanmean(y, axis=1)[:, np.newaxis]
targets_y_std = np.nanstd(y, axis=1)[:, np.newaxis] + epsilon
keep_y_idx = np.abs(y - targets_y_mean) < targets_y_std
y[~keep_y_idx] = np.nan

In [None]:
# Calculating the mean of the remaining targets
xx = np.nanmean(x, axis=1)
yy = np.nanmean(y, axis=1)
ttimes = times

In [None]:
# Remove jumps to feeder location
xx, yy, ttimes = remove_jumps_to_feeder(xx, yy, times, info, jump_thresh=10, dist_thresh=5)

In [None]:
# Apply a median filter
xx, yy = median_filter(xx, yy)

In [None]:
position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

In [None]:
fig, ax = plt.subplots()
plt.plot(position.time, position.y, "k.", ms=3)
plt.xlabel("time")
plt.ylabel("y")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.tight_layout()
plt.show()

In [None]:
position.n_samples / len(targets) * 100

In [None]:
plt.plot(position.x, position.y, "b.", ms=1)
plt.show()

In [None]:
trial_epochs = get_trials(events, info.task_times["phase3"])
trial_idx = 7
start = trial_epochs[trial_idx].start
stop = trial_epochs[trial_idx].stop

In [None]:
trial = position.time_slice(start, stop)
plt.plot(trial.x, trial.y, "k.")
plt.show()