In [None]:
%reload_ext autoreload
%autoreload 2

import functools
print = functools.partial(print, flush=True)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import pickle

import flexiznam as flz
from cottage_analysis.io_module import harp
from cottage_analysis.preprocessing import find_frames
from cottage_analysis.imaging.common import find_frames as find_img_frames
from cottage_analysis.filepath import generate_filepaths
from cottage_analysis.imaging.common import imaging_loggers_formatting as format_loggers
from cottage_analysis.preprocessing import synchronisation
from cottage_analysis.analysis import find_depth_neurons, fit_gaussian_blob, common_utils
from cottage_analysis.stimulus_structure import spheres_tube

In [None]:
# Example session
project = 'hey2_3d-vision_foodres_20220101'
mouse = 'PZAH8.2f'
session = 'S20230126'
RECORDING = 'R144331_SpheresPermTubeReward'
protocol = 'SpheresPermTubeReward'
MESSAGES = 'harpmessage.bin'
flexilims_session = flz.get_flexilims_session(project_id=project)
# all_protocol_recording_entries = generate_filepaths.get_all_recording_entries(project=project, 
#                                                                               mouse=mouse, 
#                                                                               session=session, 
#                                                                               protocol=protocol, 
#                                                                               flexilims_session=flexilims_session)

# # DO NOT RUN THIS FUNCTION (TAKES 2hrs ish): to find monitor frames from photodiode signal
# find_monitor_frames(project=project, 
#                     mouse=mouse, 
#                     session=session, 
#                     protocol=protocol, 
#                     all_protocol_recording_entries=None, 
#                     irecording=None, 
#                     flexilims_session=None)
    

In [None]:
# Generate synchronisation dataframes
vs_df =  synchronisation.generate_vs_df(project=project, 
                                        mouse=mouse, 
                                        session=session, 
                                        protocol=protocol, 
                                        irecording=0)
trials_df, imaging_df = synchronisation.generate_trials_df(project=project, 
                                                           mouse=mouse, 
                                                           session=session, 
                                                           protocol=protocol, 
                                                           vs_df=vs_df, 
                                                           irecording=0)

In [None]:
# Find depth neurons and fit preferred depth
neurons_df = find_depth_neurons.find_depth_neurons(
    project=project, 
    mouse=mouse, 
    session=session, 
    protocol="SpheresPermTubeReward", 
    rs_thr=0.2
)

neurons_df = find_depth_neurons.fit_preferred_depth(
    project=project,
    mouse=mouse,
    session=session,
    protocol="SpheresPermTubeReward",
    depth_min=0.02,
    depth_max=20,
    batch_num=10,
)

In [None]:
neurons_df

In [None]:
# Fit gaussian blob to neuronal activity (THIS WILL ALSO TAKE QUITE LONG! SO STOP EARLIER AND CHECK A FEW CELLS)
neurons_df = fit_gaussian_blob.fit_gaussian_blob(project=project, 
                  mouse=mouse, 
                  session=session, 
                  protocol="SpheresPermTubeReward", 
                  rs_thr=0.01, 
                  param_range={'rs_min':0.005, 'rs_max':5, 'of_min':0.03, 'of_max':3000}, 
                  batch_num=1)

