In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import scipy

from loading_data import get_data, extract_xy, median_filter
from utils_maze import get_trials

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "correcting_position")

In [None]:
import info.r066d6 as r066d6
info = r066d6

In [None]:
events = nept.load_events(os.path.join(dataloc, info.event_filename), info.event_labels)

In [None]:
# Load raw position from file
filename = os.path.join(dataloc, info.position_filename)
nvt_data = nept.load_nvt(filename)
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
# Initialize x, y arrays
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)

# X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
# Briefly, each record contains up to 50 targets, each stored in 32bit field.
# X field at [20:31] and Y at [4:15].
# TODO: make into a separate function in nept
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # When the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]))

# Replacing targets with no samples with nan instead of 0
x[x == 0] = np.nan
y[y == 0] = np.nan

# # Scale the positions
x /= info.scale_targets
y /= info.scale_targets

In [None]:
plt.plot(x, y, "b.")
plt.show()

In [None]:
plt.plot(times, y, "b.")
plt.show()

In [None]:
# Finding times when a feeder LED is active
led_padding = 1

ledon = sorted(np.append(events['led1'], events['led2'], 0))
ledoff = events["ledoff"]

In [None]:
led_idx = []
for start, stop in zip(ledon, ledoff):
    led_idx.append((times >= start) & (times <= stop))
led_idx = np.any(np.column_stack(led_idx), axis=1)

In [None]:
evt = nept.Epoch([ledon, ledoff])

In [None]:
led_idx[9000:10000]

In [None]:
led_padding = 1
# Finding which feeder led is on over time
# TODO: Optimize for speed
leds = []
leds.extend([(event, 'led1') for event in events['led1']])
leds.extend([(event, 'led2') for event in events['led2']])
sorted_leds = sorted(leds)

# Get an array of feeder locations when that feeder is actively flashing
feeder_x_location = np.empty(times.shape[0]) * np.nan
feeder_y_location = np.empty(times.shape[0]) * np.nan

ledoff = events["ledoff"]
off_idx = 0

for time, label in sorted_leds:
    x_location = info.path_pts['feeder2'][0] if label == 'led1' else info.path_pts['feeder1'][0]
    y_location = info.path_pts['feeder2'][1] if label == 'led1' else info.path_pts['feeder1'][1]
    
    # Find next off idx
    while ledoff[off_idx] < time and off_idx < len(ledoff):
        off_idx += 1
    
    # Discount the last event when last off missing
    if off_idx >= len(ledoff):
        break

    start = nept.find_nearest_idx(times, time)
    stop = nept.find_nearest_idx(times, ledoff[off_idx])
    feeder_x_location[start:stop+led_padding] = x_location
    feeder_y_location[start:stop+led_padding] = y_location




In [None]:
# Removing the contaminated samples that are closest to the feeder location
def remove_feeder_contamination(targets, current_feeder, dist_to_feeder=20):
    targets = np.array(targets)
    
    remove_idx = np.abs(targets - current_feeder[:, np.newaxis]) < dist_to_feeder
    targets[remove_idx] = np.nan
    
    return targets
    
x = remove_feeder_contamination(x, feeder_x_location)
y = remove_feeder_contamination(y, feeder_y_location)

In [None]:
plt.plot(times, y, "b.")
plt.show()

In [None]:
# Removing the problem samples that are furthest from the previous location
def remove_based_on_std(original_targets, std_thresh=2):
    targets = np.array(original_targets)
    stds = np.nanstd(targets, axis=1)[:, np.newaxis]
    
    # find idx where there is a large variation between targets
    problem_samples = np.where(stds > std_thresh)[0]
    
    for i in problem_samples:
        # find the previous mean to help determine which target is an issue
        previous_idx = i-1
        previous_mean = np.nanmean(targets[previous_idx])
        
        # if previous sample is nan, compare current sample to the one before that
        while np.isnan(previous_mean):
            previous_idx -= 1
            previous_mean = np.nanmean(targets[previous_idx])
        
        # remove problem target
        idx = np.nanargmax(np.abs(targets[i] - previous_mean))
        targets[i][idx] = np.nan
    
    return targets

