In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import emd
import tarfile
import io
import scipy
from scipy import signal
from icecream import ic
import neurodsp.filt as dsp
import seaborn as sns
import plotly.express as px
import plotly.graph_objs as go

# Directories
Set the working directory as well as access to the main LFP dataset and supplementary dataset

In [None]:
# Set the current working directory
cwd = os.chdir(r"C:\Python Work Directory\NMA_Impact_Scholars_Steinmetz")

# Access to the Steinmetz LFP dataset
# lfp_dat = r"E:\Steinmetz_Dataset"
lfp_dat  = r"C:\Python Work Directory\NMA_Impact_Scholars_Steinmetz\data\examples"

# @title Data retrieval
data_directory = r'data\spikeAndBehavioralData'

# test_dataset
test_LFP = r"Cori_2016-12-18"

In [None]:
print(os.path.join(os.getcwd(),data_directory))

## Brain Regions of Interest

In [None]:
hpc = ["CA1", "CA3", "DG", "SUB"]
pfc = ["ACA", "ILA", "PL","RSP"]
region_loop = hpc + pfc
region_select = 'CA1'

## Power spectrum functions

### Defining file iterator (for later use)

In [None]:
walker = os.walk(os.path.join(os.getcwd(),data_directory))
for root, dirs, files in walker:
    print(root)
    print(dirs)
    print(files)

### .npy file loader from tarball

In [None]:
def npy_loader(filename:str)-> np.ndarray:
    '''
    Numpy loader function for .npy in tarball (.tar) packages.
    
    :param filename: str
    :return: np.ndarray 
    '''
    try:
        npy_file = tar.extractfile(filename)
        if npy_file is not None:
            npy_file_content = npy_file.read()
            
            # Check file size to confirm it's not empty or corrupted
            if len(npy_file_content) == 0:
                raise ValueError(f"The .npy file '{filename}' is empty or corrupted.")
            
            # Load .npy file from memory using BytesIO
            np_data = np.load(io.BytesIO(npy_file_content))
            return np_data
        else:
            raise FileNotFoundError(f"Could not find or extract the file: {probe_filename}")
    except Exception as e:
        print(f"Error reading .npy file: {e}")
    

In [None]:
alldata_tar_path = os.path.join(os.getcwd(),data_directory,test_LFP + r".tar")
with tarfile.open(alldata_tar_path, 'r') as tar:
    print(tar.getnames())
    
    brain_loc_filename = [name for name in tar.getnames()[:5] if name.endswith('.tsv')][0]
    probe_desc_filename = [name for name in tar.getnames() if name.endswith('rawFilename.tsv')][0]
    probe_filename = [name for name in tar.getnames() if name.endswith('channels.probe.npy')][0]
    raw_Row_filename = [name for name in tar.getnames() if name.endswith('channels.rawRow.npy')][0]
    site_filename = [name for name in tar.getnames() if name.endswith('channels.site.npy')][0]
    site_pos_filename = [name for name in tar.getnames() if name.endswith('channels.sitePositions.npy')][0]
    
    
    brain_loc = pd.read_csv(tar.extractfile(brain_loc_filename), sep='\t')
    probe_desc = pd.read_csv(tar.extractfile(probe_desc_filename), sep='\t')
    probe = npy_loader(probe_filename)
    raw_Row = npy_loader(raw_Row_filename)
    site = npy_loader(site_filename)
    site_pos = npy_loader(site_pos_filename)
    
    

In [None]:
brain_loc.shape

In [None]:
brain_loc.query(f'allen_ontology == "{region_select}"')

In [None]:
probe_desc

In [None]:
brain_loc['probe'] = probe
brain_loc['site'] = site
brain_loc[['site_pos_x','site_pos_y']] = site_pos
brain_loc['raw_Row'] = raw_Row

In [None]:
brain_loc.query(f'allen_ontology == "{region_select}"') 

## Discovering the Channel Labelling Scheme

In [None]:
# Create the scatter plot using Plotly Express
fig = px.scatter(brain_loc.query('probe == 0'),
                 x='site_pos_x',
                 y='site_pos_y',
                 color='site',
                 title='Brain Location Scatter Plot',
                 width=1200,  # Equivalent to figsize=(20,10)
                 height=600)

# Customize the layout if needed
fig.update_layout(
    title_x=0.5,  # Center the title
    legend_title_text='Site',
    # Add any additional layout customizations here
)

# Show the plot
fig.show()

In [None]:
# Create the scatter plot using Plotly Express
fig = px.scatter(brain_loc.query('probe == 0'),
                 x='site_pos_x',
                 y='site_pos_y',
                 color='allen_ontology',
                 title='Brain Location Scatter Plot',
                 width=1200,  # Equivalent to figsize=(20,10)
                 height=600)

