In [1]:
import numpy as np
import os
import pandas as pd
import tarfile
from scipy import signal
from scipy.stats import zscore
from icecream import ic
import neurodsp.filt as dsp
import plotly.express as px
import plotly.graph_objs as go
import dask.array as da
import dask
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster
import multiprocessing
import sys
from src.file_ops import npy_loader, get_probe_signals
from dataclasses import dataclass
from typing import Optional
import numpy as np
import tarfile
import os
from pathlib import Path


In [17]:
# Set the current working directory
cwd = os.chdir(r"C:\Python Work Directory\NMA_Impact_Scholars_Steinmetz")

# @title Data retrieval
data_directory = r'data\spikeAndBehavioralData'

# test_dataset
test_LFP = r"Cori_2016-12-17"

In [18]:
walker = os.walk(os.path.join(os.getcwd(),data_directory))
for root, dirs, files in walker:
    print(root)
    print(dirs)
    for file in files:
        print(file)

C:\Python Work Directory\NMA_Impact_Scholars_Steinmetz\data\spikeAndBehavioralData
[]
Cori_2016-12-14.tar
Cori_2016-12-17.tar
Cori_2016-12-18.tar
Forssmann_2017-11-01.tar
Forssmann_2017-11-02.tar
Forssmann_2017-11-04.tar
Forssmann_2017-11-05.tar
Hench_2017-06-15.tar
Hench_2017-06-16.tar
Hench_2017-06-17.tar
Hench_2017-06-18.tar
Lederberg_2017-12-05.tar
Lederberg_2017-12-06.tar
Lederberg_2017-12-07.tar
Lederberg_2017-12-08.tar
Lederberg_2017-12-09.tar
Lederberg_2017-12-10.tar
Lederberg_2017-12-11.tar
Moniz_2017-05-15.tar
Moniz_2017-05-16.tar
Moniz_2017-05-18.tar
Muller_2017-01-07.tar
Muller_2017-01-08.tar
Muller_2017-01-09.tar
Radnitz_2017-01-08.tar
Radnitz_2017-01-09.tar
Radnitz_2017-01-10.tar
Radnitz_2017-01-11.tar
Radnitz_2017-01-12.tar
Richards_2017-10-29.tar
Richards_2017-10-30.tar
Richards_2017-10-31.tar
Richards_2017-11-01.tar
Richards_2017-11-02.tar
Tatum_2017-12-06.tar
Tatum_2017-12-07.tar
Tatum_2017-12-08.tar
Tatum_2017-12-09.tar
Theiler_2017-10-11.tar


In [19]:
def extract_spikes_data(filename):
    with tarfile.open(filename) as tar:
        spikes = [name for name in tar.getnames() if name.startswith('spikes')]

