In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from notebooks.imports import *
from pathlib import Path
from scipy.ndimage import gaussian_filter1d

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from config import dir_config, ephys_config

compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

### Extract spike timeseries aligned to specific event for every unit

### utils

In [25]:
# Function to align and convolve spike trains
def get_aligned_spike_trains(cluster_spike_time, timestamps, trial_info, alignment_settings, alignment_buffer, sampling_rate=30, sigma=10):
	n_trials = len(trial_info)
	duration = 30000

	spike_trains = np.zeros((n_trials, duration), dtype=np.float32) * np.nan
	convolved_spike_trains = np.zeros((n_trials, duration), dtype=np.float32) * np.nan

	start_time_ms = -500
	end_time_ms = 1000
	start_alignment_event = "fixation_onset"
	end_alignment_event = "response_onset"

	result = {
		"events_relative_timing": {
			"fixation_onset": np.full(n_trials, np.nan),
			"target_onset": np.full(n_trials, np.nan),
			"stimulus_onset": np.full(n_trials, np.nan),
			"go_onset": np.full(n_trials, np.nan),
			"response_onset": np.full(n_trials, np.nan),
		}
	}

	# Iterate through trials
	for idx_trial, trial_num in enumerate(trial_info.index):
		if np.isnan(trial_info.reaction_time[trial_num]):
			spike_trains[idx_trial, :] *= np.nan
			convolved_spike_trains[idx_trial, :] *= np.nan
			continue

		start_timestamp = timestamps.loc[trial_num, start_alignment_event] + start_time_ms * sampling_rate
		end_timestamp = timestamps.loc[trial_num, end_alignment_event] + end_time_ms * sampling_rate

		# fill with 0 from start to end
		spike_trains[idx_trial, : np.ceil((end_timestamp - start_timestamp) / sampling_rate).astype(int)] = 0
		convolved_spike_trains[idx_trial, : np.ceil((end_timestamp - start_timestamp) / sampling_rate).astype(int)] = 0

		# Filter spike times from start_timestamp to end_timestamp
		temp_spike_times = cluster_spike_time[(cluster_spike_time >= start_timestamp) & (cluster_spike_time <= end_timestamp)] - start_timestamp

		spike_idx = np.ceil(temp_spike_times / sampling_rate).astype(int)
		spike_trains[idx_trial, spike_idx] = 1

		# Convolve spike trains
		convolved_spike_trains[idx_trial, :] = gaussian_filter1d(spike_trains[idx_trial, :], sigma=sigma, truncate=3)

		# fill in event timing relative to start_timestamp
		for event_name in result["events_relative_timing"].keys():
			result["events_relative_timing"][event_name][idx_trial] = np.ceil((timestamps.loc[trial_num, event_name] - start_timestamp) / sampling_rate).astype(int)

	# Store results
	result["spike_trains"] = spike_trains
	result["convolved_spike_trains"] = convolved_spike_trains * 1000

	return result

In [26]:
# Load neuron metadata
neuron_metadata = pd.read_csv(Path(compiled_dir, "neuron_metadata.csv"), index_col=None)
ephys_neuron_wise = {event["alignment_event"]: {} for event in ephys_config.alignment_settings_GP}

# Main loop for each neuron
for neuron in neuron_metadata.neuron_id:
	session_name = neuron_metadata.session_id[neuron - 1]
	cluster_id = neuron_metadata.cluster[neuron - 1]

	# Load required data
	timestamps_path = Path(compiled_dir, session_name, f"{session_name}_timestamps.csv")
	trial_info_path = Path(compiled_dir, session_name, f"{session_name}_trial.csv")
	spike_times_path = Path(compiled_dir, session_name, "spike_times.npy")
	spike_clusters_path = Path(compiled_dir, session_name, "spike_clusters.npy")
	spike_times_mat_path = Path(compiled_dir, session_name, "spike_times.mat")
	spike_clusters_mat_path = Path(compiled_dir, session_name, "spike_clusters.mat")

	if not (timestamps_path.is_file() and trial_info_path.is_file()):
		print(f"Missing files for session: {session_name}")
		continue

	timestamps = pd.read_csv(timestamps_path, index_col=None)
	trial_info = pd.read_csv(trial_info_path, index_col=None)

	# Load spike data
	if spike_times_path.is_file() and spike_clusters_path.is_file():
		spike_times = np.load(spike_times_path)
		spike_clusters = np.load(spike_clusters_path)
	elif spike_times_mat_path.is_file() and spike_clusters_mat_path.is_file():
		spike_times = scipy.io.loadmat(spike_times_mat_path)["spike_times"].ravel()
		spike_clusters = scipy.io.loadmat(spike_clusters_mat_path)["spike_clusters"].ravel()
	else:
		print(f"Spike times and clusters not found in {session_name} for neuron {neuron}")
		continue

	# Filter spike times for the current cluster
	cluster_spike_time = spike_times[spike_clusters == cluster_id]
	GP_trial_info = trial_info[(trial_info.task_type == 1)]  # & (~np.isnan(trial_info.reaction_time))]

	# Get aligned and convolved spike trains
	results = get_aligned_spike_trains(cluster_spike_time, timestamps, GP_trial_info, ephys_config.alignment_settings_GP, ephys_config.alignment_buffer)

	# Save results
	ephys_neuron_wise[neuron] = {"spike_trains": results["spike_trains"], "convolved_spike_trains": results["convolved_spike_trains"], "event_relative_timing": results["events_relative_timing"], "trial_number": GP_trial_info.trial_number}

In [27]:
import pickle

with open(Path(processed_dir, "ephys_neuron_wise_whole_trial.pkl"), "wb") as handle:
	pickle.dump(ephys_neuron_wise, handle, protocol=pickle.HIGHEST_PROTOCOL)