In [None]:
# Regenerate sphere stimuli
def regenerate_stimuli(project,mouse,session,protocol):
    def regenerate_stimuli_each_recording(project,mouse,session,protocols,protocol,irecording,nrecordings,):
        (
            rawdata_folder,
            protocol_folder,
            _,
            _,
            _,
        ) = generate_filepaths.generate_file_folders(
            project=project,
            mouse=mouse,
            session=session,
            protocol=protocol,
            all_protocol_recording_entries=None,
            recording_no=0,
        )

        param_log = pd.read_csv(rawdata_folder / "NewParams.csv")
        param_log = param_log.rename(
            columns={"Radius": "Depth"}
        )

        with open(protocol_folder/"sync/imaging_df.pickle","rb") as handle:
            imaging_df = pickle.load(handle)
        with open(protocol_folder/"sync/vs_df.pickle","rb") as handle:
            vs_df = pickle.load(handle)
        with open(protocol_folder/"sync/trials_df.pickle","rb") as handle:
            trials_df = pickle.load(handle)
        output = spheres_tube.regenerate_frames(
            frame_times=imaging_df['harptime_imaging_trigger'].values, #using imaging frames as the list of timepoints to reconstruct stimuli
            trials_df=trials_df,
            vs_df=vs_df,
            param_logger=param_log,
            time_column="HarpTime",
            resolution=1,
            sphere_size=10,
            azimuth_limits=(-120, 120),
            elevation_limits=(-40, 40),
            verbose=True,
            output_datatype="int16",
            output=None,
        )

        np.save(protocol_folder/'stimuli.npy', output)
        
        return output 
    
    outputs = []
    output = common_utils.loop_through_recordings(project=project,
                                                  mouse=mouse,
                                                  session=session,
                                                  protocol=protocol, 
                                                  func=regenerate_stimuli_each_recording)
    outputs.append(output)
    outputs = np.stack(outputs)
    
    return outputs


outputs=regenerate_stimuli(project,mouse,session,protocol)
    


In [None]:
output = np.load(protocol_folder/'stimuli.npy')

In [None]:
np.sum(output)

In [None]:
iroi=3

reconstructed_frames = outputs[0]
frame_times = imaging_df['harptime_imaging_trigger'].values
frame_rate = 15
delays=np.array([-1,0,1])
spk_per_frame=np.stack(imaging_df['dffs'].values)[:,iroi]
verbose=True

"""Spike triggered average of reconstructed frames by depth

Delays, in second, are delay applied to the stimulus sequence. If delay is -100,
that means that spikes were triggered by stimulus 100ms before them.

Args:
    trials_df (pd.DataFrame): stimulus structure with each row as a trial.
    reconstructed_frames (np.array): n frames x n elev x n azim binary array of
                                        stimuli
    frame_times (np.array): time of each frame, same unit as corridor_df.start_time
    frame_rate (float): frame rate to calculate delays. 144 for monitor frames, 15 for imaging frames.
    delays (np.array): array of delays in seconds
    spk_per_frame (np.array): spike for each frame, use to weight average. If None
                                will do simple average
    verbose (bool): print progress

Returns:
    sta (np.array): n depth x n delay x n elev x n azim weighted average
    nspkes (np.array): n depth vector of number of spikes
    depths (np.array): ordered depths corresponding to first sta dimension
    delays (np.array): ordered delays corresponding to second sta dimension
"""
if delays is None:
    delays = [0]
if spk_per_frame is None:
    spk_per_frame = np.ones(reconstructed_frames.shape[0])

depths = np.sort(trials_df.depth.unique())
full_sta = np.zeros((len(depths), len(delays), *reconstructed_frames.shape[1:]))
nspks = np.zeros(len(depths))

for idepth, depth in enumerate(depths):
    if verbose:
        print(f"... doing depth {depth*100} cm")
    depth_df = trials_df[trials_df.depth == depth]
    # find frames at this depth
    # starts = depth_df.imaging_frame_stim_start.values
    # ends = depth_df.imaging_frame_stim_stop.values
    starts = frame_times.searchsorted(depth_df.harptime_stim_start)
    ends = frame_times.searchsorted(depth_df.harptime_stim_stop)
    ends = ends[: len(starts)]
    frame_index = np.hstack(
        [np.arange(s, e, dtype=int) for s, e in zip(starts, ends)]
    )
    # keep non-shifted spikes for all delay
    # do it like that to look for valid frames only once
    spk_per_frame_at_depth = spk_per_frame[frame_index]
    nspks[idepth] = np.sum(spk_per_frame_at_depth)
    valid_frames = spk_per_frame_at_depth != 0
    for idelay, delay in enumerate(delays):
        if verbose:
            print(f"... ... doing delay {delay * 1000} ms")
        shift = int(delay * frame_rate)
        # shift the stim
        shifted_frames = np.clip(
            frame_index[valid_frames] + shift, 0, len(frame_times)
        )
        stims = reconstructed_frames[shifted_frames].reshape(
            len(shifted_frames), -1
        )
        sta = np.dot(stims.T, spk_per_frame_at_depth[valid_frames])
        sta = sta.reshape(reconstructed_frames.shape[1:])
        full_sta[idepth, idelay] = sta
        # !! Needs to add normalized STA


