In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import scipy

from matplotlib import animation
from IPython.display import HTML

from loading_data import get_data, unzip_nvt_file, zip_nvt_file, extract_xy, median_filter, plot_correcting_position
from utils_maze import get_trials

import warnings
warnings.filterwarnings("ignore")

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "correcting_position")

In [None]:
import info.r066d7 as info
# import info.r063d8 as info

In [None]:
events, position, _, _, _ = get_data(info)

In [None]:
# Load raw position from file
filename = os.path.join(dataloc, info.position_filename)
nvt_data = nept.load_nvt(filename)
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
# Initialize x, y arrays
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)

# X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
# Briefly, each record contains up to 50 targets, each stored in 32bit field.
# X field at [20:31] and Y at [4:15].
# TODO: make into a separate function in nept
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # When the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]))

# Replacing targets with no samples with nan instead of 0
x[x == 0] = np.nan
y[y == 0] = np.nan

# Scale the positions
x /= info.scale_targets
y /= info.scale_targets

In [None]:
np.where(times > 6521.45853)

In [None]:
s = 159055
e = s+50
y[s:e]

In [None]:
plt.plot(times[s:e], y[s:e], "k.")
plt.savefig(os.path.join(output_filepath, info.session_id+" phase3trial18_time.png"))
plt.show()

In [None]:
85., 20

In [None]:
feeder_x_location[s:e]

In [None]:
plt.plot(x, y, "k.", ms=2)
plt.plot(85., 20., "r.", ms=10)
plt.savefig(os.path.join(output_filepath, info.session_id+" phase3trial18_full.png"))
# plt.show()

In [None]:
plt.plot(times, y, 'r.', ms=2)
plt.show()

In [None]:
np.where(position.time > 6513.185586)

In [None]:
plt.plot(position.time[123980:124020], position.y[123980:124020], 'b.', ms=3)
plt.show()

In [None]:
position.time[123980:124020]

In [None]:
off_delay = np.median(np.diff(times))
dist_thresh = 20

# Finding which feeder led is on over time
leds = []
leds.extend([(event, 'led1') for event in events['led1']])
leds.extend([(event, 'led2') for event in events['led2']])
sorted_leds = sorted(leds)

ledoff = events["ledoff"]

# Get an array of feeder locations when that feeder is actively flashing
feeder_x_location = np.empty(times.shape[0]) * np.nan
feeder_y_location = np.empty(times.shape[0]) * np.nan

off_idx = 0

for start, label in sorted_leds: 
    # Find next off idx
    while off_idx < len(ledoff) and ledoff[off_idx] < start:
        off_idx += 1
    
    # Discount the last event when last off missing
    if off_idx >= len(ledoff):
        break
        
    x_location = info.path_pts['feeder1'][0] if label == 'led2' else info.path_pts['feeder2'][0]
    y_location = info.path_pts['feeder1'][1] if label == 'led2' else info.path_pts['feeder2'][1]
        
    feeder_x_location[np.logical_and(times>=start-off_delay, times<ledoff[off_idx]+off_delay)] = x_location
    feeder_y_location[np.logical_and(times>=start-off_delay, times<ledoff[off_idx]+off_delay)] = y_location

# Remove idx when led is on and target is close to active feeder location
x_idx = np.abs(x - feeder_x_location[..., np.newaxis]) <= dist_thresh
y_idx = np.abs(y - feeder_y_location[..., np.newaxis]) <= dist_thresh
remove_idx = x_idx & y_idx

x[remove_idx] = np.nan
y[remove_idx] = np.nan

# Remove samples close to error location (impossible locations)
if 'error' in info.path_pts.keys():
    for error_pt in info.path_pts['error']:
        x_idx = np.abs(x - error_pt[0]) <= 20.
        y_idx = np.abs(y - error_pt[1]) <= 20.
        remove_idx = x_idx & y_idx

        x[remove_idx] = np.nan
        y[remove_idx] = np.nan