In [32]:
@dataclass
class Clusters:
    depths: Optional[np.ndarray] = None
    original_ids: Optional[np.ndarray] = None
    peak_channel: Optional[np.ndarray] = None
    probes: Optional[np.ndarray] = None
    template_waveform_chans: Optional[np.ndarray] = None
    template_waveforms: Optional[np.ndarray] = None
    waveform_duration: Optional[np.ndarray] = None
    phy_annotation: Optional[np.ndarray] = None


    def to_dataframe(self) -> pd.DataFrame:
        """
        Convert the cluster data to a pandas DataFrame.
        For multi-dimensional arrays, only the first dimension is used as the index,
        and the remaining dimensions are stored as array objects in the cells.
        """
        data_dict = {}
        base_length = None

        # Process each attribute
        for attr_name, value in self.__dict__.items():
            if value is not None:
                if len(value.shape) == 1:
                    # 1D arrays can be directly added
                    data_dict[attr_name] = value
                    if base_length is None:
                        base_length = len(value)
                else:
                    # For multi-dimensional arrays, store them as objects
                    # Each row will contain a slice of the array
                    data_dict[attr_name] = [value[i] for i in range(value.shape[0])]
                    if base_length is None:
                        base_length = value.shape[0]

        # Create DataFrame
        df = pd.DataFrame(data_dict)
        return df

    @classmethod
    def from_tar(cls, tar_path: str | Path) -> 'Clusters':
        """
        Load cluster data from a tar file containing numpy arrays.

        Args:
            tar_path: Path to the tar file containing cluster data

        Returns:
            ClusterData instance with loaded arrays

        Raises:
            FileNotFoundError: If tar file doesn't exist
            ValueError: If expected cluster files are missing
        """
        # if not os.path.exists(tar_path):
        #     raise FileNotFoundError(f"Tar file not found: {tar_path}")

        data = cls()

        with tarfile.open(tar_path, 'r') as tar:
            cluster_files = [name for name in tar.getnames() if name.startswith('clusters')]

            # Mapping between file names and dataclass attributes
            file_attr_map = {
                'clusters.depths.npy': 'depths',
                'clusters.originalIDs.npy': 'original_ids',
                'clusters.peakChannel.npy': 'peak_channel',
                'clusters.probes.npy': 'probes',
                'clusters.templateWaveformChans.npy': 'template_waveform_chans',
                'clusters.templateWaveforms.npy': 'template_waveforms',
                'clusters.waveformDuration.npy': 'waveform_duration',
                'clusters._phy_annotation.npy': 'phy_annotation'
            }

            # Extract and load each file
            for file_name, attr_name in file_attr_map.items():
                if file_name not in cluster_files:
                    print(f"Warning: {file_name} not found in tar archive")
                    continue

                # try:
                #     # Extract file to memory and load with numpy
                #     member = tar.extractfile(file_name)
                #     if member is None:
                #         raise ValueError(f"Could not extract {file_name}")

                array_data = npy_loader(tar,file_name)
                setattr(data, attr_name, array_data)

                # except Exception as e:
                #     print(f"Error loading {file_name}: {str(e)}")

        return data

@dataclass
class Trials:
    feedback_type: Optional[np.ndarray] = None
    feedback_times: Optional[np.ndarray] = None
    gocue_times: Optional[np.ndarray] = None
    included: Optional[np.ndarray] = None
    intervals: Optional[np.ndarray] = None
    repNum: Optional[np.ndarray] = None
    response_choice: Optional[np.ndarray] = None
    response_times: Optional[np.ndarray] = None
    contrast_left: Optional[np.ndarray] = None
    constra_right: Optional[np.ndarray] = None
    stimulus_times: Optional[np.ndarray] = None


    def to_dataframe(self) -> pd.DataFrame:
        """
        Convert the Trials data to a pandas DataFrame.
        For multi-dimensional arrays, only the first dimension is used as the index,
        and the remaining dimensions are stored as array objects in the cells.
        """
        data_dict = {}
        base_length = None

        # Process each attribute
        for attr_name, value in self.__dict__.items():
            if value is not None:
                if len(value.shape) == 1:
                    # 1D arrays can be directly added
                    data_dict[attr_name] = value
                    if base_length is None:
                        base_length = len(value)
                else:
                    # For multi-dimensional arrays, store them as objects
                    # Each row will contain a slice of the array
                    data_dict[attr_name] = [value[i] for i in range(value.shape[0])]
                    if base_length is None:
                        base_length = value.shape[0]

        # Create DataFrame
        df = pd.DataFrame(data_dict)
        return df

    @classmethod
    def from_tar(cls, tar_path: str | Path) -> 'Trials':
        """
        Load cluster data from a tar file containing numpy arrays.

        Args:
            tar_path: Path to the tar file containing cluster data

        Returns:
            ClusterData instance with loaded arrays

        Raises:
            FileNotFoundError: If tar file doesn't exist
            ValueError: If expected cluster files are missing
        """
        # if not os.path.exists(tar_path):
        #     raise FileNotFoundError(f"Tar file not found: {tar_path}")

        data = cls()

        with tarfile.open(tar_path, 'r') as tar:
            trial_files = [name for name in tar.getnames() if name.startswith('trials')]

            # Mapping between file names and dataclass attributes
            file_attr_map = {
                'trials.feedbackType.npy': 'feedback_type',
                'trials.feedback_times.npy': 'feedback_times',
                'trials.goCue_times.npy': 'gocue_times',
                'trials.included.npy': 'included',
                'trials.intervals.npy': 'intervals',
                'trials.repNum.npy': 'repNum',
                'trials.response_choice.npy': 'response_choice',
                'trials.response_times.npy': 'response_times',
                'trials.visualStim_contrastLeft.npy': 'contrast_left',
                'trials.visualStim_contrastRight.npy': 'contrast_right',
                'trials.visualStim_times.npy': 'stimulus_times',
            }

            # Extract and load each file
            for file_name, attr_name in file_attr_map.items():
                if file_name not in trial_files:
                    print(f"Warning: {file_name} not found in tar archive")
                    continue

                # try:
                #     # Extract file to memory and load with numpy
                #     member = tar.extractfile(file_name)
                #     if member is None:
                #         raise ValueError(f"Could not extract {file_name}")

                array_data = npy_loader(tar,file_name)
                setattr(data, attr_name, array_data)

                # except Exception as e:
                #     print(f"Error loading {file_name}: {str(e)}")

        return data