In [None]:
iroi=3

reconstructed_frames = outputs[0]
frame_times = imaging_df['harptime_imaging_trigger'].values
frame_rate = 15
delays=np.array([-1,0,1])
spk_per_frame=np.stack(imaging_df['dffs'].values)[:,iroi]
verbose=True

"""Spike triggered average of reconstructed frames by depth

Delays, in second, are delay applied to the stimulus sequence. If delay is -100,
that means that spikes were triggered by stimulus 100ms before them.

Args:
    trials_df (pd.DataFrame): stimulus structure with each row as a trial.
    reconstructed_frames (np.array): n frames x n elev x n azim binary array of
                                        stimuli
    frame_times (np.array): time of each frame, same unit as corridor_df.start_time
    frame_rate (float): frame rate to calculate delays. 144 for monitor frames, 15 for imaging frames.
    delays (np.array): array of delays in seconds
    spk_per_frame (np.array): spike for each frame, use to weight average. If None
                                will do simple average
    verbose (bool): print progress

Returns:
    sta (np.array): n depth x n delay x n elev x n azim weighted average
    nspkes (np.array): n depth vector of number of spikes
    depths (np.array): ordered depths corresponding to first sta dimension
    delays (np.array): ordered delays corresponding to second sta dimension
"""
if delays is None:
    delays = [0]
if spk_per_frame is None:
    spk_per_frame = np.ones(reconstructed_frames.shape[0])

depths = np.sort(trials_df.depth.unique())
full_sta = np.zeros((len(depths), len(delays), *reconstructed_frames.shape[1:]))
nspks = np.zeros(len(depths))

for idepth, depth in enumerate(depths):
    if verbose:
        print(f"... doing depth {depth*100} cm")
    depth_df = trials_df[trials_df.depth == depth]
    # find frames at this depth
    # starts = depth_df.imaging_frame_stim_start.values
    # ends = depth_df.imaging_frame_stim_stop.values
    starts = frame_times.searchsorted(depth_df.harptime_stim_start)
    ends = frame_times.searchsorted(depth_df.harptime_stim_stop)
    ends = ends[: len(starts)]
    frame_index = np.hstack(
        [np.arange(s, e, dtype=int) for s, e in zip(starts, ends)]
    )
    # keep non-shifted spikes for all delay
    # do it like that to look for valid frames only once
    spk_per_frame_at_depth = spk_per_frame[frame_index]
    nspks[idepth] = np.sum(spk_per_frame_at_depth)
    valid_frames = spk_per_frame_at_depth != 0
    for idelay, delay in enumerate(delays):
        if verbose:
            print(f"... ... doing delay {delay * 1000} ms")
        shift = int(delay * frame_rate)
        # shift the stim
        shifted_frames = np.clip(
            frame_index[valid_frames] + shift, 0, len(frame_times)
        )
        stims = reconstructed_frames[shifted_frames].reshape(
            len(shifted_frames), -1
        )
        sta = np.dot(stims.T, spk_per_frame_at_depth[valid_frames])
        sta = sta.reshape(reconstructed_frames.shape[1:])
        full_sta[idepth, idelay] = sta
        # !! Needs to add normalized STA


In [79]:
from cottage_analysis.analysis import sta
outputs = sta.regenerate_stimuli(project, mouse, session, protocol)

  and should_run_async(code)


---------Process protocol 1/1---------


  sess_children = get_session_children(
  recording_entries, recording_path = get_recording_entries(
             The last recording that matches this protocoll will be returned.
  warn(


KeyboardInterrupt: 

In [78]:
np.sum(outputs)

  and should_run_async(code)


11795