In [None]:
import warnings
warnings.filterwarnings('ignore')

try:
    %cd openscope_databook
    from databook_utils.dandi_utils import dandi_download_open
    %cd ..
except:
    !git clone https://github.com/AllenInstitute/openscope_databook.git
    %cd openscope_databook
    %pip install -e .
        
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from math import sqrt
from scipy import interpolate

from matplotlib.gridspec import GridSpec
import numpy as np
import pandas as pd
from dandi import dandiapi
from pynwb import NWBHDF5IO
import scipy.signal
from scipy.ndimage import gaussian_filter
import icsd
from scipy.signal import welch
import math
import quantities as pq
import gc

%matplotlib inline

sub_sessions = [("sub-619296", "ses-1187930705"), ("sub-620333", "ses-1188137866"), ("sub-620334", "ses-1189887297"), ("sub-625545", "ses-1182865981"), ("sub-625554", "ses-1181330601"), ("sub-625555", "ses-1183070926"), ("sub-630506", "ses-1192952695"), ("sub-631510", "ses-1196157974"), ("sub-631570", "ses-1194857009"), ("sub-633229", "ses-1199247593"), ("sub-637484", "ses-1208667752")]

#read nwb file
session_num = 0
filepath = f"../material/00248_v240130/{sub_sessions[session_num][0]}/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}_ogen.nwb"
stim_io = NWBHDF5IO(filepath, mode="r", load_namespaces=True)
stim_nwb = stim_io.read()


In [None]:
#probe 파일 선택
probe_num = 2

stim_num = [110000, 110101, 110105, 110106, 110107, 110109, 110110, 110111, 110506, 110511, 111105, 111109, 111201, 111299, 111301, 111302, 111303, 111304, 111305, 111306, 111307, 111308]
frame_stimtype = [(0, '0'), (3, 'IC1'), (7, 'IC2'), (4, 'LC1'), (6, 'LC2'), (8, 'IRE1'), (9, 'IRE2')]

lfp_filepath = f"../material/00248_v240130/{sub_sessions[session_num][0]}/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}_probe-" + str(probe_num) + "_ecephys.nwb"
lfp_io = NWBHDF5IO(lfp_filepath, mode="r", load_namespaces=True)
lfp_nwb = lfp_io.read()
lfp = lfp_nwb.acquisition["probe_" + str(probe_num) + "_lfp_data"]

In [None]:
#channel 정보 표시
for i in range(len(lfp.electrodes)) :
    print(i, lfp.electrodes['location'][i], round(lfp.electrodes['probe_vertical_position'][i]/1000, 3), round(lfp.electrodes['y'][i]/1000, 3))

In [None]:
#실행하면 stim_windows[]에 (frame #, trial #, time, channel)형태로 LFP 생성됨
def get_all_stim_times(stim_nwb, frame_num):
    stim_table = stim_nwb.intervals["ICwcfg1_presentations"] #chosen stimulus

    frame = stim_table.frame[:]
    nzframe_idxs = np.nonzero(frame)[0]
    start_idx = int(nzframe_idxs[0])
    end_idx = int(nzframe_idxs[-1])

    start_time = stim_table.start_time[start_idx]
    end_time = stim_table.stop_time[end_idx]

    stim_select = lambda row : int(row['frame'].iloc[0]) == frame_num
    all_stim_times = [float(row['start_time'].iloc[0]) for i, row in enumerate(stim_table) if i % 2 == 0 and stim_select(row) and i >= start_idx and i <= end_idx]
    print(len(all_stim_times))
    return all_stim_times


def get_timestamps_data(lfp, all_stim_times):
    #print("First timestamp stimulus data: ", all_stim_times[0])
    #print("Last timestamp stimulus data: ", all_stim_times[-1])
    #print("First timestamp LFP data: ", lfp.timestamps[0])
    #print("Last timestamp LFP data: ", lfp.timestamps[-1])
    period_start = lfp.timestamps[0]
    period_end = lfp.timestamps[-1]
    # filter stim_timestamps to just timestamps within period
    stim_times = np.array([ts for ts in all_stim_times if ts >= period_start and ts <= period_end])
    if len(stim_times) == 0:
        raise ValueError("There are no stimulus timestamps in that period")

    # find indices within lfp data that correspond to period bounds
    period_start_idx, period_end_idx = None, None
    for i, ts in enumerate(lfp.timestamps):
        if not period_start_idx and ts >= period_start:
            period_start_idx = i
        if period_start_idx and ts >= period_end:
            period_end_idx = i
            break

    if period_start_idx == None or period_end_idx == None:
        raise ValueError("Period bounds not found within lfp data")
    #print(period_start_idx, lfp.timestamps[period_start_idx])
    # get slice of LFP data corresponding to the period bounds
    lfp_timestamps = lfp.timestamps[period_start_idx:period_end_idx]
    lfp_data = lfp.data[period_start_idx:period_end_idx]
    return lfp_timestamps, lfp_data, stim_times