x = remove_based_on_std(x)
y = remove_based_on_std(y)

# Calculating the mean of the remaining targets
x = np.nanmean(x, axis=1)
y = np.nanmean(y, axis=1)


def interpolate(time, array, nan_idx):
    f = scipy.interpolate.interp1d(time[~nan_idx], array[~nan_idx], kind='linear', bounds_error=False)
    array[nan_idx] = f(time[nan_idx])

# Interpolate positions to replace nans during maze phases
xx = np.array(x)
yy = np.array(y)
ttimes = np.array(times)

maze_phases = ["phase1", "phase2", "phase3"]
for task_time in info.task_times.keys():
    if task_time in maze_phases:
        trial_epochs = get_trials(events, info.task_times[task_time])
        for start, stop in zip(trial_epochs.starts, trial_epochs.stops):
            idx = (times >= start) & (times < stop)
            
            this_x = x[idx]
            this_y = y[idx]
            this_times = times[idx]

            # Finding nan idx
            x_nan_idx = np.isnan(this_x)
            y_nan_idx = np.isnan(this_y)
            nan_idx = x_nan_idx | y_nan_idx

            interpolate(this_times, this_x, nan_idx)
            interpolate(this_times, this_y, nan_idx)

            xx[idx] = this_x
            yy[idx] = this_y

# Finding nan idx
x_nan_idx = np.isnan(xx)
y_nan_idx = np.isnan(yy)
nan_idx = x_nan_idx | y_nan_idx

# Removing nan idx
xx = xx[~nan_idx]
yy = yy[~nan_idx]
ttimes = ttimes[~nan_idx]

# Apply a median filter
xx, yy = median_filter(xx, yy, kernel=11)

position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

plt.plot(position.x, position.y, "k.", ms=2)
plt.show()

In [None]:
# Determining the percent of samples used
len(xx) / len(targets) * 100

In [None]:
plt.plot(position.time, position.y, "k.", ms=2)
plt.show()

In [None]:
x.shape

In [None]:
def plot_trials(info, position, events, filepath):
    for phase in ["phase1", "phase2", "phase3"]:
        trial_epochs = get_trials(events, info.task_times[phase])
        for trial_idx in range(trial_epochs.n_epochs):
            start = trial_epochs[trial_idx].start
            stop = trial_epochs[trial_idx].stop

            trial = position.time_slice(start, stop)
            plt.plot(trial.x, trial.y, "k.")
            title = info.session_id + " " + phase + " trial" + str(trial_idx)
            plt.title(title)
            if filepath is not None:
                plt.savefig(os.path.join(filepath, title))
                plt.close()
            else:
                plt.show()
                plt.close()

In [None]:
plot_trials(info, position, events, filepath=None)

In [None]:
w = np.where(np.diff(position.time) > 15.)[0]

In [None]:
i = 7
position.time[w[i]-2:w[i]+2]

In [None]:
position.time[w[i]:w[i]+10]

In [None]:
idx = np.insert(np.diff(position.time) > 10., 0, False)

In [None]:
plt.plot(position.time, position.y, "y.", ms=4)
plt.plot(position.time[idx], position.y[idx], "k.", ms=4)
plt.plot(3122.85, 40, "r.", ms=10)
plt.show()

In [None]:
phase="phase2"

In [None]:
start = info.task_times[phase].start
stop = info.task_times[phase].stop

sliced_position = position.time_slice(start, stop)

In [None]:
plt.plot(sliced_position.x, sliced_position.y, "g.")
plt.show()

In [None]:
plt.plot(sliced_position.time[:5000], sliced_position.y[:5000], "g.")
plt.plot(3155, 65, "r.", ms=10)
plt.show()

In [None]:
trial_epochs = get_trials(events, info.task_times[phase])
for idx in range(trial_epochs.n_epochs):
    start = trial_epochs[idx].start
    stop = trial_epochs[idx].stop

    trial = position.time_slice(start, stop)
    plt.plot(trial.x, trial.y, "g.")
    plt.show()