@dataclass
class Spikes:
    amps: Optional[np.ndarray] = None
    clusters: Optional[np.ndarray] = None
    depths: Optional[np.ndarray] = None
    times: Optional[np.ndarray] = None

    def to_dataframe(self) -> pd.DataFrame:
        """
        Convert the Trials data to a pandas DataFrame.
        For multi-dimensional arrays, only the first dimension is used as the index,
        and the remaining dimensions are stored as array objects in the cells.
        """
        data_dict = {}
        base_length = None

        # Process each attribute
        for attr_name, value in self.__dict__.items():
            if value is not None:
                if len(value.shape) == 1:
                    # 1D arrays can be directly added
                    data_dict[attr_name] = value
                    if base_length is None:
                        base_length = len(value)
                else:
                    # For multi-dimensional arrays, store them as objects
                    # Each row will contain a slice of the array
                    data_dict[attr_name] = [value[i] for i in range(value.shape[0])]
                    if base_length is None:
                        base_length = value.shape[0]

        # Create DataFrame
        df = pd.DataFrame(data_dict)
        return df

    @classmethod
    def from_tar(cls, tar_path: str | Path) -> 'Spikes':
        """
        Load cluster data from a tar file containing numpy arrays.

        Args:
            tar_path: Path to the tar file containing cluster data

        Returns:
            ClusterData instance with loaded arrays

        Raises:
            FileNotFoundError: If tar file doesn't exist
            ValueError: If expected cluster files are missing
        """
        # if not os.path.exists(tar_path):
        #     raise FileNotFoundError(f"Tar file not found: {tar_path}")

        data = cls()

        with tarfile.open(tar_path, 'r') as tar:
            trial_files = [name for name in tar.getnames() if name.startswith('spikes')]

            # Mapping between file names and dataclass attributes
            file_attr_map = {
                'spikes.amps.npy': 'amps',
                'spikes.clusters.npy': 'clusters',
                'spikes.depths.npy': 'depths',
                'spikes.times.npy': 'times',
            }

            # Extract and load each file
            for file_name, attr_name in file_attr_map.items():
                if file_name not in trial_files:
                    print(f"Warning: {file_name} not found in tar archive")
                    continue

                # try:
                #     # Extract file to memory and load with numpy
                #     member = tar.extractfile(file_name)
                #     if member is None:
                #         raise ValueError(f"Could not extract {file_name}")

                array_data = npy_loader(tar,file_name)
                setattr(data, attr_name, array_data)

                # except Exception as e:
                #     print(f"Error loading {file_name}: {str(e)}")

        return data

