## Test SPIM data CEBRA model

- Use CEBRA label contrastive learning on neural data from one fish
    - design model
    - convert SPIM data to usable format 
    - load data
    - fit with label
    - plot embeddings
    - try to predict stimulus presence
    - try to decode stimulus type (left/right spots)
        - create a discrete variable that labels the post-stimulus frames for right and left spots
        - This should inform the decoder to separate embedding states (which should vary between left and right spots)<br/><br/>

In [1]:
import cebra
import numpy as np
import matplotlib.pyplot as plt
import os
import h5py

In [2]:
# # define globals

# list of all data files
dat_files = ['/media/storage/DATA/lfads_export/f1_221027.h5',
             '/media/storage/DATA/lfads_export/f1_221103.h5',
             '/media/storage/DATA/lfads_export/f2_221103.h5',
             '/media/storage/DATA/lfads_export/f3_221103.h5']

global FILEPATH
global TIMESTEPS
global ROIS
global ITERS
global LOAD
global STIM_TYPES
global STIMS
global STIM_MASKS

FILEPATH = dat_files[0]
TIMESTEPS = 15000
ROIS = 10000
ITERS = 15000
LOAD = False
STIM_TYPES = {'left_spot':0, 'right_spot':1,  \
              'open_loop_grating':2, 'closed_loop_grating':3}
STIMS = ['left_spot', 'right_spot']
STIM_MASKS = [['left_spot', 'right_spot']]

In [3]:
# # define model

cebra_time_model = cebra.CEBRA(
    model_architecture='offset10-model',
    device='cuda_if_available',
    conditional='time',
    temperature_mode='auto',
    min_temperature=0.1,
    time_offsets=10,
    max_iterations=ITERS,
    max_adapt_iterations=500,
    batch_size=None,
    learning_rate=1e-4,
    output_dimension=3,
    verbose=True,
    num_hidden_units=32,
    hybrid=False
    )
print(cebra_time_model)

CEBRA(conditional='time', learning_rate=0.0001, max_iterations=15000,
      model_architecture='offset10-model', output_dimension=3,
      temperature_mode='auto', time_offsets=10, verbose=True)


In [26]:
# to place in load data loop

filepath = FILEPATH
with h5py.File(filepath, 'r') as f:

    # get stimulus presentations
    stimuli = f['visuomotor']['presentations']
    stim_type = stimuli['stim_type']
    spot_pres_fr_l = np.where(np.isin(stim_type, 1))[0]
    spot_pres_fr_r = np.where(np.isin(stim_type, 2))[0]
    print(f'Out of a total {stim_type.size} stimulus presentations:\n \
          {spot_pres_fr_l.size} left spots\n \
          {spot_pres_fr_r.size} right spots')
    
    spot_pres_fr = np.column_stack((spot_pres_fr_l, spot_pres_fr_r))

    # neural
    neural = f['rois']['dfof']
    print(f"Full neural dataset shape is: {neural.shape}")

    # truncate neural
    # select first TIMESTEPS timesteps and random ROIS rois
    neural_indexes = np.sort(
                        np.random.choice(
                                    np.arange(neural.shape[1]), size=ROIS, replace=False
                                    )
                        )
    neural = np.array(neural[start:stop, neural_indexes])
    print(f'Truncated neural dataset shape is: {neural.shape}')

    # assert shapes
    assert(neural.shape == (TIMESTEPS, ROIS))
    assert(spot_pres_fr == (spot_pres_fr_l.size, 2))
    
    np.savez(f'{data_folder}{filepath_spot_pres_fr}', spot_pres_fr=spot_pres_fr)
    np.savez(f'{data_folder}{neural}', neural=neural)



print(spot_pres_fr_l)
        

Out of a total 142 stimulus presentations:
           35 left spots
           35 right spots
Full neural dataset shape is: (43350, 93122)


NameError: name 'start' is not defined

In [103]:
a = [1,1,1,1]
b = [4,1,1,4]
np.bitwise_xor(a,b)