def get_interp_lfp(interp_hz, lfp_timestamps, lfp_data) :
    
    # generate regularly-space x values and interpolate along it
    time_axis = np.arange(lfp_timestamps[0], lfp_timestamps[-1], step=(1/interp_hz))
    interp_channels = []

    # interpolate channel by channel to save RAM
    for channel in range(lfp_data.shape[1]):
        f = interpolate.interp1d(lfp_timestamps, lfp_data[:,channel], axis=0, kind="nearest", fill_value="extrapolate")
        interp_channels.append(f(time_axis))

    interp_lfp = np.transpose(interp_channels)
    return interp_lfp

#smoothing 추가시 주석 제거하고 맨 아래 for문에 있는 주석도 제거, bin_size조절
# def lfp_smoothing(lfp_data, bin_size) :
#     S = np.zeros(lfp_data.shape)
#     for t in range(lfp_data.shape[0]) :
#         if t< bin_size :
#             S[t] = np.mean(lfp_data[:t+1])
#         else :
#             S[t] = np.sum(lfp_data[t-bin_size:t])/bin_size
#     return S

def get_windows(window_start_time, window_end_time, interp_hz, stim_times, lfp_timestamps, interp_lfp) :
    # validate window bounds
    if window_start_time > 0:
        raise ValueError("start time must be non-positive number")
    if window_end_time <= 0:
        raise ValueError("end time must be positive number")

    # get event windows
    windows = []
    window_length = int((window_end_time-window_start_time) * interp_hz)

    for stim_ts in stim_times:
        # convert time to index
        #print(stim_ts)
        start_idx = int( (stim_ts + window_start_time - lfp_timestamps[0]) * interp_hz )
        end_idx = start_idx + window_length

        # bounds checking
        if start_idx < 0 or end_idx > len(interp_lfp):
            print("bound out", start_idx, end_idx)
            continue

        windows.append(interp_lfp[start_idx:end_idx])

    if len(windows) == 0:
        raise ValueError("There are no windows for these timestamps")

    windows = np.array(windows)
    return windows

interp_hz = 1000
stim_window_start_time = -0.0 #여기랑 아래 줄 조정해서 image 자극 전 후 시간 선택
stim_window_end_time = 0.4
stim_windows = [] #(frame #, trial #, time, channel)

for frame_num in range(len(stim_num)) :
    print(frame_num)
    all_stim_times_temp = get_all_stim_times(stim_nwb, frame_num)
    lfp_timestamps_temp, lfp_data_temp, stim_times_temp = get_timestamps_data(lfp, all_stim_times_temp)
    interp_lfp_temp = get_interp_lfp(interp_hz, lfp_timestamps_temp, lfp_data_temp)
    #interp_lfp_temp = lfp_smoothing(interp_lfp_temp, 25)
    stim_windows.append(get_windows(stim_window_start_time, stim_window_end_time, interp_hz, stim_times_temp, lfp_timestamps_temp, interp_lfp_temp))
    
print("gc")
del stim_nwb
stim_io.close()
del stim_io
gc.collect()

In [None]:
#save LFP #경로는 바꿔주셔야 합니다
os.makedirs(f'../material/LFP_npy_data/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}', exist_ok = True)
for frame_num in range(len(stim_num)) : 
    np.save(f'../material/LFP_npy_data/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}/frame_{frame_num}', stim_windows[frame_num])

In [None]:
#load LFP data #경로는 바꿔주셔야 합니다
stim_windows = []
for frame_num in range(len(stim_num)) : 
    stim_windows.append(np.load(f'../material/LFP_npy_data/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}/frame_{frame_num}.npy'))

plot LFP trace

In [None]:
#그래프에 들어갈 channel 범위 선택
start_channel = 0
end_channel = 84
n_channels = end_channel - start_channel

In [None]:
stim_average_trace = []
for frame_num in range(len(stim_num)) :
    stim_average_trace.append(np.average(stim_windows[frame_num], axis = 0))
stim_average_trace = np.array(stim_average_trace)

#frame_num = 3 #특정 이미지에 대해서만 그리고 싶으면 아래 for문 지우고 이 줄 주석 지운 뒤 frame_num 선택
for frame_num in range(len(stim_num)) :
    amp_res = 0.000026 #여기 조정해서 선끼리 겹치는 정도 조절
    xaxis = np.linspace(stim_window_start_time, stim_window_end_time, stim_average_trace.shape[1])
    colors = plt.cm.viridis(np.linspace(0, 1, n_channels))
    fig, ax = plt.subplots(figsize=(10, n_channels/3))
    
    for i, channel in enumerate(range(start_channel, end_channel)):
        offset_trace = stim_average_trace[frame_num,:,channel] + i*amp_res
        plot = ax.plot(xaxis, offset_trace, color=colors[i])
    
    norm = mpl.colors.Normalize(vmin=start_channel, vmax=end_channel)
    cb = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=plt.cm.viridis), location="left", pad = 0.01, anchor=(0,1), shrink=0.3, ax=ax, label='Depth (channel #)')
    ax.yaxis.set_ticks([])
    plt.xlabel("time (s)")
    plt.ylabel("LFP")
    plt.title(f"LFP Traces Shown By Depth, {stim_num[frame_num]}")
    plt.axvline(0, 0, 1, color = 'black', linestyle = 'solid', linewidth = 0.5)
    plt.show()