In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from matplotlib import animation
from IPython.display import HTML

from startup import extract_xy, sort_led_locations, correct_targets, median_filter, remove_jumps_to_feeder
from loading_data import get_data, unzip_nvt_file, zip_nvt_file
from analyze_decode_bytrial import get_trials

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "correcting_position")

In [None]:
import info.r063d2 as info

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = nept.get_xyedges(position)

trial_epochs = get_trials(events, info.task_times["phase3"])
trial_idx = 2
start = trial_epochs[trial_idx].start
stop = trial_epochs[trial_idx].stop

In [None]:
plt.plot(position.x, position.y, "b.")
plt.show()

In [None]:
variance_thresh = 4

In [None]:
position_filename = os.path.join(dataloc, info.position_filename)
# events = nept.load_events(os.path.join(dataloc, info.event_filename), info.event_labels)
nvt_data = nept.load_nvt(position_filename)

In [None]:
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
def animate_trial(x, y, time, start, stop):
    
    position = nept.Position(np.hstack(np.array([x, y])[..., np.newaxis]), 
                             time)
    
    position = position.time_slice(start, stop)
    
    fig, ax = plt.subplots()

    xx, yy = np.meshgrid(xedges, yedges)

    pad_amount = 5
    ax.set_xlim((np.floor(np.min(position.x))-pad_amount, np.ceil(np.max(position.x))+pad_amount))
    ax.set_ylim((np.floor(np.min(position.y))-pad_amount, np.ceil(np.max(position.y))+pad_amount))

    plt.plot(position.x, position.y, ".", color="#bdbdbd")

    n_timebins = position.n_samples
    rat_position, = ax.plot([], [], "<", color="r")

    fig.tight_layout()


    def init():
        rat_position.set_data([], [])
        return rat_position


    def animate(i):
        rat_position.set_data(position.x[i], position.y[i])
        return rat_position

    anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=80, 
                                   blit=False, repeat=True)
    return anim

In [None]:
# Initialize x, y arrays
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)

In [None]:
# X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
# Briefly, each record contains up to 50 targets, each stored in 32bit field.
# X field at [20:31] and Y at [4:15].
# TODO: make into a separate function in nept
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # When the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]))

In [None]:
# How many targets are taken for each timepoint?
n_targets = []

for i in range(len(targets)):
    this_row = x[i]
    n_targets.append(len(np.where(this_row > 0.)[0]))

fig, ax = plt.subplots()
plt.hist(n_targets, bins=[1,2,3,4,5], align="left")
plt.xticks([1, 2, 3, 4], ["1", "2", "3", "4+"])
plt.xlabel("Number of targets")
plt.ylabel("Number of samples")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.tight_layout()
# plt.savefig(os.path.join(output_filepath, "n_targets.png"))
plt.show()

In [None]:
plt.plot(times, y, "g.", ms=3)
plt.xlabel("time")
plt.ylabel("y")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.tight_layout()
# plt.savefig(os.path.join(output_filepath, "raw.png"))
plt.show()

In [None]:
# Replacing targets with no samples with nan instead of 0
x[x == 0] = np.nan
y[y == 0] = np.nan

In [None]:
# Scale the positions
x = x / info.scale_targets[0]
y = y / info.scale_targets[1]

In [None]:
# Removing the sample that is more than std + buffer from the mean of the targets for both x and y
epsilon = 0.01
targets_x_mean = np.nanmean(x, axis=1)[:, np.newaxis]
targets_x_std = np.nanstd(x, axis=1)[:, np.newaxis] + epsilon
keep_x_idx = np.abs(x - targets_x_mean) < targets_x_std
x[~keep_x_idx] = np.nan

targets_y_mean = np.nanmean(y, axis=1)[:, np.newaxis]
targets_y_std = np.nanstd(y, axis=1)[:, np.newaxis] + epsilon
keep_y_idx = np.abs(y - targets_y_mean) < targets_y_std
y[~keep_y_idx] = np.nan

In [None]:
# Get the feeder locations
feeder_x, feeder_y = sort_led_locations(info, events, times)

In [None]:
# One target is contaminated when the distance between the targets is large
target_x_var = np.nanvar(x, axis=1)
target_y_var = np.nanvar(y, axis=1)

In [None]:
# Contaminated samples are using the feeder LED instead of the implant LEDs
contaminated_x_idx = target_x_var > variance_thresh
contaminated_y_idx = target_y_var > variance_thresh

In [None]:
# Initialize cleaned xy arrays with first target
# xx = np.array(x[:, 0])
# yy = np.array(y[:, 0])

In [None]:
# Removing the contaminated samples that are closest to the feeder location  
x_arrays = x[contaminated_x_idx]
x_values = feeder_x[contaminated_x_idx]
for array, value in zip(x_arrays, x_values):
    nidx = np.nanargmin(np.abs(array - value))
    array[nidx] = np.nan
x[contaminated_x_idx] = x_arrays

y_arrays = y[contaminated_y_idx]
y_values = feeder_y[contaminated_y_idx]
for array, value in zip(y_arrays, y_values):
    nidx = np.nanargmin(np.abs(array - value))
    array[nidx] = np.nan
y[contaminated_y_idx] = y_arrays

In [None]:
# Calculating the mean of the remaining targets
xx = np.nanmean(x, axis=1)
yy = np.nanmean(y, axis=1)

In [None]:
plt.plot(times, xx, "r.", ms=3)
plt.show()

In [None]:
idx = 26000
plt.plot(times[idx:idx+600], xx[idx:idx+600], "r.", ms=3)
plt.show()

In [None]:
# Remove jumps to feeder location
nojump_x, nojump_y, ttimes = remove_jumps_to_feeder(xx, yy, times, info, jump_thresh=10, dist_thresh=6.)

In [None]:
print("total removed:", len(xx) - len(nojump_x))

In [None]:
plt.plot(ttimes, nojump_y, "b.", ms=3)
plt.show()

In [None]:
idx = 80000
plt.plot(ttimes[idx:idx+5000], nojump_x[idx:idx+5000], "r.", ms=3)
plt.show()

In [None]:
position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)
position = position.time_slice(start, stop)

plt.plot(position.x, position.y, "k.")
plt.show()

In [None]:
# Check out animation for single trial
anim = animate_trial(nojump_x, nojump_y, ttimes, start, stop)
HTML(anim.to_html5_video())

In [None]:
# Apply a median filter
xx, yy = median_filter(nojump_x, nojump_y)

In [None]:
plt.plot(ttimes, yy, "k.", ms=3)
plt.show()

In [None]:
yy[:10] == nojump_y[:10]

In [None]:
# Construct a nept.Position object
position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

In [None]:
# writer = animation.writers['ffmpeg'](fps=18)
# dpi = 600
# anim.save(os.path.join(output_filepath, "updated_corrected-position-animation.mp4"), writer=writer, dpi=dpi)

# plt.close()

In [None]:
# Check out animation for single trial
anim = animate_trial(position.x, position.y, position.time, start, stop)
HTML(anim.to_html5_video())

In [None]:
plt.plot(position.time, position.x, "k.", ms=3)
plt.show()

In [None]:
plt.plot(position.time, position.y, "b.", ms=3)
plt.xlabel("time")
plt.ylabel("y")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
plt.tight_layout()
# plt.savefig(os.path.join(output_filepath, "updated_corrected-position.png"))
plt.show()

In [None]:
position.n_samples