array([5, 0, 0, 5])

In [152]:
a = np.zeros((100000,)).astype(int)
a[[stim_pres_fr_l.astype(int), stim_end_fr_l.astype(int)]] = 1
a = np.bitwise_xor.accumulate(a) | a
a[793:825]


array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 0, 0, 0, 0])

In [115]:
stim_pres_fr_l

array([  798.,  1963.,  2933.,  4218.,  5638.,  6683.,  8263.,  9583.,
       10928., 12078., 13513., 14858., 15888., 17053., 17988., 18998.,
       20083., 21278., 22388., 23538., 24703., 25863., 27303., 28623.,
       29693., 30903., 32078., 33643., 34828., 36173., 37023., 38368.,
       39488., 40788., 42343.])

In [110]:
a = np.array([1,1,1,1])
b = np.array([4,1,1,4])
np.bitwise_xor.accumulate(a)|a

array([1, 1, 1, 1])

In [176]:
stim_pres_idx_list

[array([  0,   4,   7,  11,  16,  20,  25,  29,  34,  38,  43,  48,  51,
         55,  58,  61,  65,  69,  72,  76,  80,  83,  88,  93,  96, 100,
        104, 109, 113, 117, 120, 125, 129, 133, 138])]

In [179]:
filepath = FILEPATH
stim_types = STIM_TYPES
neural = np.zeros((100000,1))
with h5py.File(filepath, 'r') as f:

    # get stimulus presentations
    stimuli = f['visuomotor']['presentations']
    stim_type = stimuli['stim_type'].astype(int)
    stim_on_fr = stimuli['onset_frame'].astype(int)
    stim_end_fr = stimuli['offset_frame'].astype(int)


    (stim_pres_idx_list, stim_on_fr_list,
     stim_end_fr_list, stim_on_mask_list, stim_dur_list)  = [],[],[],[],[]

    for stim in STIMS:

        # convert stim name to stim number
        stim_num = STIM_TYPES[stim]

        # find the presentation indexes for all specified stim types
        stim_pres_idx_list.append(np.where(np.isin(stim_type, stim_num))[0])

        # index stim onset frames with the presentation indexes
        this_stim_on_frame = stim_on_fr[stim_pres_idx_list[stim_num]]
        stim_on_fr_list.append(this_stim_on_frame)

In [196]:
stim_on_mask_list[1][793:825]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [194]:
filepath = FILEPATH
stim_types = STIM_TYPES
neural = np.zeros((100000,1))
with h5py.File(filepath, 'r') as f:

    # get stimulus presentations
    stimuli = f['visuomotor']['presentations']
    stim_type = stimuli['stim_type'].astype(int)
    stim_on_fr = stimuli['onset_frame'].astype(int)
    stim_end_fr = stimuli['offset_frame'].astype(int)


    (stim_pres_idx_list, stim_on_fr_list,
     stim_end_fr_list, stim_on_mask_list, stim_dur_list)  = [],[],[],[],[]

    for stim in STIMS:

        # convert stim name to stim number
        stim_num = STIM_TYPES[stim]  

        # find the presentation indexes for all specified stim types
        # must account for data index starting at 1
        stim_pres_idx_list.append(np.where(np.isin(stim_type, stim_num + 1))[0])

        # index stim onset frames with the presentation indexes
        this_stim_on_frame = stim_on_fr[stim_pres_idx_list[stim_num]]
        stim_on_fr_list.append(this_stim_on_frame)

        # index stim end frames with the presentation indexes
        this_stim_end_frame = stim_end_fr[stim_pres_idx_list[stim_num]]
        stim_end_fr_list.append(this_stim_end_frame)

        # create a boolean mask of stimulus presentation frames (1 == stimulus on, 0 == stimulus off)
        stim_on_mask = np.zeros(neural.shape[0]).astype(int)
        stim_on_mask[[stim_on_fr_list[stim_num], stim_end_fr_list[stim_num]]] = 1
        stim_on_mask = np.bitwise_xor.accumulate(stim_on_mask) | stim_on_mask
        stim_on_mask_list.append(stim_on_mask)

        # find duration (in frames) of each presentation of the stimulus
        # recording rate is 5 Hz
        stim_dur_list.append(this_stim_end_frame - this_stim_on_frame)

