Will require loading in the neural dataset to run

In [None]:
a = cebra.distributions.base.HasGenerator('cuda', 0)
a.generator

In [None]:
b = cebra.distributions.discrete.Discrete(discrete, device='cpu', seed='none')
uniform_sample = b.sample_uniform(20)
uniform_sample

In [None]:
np.sort(discrete[uniform_sample])

In [None]:
c = cebra.distributions.discrete.DiscreteUniform(discrete, device='cpu', seed='none')
uniform_sample_2 = c.sample_prior(20)
sample_conditional_2 = c.sample_conditional(discrete[uniform_sample_2])
uniform_sample_2, values_2

In [None]:
np.sort(discrete[uniform_sample_2]), np.sort(discrete[sample_conditional_2])

In [None]:
### load and preprocess data for a single fish ###
# if LOAD == True, load pre-saved .npz file data. Otherwise,
# create this data as specified below and save it to .npz

##  params ##

# variables
stim_types = STIM_TYPES     # dict of all possible stims
stims = STIMS               # stim types chosen for analysis
timesteps = TIMESTEPS
rois = ROIS
stim_length_frames = STIM_LENGTH_FRAMES # used for selecting the second half of stimuli

start, stop = 0, timesteps
load_data = LOAD_DATA
save_data = SAVE_DATA

# paths
filepath = FILEPATH
filename = filepath.split('/')[-1][:-3] # fish and date only
data_folder = DATA_PATH
data_folder_HDD = '/media/storage/DATA/tom/'
filename_stim_pres_frames = f'{filename[-9:]}_stim_pres_frames.npz'
filename_neural_subset = f'{filename[-9:]}_{SIGNAL_TYPE}_subset.npz'
filename_neural_indexes = f'{filename[-9:]}_neural_indexes_all.npz'

# specify loading anatomically unrestricted data or tectal-restricted data
if RESTRICT_TO_TECTAL:
    filename_neural_subset = f'{filename[-9:]}_{SIGNAL_TYPE}_subset_tectal.npz'
    filename_neural_indexes = f'{filename[-9:]}_neural_indexes_tectal.npz'

# if not loading data, but not wanting to overwrite saved data, save as a temp file
if not save_data and not load_data: 
    print(f"Producing temp files...")
    filename_neural = f'{filename[-9:]}_{SIGNAL_TYPE}_TEMPORARY_DELETE.npz'
    filename_neural_subset = f'{filename[-9:]}_{SIGNAL_TYPE}_subset_TEMPORARY_DELETE.npz'
    filename_stim_pres_frames = f'{filename[-9:]}_stim_pres_frames_TEMPORARY_DELETE.npz'


print("Accessing data...")

## load data ##
if load_data:
        
    # Attempt to load neural data from .npz, otherwise load from HDD .h5
    # Load small datasets from .npz files
    print("Loading data...")
    (neural, stim_on_frames) =  load_data_from_file(filepath, data_folder, filename_neural_subset,
                                                    filename_stim_pres_frames)


## else generate data ##
else:
    with h5py.File(filepath, 'r') as f:

            ## neural ##

            neural_dataset = f['rois'][f'{SIGNAL_TYPE}']
            print(f"Full neural dataset shape is: {neural_dataset.shape}")

            neural, neural_indexes = generate_neural_dataset(f, neural_dataset, rois, timesteps=timesteps)

            ## stimuli ##

            (stim_pres_idx_list, 
             stim_on_fr_list,
             stim_on_fr_list_half,
             stim_end_fr_list, 
             stim_on_mask_list, 
             stim_on_mask_list_half,
             stim_dur_list) = create_stimulus_presentation_masks(f, neural, stims, stim_types,
                                                                stim_length_frames, timesteps)

            # if set, create a separate continuous "contrast "variable for each stimulus, to train
            # the CEBRA model on alongside the discrete variable
            cont_left_spot = np.zeros(neural.shape[0])
            cont_right_spot = np.zeros(neural.shape[0])
            cont_stimuli = [cont_left_spot, cont_right_spot]
            for i in range(len(cont_stimuli)):
                this_stim_on_fr = stim_on_fr_list[i]
                for pres in this_stim_on_fr:
                    cont_stimuli[i][pres:pres+STIM_LENGTH_FRAMES] = np.arange(STIM_LENGTH_FRAMES)



            ## save data ##
            print("Saving data...")

            # choose which stim_on_mask to use (half or full)
            if HALF_STIM_MASK:
                stim_on_mask_dataset = np.column_stack(stim_on_mask_list_half[:])
            else: 
                 stim_on_mask_dataset = np.column_stack(stim_on_mask_list[:])

            assert(stim_on_mask_dataset.shape == (neural.shape[0], len(stims)))
            if timesteps:
                assert(neural.shape == (timesteps, rois))

            save_data_to_file(stim_on_mask_dataset, neural, neural_indexes, 
              data_folder, filename_stim_pres_frames, filename_neural_subset,
              filename_neural_indexes)

            ## load data ##
            # Attempt to load neural data from .npz, otherwise load from HDD .h5
            # Load small datasets from .npz files
            print("Loading data...")
            
            (neural, stim_on_frames) = load_data_from_file(filepath, data_folder, filename_neural_subset,
                                                           filename_stim_pres_frames)

# end else

## final processing ##

# format the discrete variable
# left spot == 1, right spot == 2, no stimulus == 0 
left_spot, right_spot = stim_on_frames[:,0], stim_on_frames[:,1]
right_spot = np.multiply(right_spot, 2)
discrete = np.add(left_spot, right_spot)

# separate data into training and test
training_test_split = TRAINING_TEST_SPLIT
split_idx = int(np.round(neural.shape[0] * training_test_split))
neural_train, neural_test = neural[:split_idx, :], neural[split_idx:, :]
discrete_train, discrete_test = discrete[:split_idx], discrete[split_idx:]