In [None]:
# Removing the problem samples that are furthest from the previous location
def remove_based_on_std(original_targets, std_thresh=2):
    targets = np.array(original_targets)
    stds = np.nanstd(targets, axis=1)[:, np.newaxis]
    
    # find idx where there is a large variation between targets
    problem_samples = np.where(stds > std_thresh)[0]
    
    for i in problem_samples:
        # find the previous mean to help determine which target is an issue
        previous_idx = i-1
        previous_mean = np.nanmean(targets[previous_idx])
        
        # if previous sample is nan, compare current sample to the one before that
        while np.isnan(previous_mean):
            previous_idx -= 1
            previous_mean = np.nanmean(targets[previous_idx])
        
        # remove problem target
        idx = np.nanargmax(np.abs(targets[i] - previous_mean))
        targets[i][idx] = np.nan
    
    return targets

x = remove_based_on_std(x)
y = remove_based_on_std(y)

In [None]:
# Calculating the mean of the remaining targets
x = np.nanmean(x, axis=1)
y = np.nanmean(y, axis=1)


# Interpolating for nan samples
def interpolate(time, array, nan_idx):
    f = scipy.interpolate.interp1d(time[~nan_idx], array[~nan_idx], kind='linear', bounds_error=False)
    array[nan_idx] = f(time[nan_idx])

xx = np.array(x)
yy = np.array(y)
ttimes = np.array(times)

maze_phases = ["phase1", "phase2", "phase3"]
for task_time in info.task_times.keys():
    # Interpolate positions to replace nans during maze phases
    if task_time in maze_phases:
        trial_epochs = get_trials(events, info.task_times[task_time])
        for start, stop in zip(trial_epochs.starts, trial_epochs.stops):
            idx = (times >= start) & (times < stop)

            this_x = x[idx]
            this_y = y[idx]
            this_times = times[idx]

            # Finding nan idx
            x_nan_idx = np.isnan(this_x)
            y_nan_idx = np.isnan(this_y)
            nan_idx = x_nan_idx | y_nan_idx

            interpolate(this_times, this_x, nan_idx)
            interpolate(this_times, this_y, nan_idx)

            xx[idx] = this_x
            yy[idx] = this_y

# Finding nan idx
x_nan_idx = np.isnan(xx)
y_nan_idx = np.isnan(yy)
nan_idx = x_nan_idx | y_nan_idx

# Removing nan idx
xx = xx[~nan_idx]
yy = yy[~nan_idx]
ttimes = ttimes[~nan_idx]

# Apply a median filter
xx, yy = median_filter(xx, yy, kernel=11)

position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

plt.plot(position.x, position.y, "k.", ms=2)
plt.show()

In [None]:
plt.plot(position.x, position.y, "k.", ms=2)
plt.plot(135., 6., "r.", ms=10)
plt.plot(85., 20., "r.", ms=10)
plt.plot(50., 15., "r.", ms=10)
plt.show()

In [None]:
rr = position.time_slice(6513.185586, 6590.)
plt.plot(rr.time, rr.y, "k.")
plt.show()

In [None]:
rr.x[rr.y < 80]

In [None]:
plt.plot(position.time, position.y, "k.", ms=2)
plt.show()

In [None]:
# Determining the percent of samples used
position.n_samples / targets.shape[0] * 100

In [None]:
def plot_trials(info, position, events, filepath):
    for phase in ["phase1", "phase2", "phase3"]:
        trial_epochs = get_trials(events, info.task_times[phase])
        for trial_idx in range(trial_epochs.n_epochs):
            start = trial_epochs[trial_idx].start
            stop = trial_epochs[trial_idx].stop

            trial = position.time_slice(start, stop)
            print(start, stop)
            plt.plot(trial.x, trial.y, "k.")
            title = info.session_id + " " + phase + " trial" + str(trial_idx)
            plt.title(title)
            if filepath is not None:
                plt.savefig(os.path.join(filepath, title))
                plt.close()
            else:
                plt.show()
                plt.close()

In [None]:
plot_trials(info, position, events, filepath=None)

In [None]:
plot_correcting_position(info, position, targets, events)