In [200]:
# load data for a single fish
# TODO - put this in context of LOAD if-else statement

##  params ##

# variables
stim_types = STIM_TYPES

# paths
filepath = FILEPATH
filename = filepath.split('/')[-1][:-3] # fish and date only
data_folder = 'data/'
filename_spot_pres_fr = f'{filename[-9:]}_spot_pres_fr.npz'
filename_dfof = f'{filename[-9:]}_dfof_stim_decode.npz'


with h5py.File(filepath, 'r') as f:

    ## neural ##

    neural = f['rois']['dfof']
    print(f"Full neural dataset shape is: {neural.shape}")

    # get stimulus presentations
    stimuli = f['visuomotor']['presentations']
    stim_type = stimuli['stim_type'].astype(int)
    stim_on_fr = stimuli['onset_frame'].astype(int)
    stim_end_fr = stimuli['offset_frame'].astype(int)

    # initialise lists for the chosen stimuli
    (stim_pres_idx_list, stim_on_fr_list,
     stim_end_fr_list, stim_on_mask_list, stim_dur_list)  = [],[],[],[],[]

    ## stimuli ##

    # loop through chosen stimuli and find boolean masks for their 'on' frames
    for stim in STIMS:

        # convert stim name to stim number
        stim_num = STIM_TYPES[stim]  

        # find the presentation indexes for all specified stim types
        # must account for data index starting at 1
        stim_pres_idx_list.append(np.where(np.isin(stim_type, stim_num + 1))[0])

        # index stim onset frames with the presentation indexes
        this_stim_on_frame = stim_on_fr[stim_pres_idx_list[stim_num]]
        stim_on_fr_list.append(this_stim_on_frame)

        # index stim end frames with the presentation indexes
        this_stim_end_frame = stim_end_fr[stim_pres_idx_list[stim_num]]
        stim_end_fr_list.append(this_stim_end_frame)

        # create a boolean mask of stimulus presentation frames (1 == stimulus on, 0 == stimulus off)
        stim_on_mask = np.zeros(neural.shape[0]).astype(int)
        stim_on_mask[[stim_on_fr_list[stim_num], stim_end_fr_list[stim_num]]] = 1
        stim_on_mask = np.bitwise_xor.accumulate(stim_on_mask) | stim_on_mask
        stim_on_mask_list.append(stim_on_mask)

        # find duration (in frames) of each presentation of the stimulus
        # recording rate is 5 Hz
        stim_dur_list.append(this_stim_end_frame - this_stim_on_frame)

        # assert shapes
        assert(stim_on_mask_list[0].size == neural.shape[0])

        # save data
        # TODO - save neural and save stim data into separate columns of the same dataset

        # load data
        # TODO - write this    

Full neural dataset shape is: (43350, 93122)


In [39]:
# # load data (single fish)

# paths
filepath = FILEPATH
filename = filepath.split('/')[-1][:-3] # fish and date only
data_folder = 'data/'
filename_spot_pres_fr = f'{filename[-9:]}_spot_pres_fr.npz'
filename_dfof = f'{filename[-9:]}_dfof_stim_decode.npz'

# choose where in dataset to sample
start, stop = 0, 0+TIMESTEPS

# extract eye position and neural data
# do not attempt to load the entire file 
print("Accessing data...")

# load data if it is already saved, and LOAD == True
if LOAD == True:
    try:
        spot_pres_fr = cebra.load_data(f'{data_folder}{filename_spot_pres_fr}', key="spot_pres_fr")
        print(f"{filename_spot_pres_fr}_left loaded.")
        spot_pres_fr = cebra.load_data(f'{data_folder}{filename_spot_pres_fr}', key="spot_pres_fr")
        print(f"{filename_spot_pres_fr}_right loaded.")
        neural = cebra.load_data(f'{data_folder}{filename_dfof}', key="neural")
        print(f"{filename_dfof} loaded.")
    
    except:
        pass
        print("Couldn't load data into CEBRA")