# Customize the layout if needed
fig.update_layout(
    title_x=0.5,  # Center the title
    legend_title_text='Site',
    # Add any additional layout customizations here
)

# Show the plot
fig.show()

### Probe Selection

Select the necessary probes that have recording sites of our brain regions of interest

In [None]:
# Identify probe for CA1
probe_select = brain_loc.query(f'allen_ontology == "{region_select}"')['probe'].unique() == np.array(probe_desc.index)

In [None]:
#TODO: Build a dataset loader that interacts with the online database

# Path to your .tar file


tar_path = os.path.join(lfp_dat,test_LFP + r"_lfp.tar")



# Define the parameters based on the documentation
num_channels = 385  # 385 channels as specified
data_type = np.int16  # int16 data type
sampling_rate = 2500  # 2500 Hz sampling rate

# Open the .tar file and load the .bin file
with tarfile.open(tar_path, 'r') as tar:
    # Identify the .bin file (assuming there's only one)
    bin_file_name = np.array(tar.getnames())[probe_select][0]
    
    
    # Extract the .bin file to memory
    bin_file = tar.extractfile(bin_file_name)
    
    # Determine the number of samples by dividing the file size by the number of channels
    # and the size of each data point (2 bytes for int16)
    file_size = tar.getmember(bin_file_name).size
    num_samples = file_size // (num_channels * np.dtype(data_type).itemsize)
    
    # Read the .bin file in chunks if it's too large for memory
    chunk_size = 1000000  # Set a reasonable chunk size
    all_data = []
    
    while True:
        # Read a chunk of data
        data_chunk = np.frombuffer(bin_file.read(chunk_size * num_channels * np.dtype(data_type).itemsize), dtype=data_type)
        if data_chunk.size == 0:
            break
        # Reshape the chunk to (num_channels, chunk_samples)
        data_chunk = data_chunk.reshape(-1, num_channels).T
        all_data.append(data_chunk)
    
    # Concatenate all chunks if the entire data needs to be loaded
    reshaped_data = np.hstack(all_data)

# At this point, reshaped_data contains the LFP data in shape (385, num_total_samples)

In [None]:
ic(reshaped_data.shape)

## Synchronization Signal Channel
When plotting Channel 385, we can observe that this channel contains our time events of stimulus being presented

In [None]:
sampling_rate = 2500
total_time = reshaped_data.shape[1]/sampling_rate
time_points = np.linspace(0,total_time, reshaped_data.shape[1])
time_points_ms = time_points*1000

sync_signal_fig = px.line(
    x=time_points_ms[:1000000],
    y=reshaped_data[-1,:1000000],
    labels={'x': 'Time (ms)', 'y': 'Amplitude (μV)'},
    title='Synchronization Signal Time Series'
)

sync_signal_fig.show()

### Plot of a random CA1 channel

In [None]:
CA1_signal_fig = px.line(
    x=time_points_ms[:1000000],
    y=reshaped_data[-233,:1000000],
    labels={'x': 'Time (ms)', 'y': 'Amplitude (μV)'},
    title='CA1, Channel 233 Signal Time Series'
)

CA1_signal_fig.show()

### Power Spectrum of CA1


In [None]:
select_channels = reshaped_data[brain_loc.query(f'allen_ontology == "{region_select}"')['raw_Row'].unique()]

In [None]:
freqs, pspec = signal.welch(x = select_channels, fs = 2500, scaling = 'spectrum', nperseg = 4*1024)

In [None]:
colors = px.colors.sequential.Viridis
num_channels = len(pspec)
color_indices = np.linspace(0, 1, num_channels)

pspec_fig = go.Figure()

# Create the figure
for i, psd in enumerate(pspec):
    pspec_fig.add_trace(go.Scatter(
                            x=freqs,
                            y=psd,
                            mode='lines',
                            line=dict(color=colors[int(color_indices[i] * (len(colors) - 1))]),
    ))

# Customize layout
pspec_fig.update_layout(
    title='Power Spectrum of CA1 channels',
    width=1200,
    height=600,
    xaxis_title='Frequency (Hz)',
    yaxis_title='Power',
    yaxis_type='log'
)

# Display the plot
pspec_fig.show()

In [None]:
# Retrieves the reference to subtract from the signal
def CAR_filter(signal, mode ='mean'):
    avg_ref = np.zeros((signal.shape[0],1))
    if mode == 'mean':
        avg_ref = np.mean(signal,axis=0)
    if mode == 'median':
        avg_ref = np.median(signal,axis=0)
    return avg_ref

In [None]:
avg_ref = CAR_filter(reshaped_data[:-1], mode='median')

### Power spectrum after selecting the best HPC channel