In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import scipy

from loading_data import get_data, extract_xy, median_filter
from utils_maze import get_trials

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots")

In [None]:
import info.r066d6 as r066d6
info = r066d6

In [None]:
events, position, _, _, _ = get_data(info)

In [None]:
# Load raw position from file
filename = os.path.join(dataloc, info.position_filename)
nvt_data = nept.load_nvt(filename)
targets = nvt_data['targets']
times = nvt_data['time']

In [None]:
# Initialize x, y arrays
x = np.zeros(targets.shape)
y = np.zeros(targets.shape)

# X and Y are stored in a custom bitfield. See Neuralynx data file format documentation for details.
# Briefly, each record contains up to 50 targets, each stored in 32bit field.
# X field at [20:31] and Y at [4:15].
# TODO: make into a separate function in nept
for target in range(targets.shape[1]):
    this_sample = targets[:, target]
    for sample in range(targets.shape[0]):
        # When the bitfield is equal to zero there is no valid data for that field
        # and remains zero for the rest of the bitfields in the record.
        if this_sample[target] == 0:
            break
        x[sample, target], y[sample, target] = extract_xy(int(this_sample[sample]))

# Replacing targets with no samples with nan instead of 0
x[x == 0] = np.nan
y[y == 0] = np.nan

# # Scale the positions
x = x / info.scale_targets[0]
y = y / info.scale_targets[1]

In [None]:
plt.plot(times, y, "b.")
plt.show()

In [None]:
print("Feeder1")
xmode, xcount = scipy.stats.mode(x, axis=None, nan_policy='omit')
ymode, ycount = scipy.stats.mode(y, axis=None, nan_policy='omit')
print("xy:", xmode, ymode)

In [None]:
print("Feeder2")
xxmode, xxcount = scipy.stats.mode(x[x != xmode], axis=None, nan_policy='omit')
yymode, yycount = scipy.stats.mode(y[y != ymode], axis=None, nan_policy='omit')
print("xy:", xxmode, yymode)

In [None]:
plt.plot(position.x, position.y, "y.", ms=2)
cornerx = 70
cornery = 20
plt.plot(cornerx, cornery, "r.", ms=10)
plt.plot(cornerx, cornery+(35*3), "r.", ms=10)
plt.plot(cornerx+(35*4), cornery, "r.", ms=10)

plt.plot(208., 18.66, "c.", ms=10)
plt.plot(174.67, 155.33, "c.", ms=10)

plt.plot(74, 25, "k.", ms=10)
plt.plot(74, 115, "k.", ms=10)
plt.plot(175, 125, "k.", ms=10)

plt.plot(115, 60, "k.", ms=10)

plt.plot(175, 80, "k.", ms=10)
plt.plot(110, 90, "k.", ms=10)

plt.plot(175, 125, "k.", ms=10)
plt.plot(65, 160, "k.", ms=10)

plt.plot(130, 17.5, "k.", ms=10)
plt.show()

In [None]:
led_padding = 1
# Finding which feeder led is on over time
leds = []
leds.extend([(event, 'led1') for event in events['led1']])
leds.extend([(event, 'led2') for event in events['led2']])
sorted_leds = sorted(leds)

# Get an array of feeder locations when that feeder is actively flashing
feeder_x_location = np.empty(times.shape[0]) * np.nan
feeder_y_location = np.empty(times.shape[0]) * np.nan

ledoff = events["ledoff"]
off_idx = 0

for time, label in sorted_leds:
    x_location = info.path_pts['feeder1'][0] if label == 'led1' else info.path_pts['feeder2'][0]
    y_location = info.path_pts['feeder1'][1] if label == 'led1' else info.path_pts['feeder2'][1]
    
    # Find next off idx
    while ledoff[off_idx] < time and off_idx < len(ledoff):
        off_idx += 1
    
    # Discount the last event when last off missing
    if off_idx >= len(ledoff):
        break

    start = nept.find_nearest_idx(times, time)
    stop = nept.find_nearest_idx(times, ledoff[off_idx])
    feeder_x_location[start:stop+led_padding] = x_location
    feeder_y_location[start:stop+led_padding] = y_location


# Removing the contaminated samples that are closest to the feeder location
def remove_feeder_contamination(original_targets, current_feeder, dist_to_feeder=40):
    targets = np.array(original_targets)
    for i, (target, feeder) in enumerate(zip(targets, current_feeder)):
        if not np.isnan(feeder):
            dist = np.abs(target - feeder) < dist_to_feeder
            target[dist] = np.nan
        targets[i] = target
    return targets

x = remove_feeder_contamination(x, feeder_x_location)
y = remove_feeder_contamination(y, feeder_y_location)


# Removing the problem samples that are furthest from the previous location
def remove_based_on_std(original_targets, std_thresh=2):
    targets = np.array(original_targets)
    stds = np.nanstd(targets, axis=1)[:, np.newaxis]
    
    # find idx where there is a large variation between targets
    problem_samples = np.where(stds > std_thresh)[0]
    
    for i in problem_samples:
        # find the previous mean to help determine which target is an issue
        previous_idx = i-1
        previous_mean = np.nanmean(targets[previous_idx])
        
        # if previous sample is nan, compare current sample to the one before that
        while np.isnan(previous_mean):
            previous_idx -= 1
            previous_mean = np.nanmean(targets[previous_idx])
        
        # remove problem target
        idx = np.nanargmax(np.abs(targets[i] - previous_mean))
        targets[i][idx] = np.nan
    
    return targets

x = remove_based_on_std(x)
y = remove_based_on_std(y)

# Calculating the mean of the remaining targets
xx = np.nanmean(x, axis=1)
yy = np.nanmean(y, axis=1)
ttimes = times

# Finding nan idx
x_nan_idx = np.isnan(xx)
y_nan_idx = np.isnan(yy)
nan_idx = x_nan_idx | y_nan_idx

# Removing nan samples
xx = xx[~nan_idx]
yy = yy[~nan_idx]
ttimes = ttimes[~nan_idx]

# Apply a median filter
xx, yy = median_filter(xx, yy, kernel=11)

position = nept.Position(np.hstack(np.array([xx, yy])[..., np.newaxis]), ttimes)

plt.plot(position.x, position.y, "k.", ms=2)
plt.show()

In [None]:
plt.plot(position.time, position.y, "k.", ms=2)
plt.show()

In [None]:
trial_epochs = get_trials(events, info.task_times["phase3"])
trial_idx = 1
start = trial_epochs[trial_idx].start
stop = trial_epochs[trial_idx].stop

trial = position.time_slice(start, stop)
plt.plot(trial.x, trial.y, "k.")
plt.show()