else:
    with h5py.File(filepath, 'r') as f:

        # neural
        neural = f['rois']['dfof']
        print(f"Full neural dataset shape is: {neural.shape}")

        
        # get stimulus presentations
        stimuli = f['visuomotor']['presentations']
        stim_type = stimuli['stim_type'].astype(int)

        for stim in STIMS:
            


        # find the presentation indexes with left or right spots
        stim_pres_idx_l = np.where(np.isin(stim_type, 1))[0]    # left spots
        stim_pres_idx_r = np.where(np.isin(stim_type, 2))[0]    # right spots

        # print spot information
        print(f'Out of a total {stim_type.size} stimulus presentations:\n \
        {spot_pres_fr_l.size} left spots\n \
        {spot_pres_fr_r.size} right spots')

        # index stim onset frames with the presentation indexes
        stim_onset_fr = stimuli['onset_frame'].astype(int)
        stim_pres_fr_l = stim_onset_fr[stim_pres_idx_l]
        stim_pres_fr_r = stim_onset_fr[stim_pres_idx_r]

        # index stim end frames with the presentation indexes
        stim_end_fr = stimuli['offset_frame'].astype(int)
        stim_end_fr_l = stim_end_fr[stim_pres_idx_l]
        stim_end_fr_r = stim_end_fr[stim_pres_idx_r]

        # create masks of stim onset/stim end
        # left spot
        stim_on_l = np.zeros(neural.shape[0])
        stim_on_l[[stim_pres_fr_l, stim_end_fr_l]] = 1
        np.bitwise_xor.accumulate(stim_on_l) | stim_on_l
        # right spot
        stim_on_r = np.zeros(neural.shape[0])
        stim_on_r[[stim_pres_fr_r, stim_end_fr_r]] = 1
        np.bitwise_xor.accumulate(stim_on_r) | stim_on_r

        # find duration (in frames) of each presentation
        # (neural recording is at 5Hz)
        stim_dur_l = stim_end_fr_l - stim_pres_fr_l
        stim_dur_r = stim_end_fr_r - stim_pres_fr_r
        
        spot_pres_fr = np.column_stack((spot_pres_fr_l, spot_pres_fr_r))

        # assert shapes
        assert(neural.shape == (TIMESTEPS, ROIS))
        assert(spot_pres_fr.shape == (spot_pres_fr_l.size, 2))

        # save data
        np.savez(f'{data_folder}{filename_spot_pres_fr}', spot_pres_fr=spot_pres_fr)
        np.savez(f'{data_folder}{filename_dfof}', neural=neural)

        # load data
        spot_pres_fr = cebra.load_data(f'{data_folder}{filename_spot_pres_fr}', key="spot_pres_fr")
        print(f"{filename_spot_pres_fr}_left loaded.")
        neural = cebra.load_data(f'{data_folder}{filename_dfof}', key="neural")
        print(f"{filename_dfof} loaded.")
    

print(spot_pres_fr_l)
        

Accessing data...
Out of a total 142 stimulus presentations:
             35 left spots
             35 right spots
Full neural dataset shape is: (43350, 93122)
Truncated neural dataset shape is: (15000, 10000)
f1_221027_spot_pres_fr.npz_left loaded.
f1_221027_dfof_stim_decode.npz loaded.
[  0   4   7  11  16  20  25  29  34  38  43  48  51  55  58  61  65  69
  72  76  80  83  88  93  96 100 104 109 113 117 120 125 129 133 138]


In [None]:
# # fit model

In [7]:
# remember to create an array containing both variables
post_left_spot = np.arange(3)
post_right_spot = np.arange(3)
post_spot = np.column_stack([post_left_spot, post_right_spot])
post_spot.shape

(3, 2)