In [33]:
alldata_tar_path = os.path.join(os.getcwd(),data_directory,test_LFP + r".tar")
with tarfile.open(alldata_tar_path, 'r') as tar:
    print(type(tar))
    # print(tar.getnames())

    clusters = [name for name in tar.getnames() if name.startswith('clusters')]
    spikes = [name for name in tar.getnames() if name.startswith('spikes')]
    trials = [name for name in tar.getnames() if name.startswith('trials')]
    print(clusters)
    print(spikes)
    print(trials)

    for spike in spikes:
        print(npy_loader(tar,spike).shape)

    for cluster in clusters:
        print(npy_loader(tar,cluster).shape)

    for trial in trials:
        print(npy_loader(tar,trial).shape)

cluster_data = Clusters.from_tar(alldata_tar_path)
trials = Trials.from_tar(alldata_tar_path)

<class 'tarfile.TarFile'>
['clusters.depths.npy', 'clusters.originalIDs.npy', 'clusters.peakChannel.npy', 'clusters.probes.npy', 'clusters.templateWaveformChans.npy', 'clusters.templateWaveforms.npy', 'clusters.waveformDuration.npy', 'clusters._phy_annotation.npy']
['spikes.amps.npy', 'spikes.clusters.npy', 'spikes.depths.npy', 'spikes.times.npy']
['trials.feedbackType.npy', 'trials.feedback_times.npy', 'trials.goCue_times.npy', 'trials.included.npy', 'trials.intervals.npy', 'trials.repNum.npy', 'trials.response_choice.npy', 'trials.response_times.npy', 'trials.visualStim_contrastLeft.npy', 'trials.visualStim_contrastRight.npy', 'trials.visualStim_times.npy']
(10379618, 1)
(10379618, 1)
(10379618, 1)
(10379618, 1)
(1146, 1)
(1146, 1)
(1146, 1)
(1146, 1)
(1146, 50)
(1146, 82, 50)
(1146, 1)
(1146, 1)
(251, 1)
(251, 1)
(251, 1)
(251, 1)
(251, 2)
(251, 1)
(251, 1)
(251, 1)
(251, 1)
(251, 1)
(251, 1)


In [27]:
cluster_df = cluster_data.to_dataframe().query('phy_annotation != 1.0')
cluster_df.phy_annotation.value_counts()

phy_annotation
[2.0]    1069
[3.0]       1
Name: count, dtype: int64

In [28]:
cluster_df.head()

Unnamed: 0,depths,original_ids,peak_channel,probes,template_waveform_chans,template_waveforms,waveform_duration,phy_annotation
0,[2094.0924379360854],[1],[205.0],[0.0],"[204.0, 203.0, 202.0, 201.0, 206.0, 205.0, 199...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",[28.0],[2.0]
1,[3174.3952154674303],[2],[310.0],[0.0],"[309.0, 307.0, 311.0, 308.0, 313.0, 305.0, 304...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",[17.0],[2.0]
2,[171.90866042710923],[3],[18.0],[0.0],"[17.0, 15.0, 13.0, 19.0, 16.0, 21.0, 12.0, 11....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",[23.0],[2.0]
3,[482.96850438792995],[4],[49.0],[0.0],"[48.0, 46.0, 44.0, 50.0, 47.0, 43.0, 42.0, 45....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",[18.0],[2.0]
5,[2178.9536964067097],[6],[212.0],[0.0],"[211.0, 209.0, 214.0, 210.0, 212.0, 213.0, 207...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",[25.0],[2.0]


In [30]:
cluster_df.template_waveforms.iloc[0].T.shape

(50, 82)

In [31]:
trial_df = trials.to_dataframe()
trial_df.head()

(251, 11)


In [35]:
spike_data = Spikes.from_tar(alldata_tar_path)
spike_df = spike_data.to_dataframe()
spike_df.head()

Unnamed: 0,amps,clusters,depths,times
0,[348.2527631004607],[73],[2893.285],[0.0034333333333333334]
1,[172.98296191495376],[176],[2327.7686],[0.007033333333333333]
2,[351.8927727973601],[254],[2219.6008],[0.007566666666666667]
3,[495.4610743675204],[21],[2159.3933],[0.008166666666666666]
4,[92.05229598127468],[69],[3095.128],[0.0087]


In [36]:
print(spike_df.shape)

(10379618, 4)
