In [1]:
import sys
sys.path.append('/home/thijs/repos/dnp-code/PGM3_correct/source/')
sys.path.append('/home/thijs/repos/dnp-code/PGM3_correct/utilities/')
sys.path.append('PGM3_correct/source/') # the path where the folder PGM is.
sys.path.append('PGM3_correct/utilities/') # the path where the folder PGM is.

In [2]:
import numpy as np
import h5py, os
import pandas as pd
import scipy.sparse
import rbm as rbm
from fishualizer_utilities import Zecording

def export_weights_for_fishualizer(weights, recording=None, 
                                   labeled_cells_only=True,
                                   path_weights='/home/thijs/',
                                   filename_weights='weights_RBM-name.h5',
                                   save_h5=True):
    '''Export weights from RBM to file compatible with fishualizer viewing.
    
    Parameters:
    ---------
        rbm: RBM class
            RBM object with weights to be exported
        recording: Zecording class
            Zecording object (from Fishualizer) that contains zebrafish data recording on which this rbm has been trained'
        labeled_cells_only: bool
            If the RBM has been trained on the Zbrain Atlas-labeled cells only; then we need to account for this when exporting the weights
        path_weights: str
            folder 
        filename_weights: str
            file name 
    '''
    
    if filename_weights[-3:] != '.h5':
        filename_weights = filename_weights + '.h5'
        
    ## Extract RBM weights:
    # weights = np.transpose(rbm.weights) 
    assert weights.ndim == 2
    assert weights.shape[0] > weights.shape[1], 'Weights should be neurons x HUs'
    print(weights.shape)
    
    ## Take care of unlabeled cells if needed
    if labeled_cells_only:
        assert recording != None, 'the zebrafish recording is needed to account for labeled cells only weights'
        n_cells = recording.n_cells
        selected_neurons = np.unique(scipy.sparse.find(recording.labels[:, np.arange(294)])[0])  # cells with zbrain label
        assert weights.shape[0] == len(selected_neurons)  # make sure shape is neurons x time (RBM has only used labelled neurons)
        print(f'n cells: {n_cells}, n labelled cells: {len(selected_neurons)}')
        full_weights = np.zeros((n_cells, weights.shape[1]), dtype='float32')  # make matrix for all cells (including non-labeled)
        full_weights[selected_neurons, :] = weights  # let non-labelled neurons have w=0 for all HU
    else:
        print('Assuming that all cells were used for RBM training')
        full_weights = weights.copy()
        
    ## Export to h5 via pd DataFrame
    df_weights = pd.DataFrame({'hu_' + str(ii).zfill(3):
                                np.squeeze(full_weights[:, ii]) for ii in range(full_weights.shape[1])})  # save as pd df with each column = one weight vector
    if save_h5:
        df_weights.to_hdf(os.path.join(path_weights, filename_weights), key='all')  # store as h5 file

    ## To view the weights in the Fishualizer: 
    ## Launch the Fishualizer
    ## Load the main data set (= rec) using File -> Open data
    ## Load the saved weights (= df_weights) using File -> Add static data -> Choose Dataset
        
    return df_weights

    
    

  dmean_v_dw = np.dot(s1.T, V)
  dvar_e_dw = np.dot(s2.T, V)
  tmp3 = np.dot(s3.T, V)
  mean_V = np.dot(weights, V) / sum_weights


# Example of how to use the function above:

In [3]:
base_path = '/home/thijs/Desktop/zf_rbm_essentials/'
# base_path = '/media/thijs/hooghoudt/Zebrafish_data/spontaneous_data_guillaume/'
data_sets = {#'20180912-Run01': '20180912_Run01Tset=.h5'}
            '20180912-Run01': '20180912_Run01_spontaneous_rbm2.h5'}#,
test_segs = '267'
train_inds_path=f'/home/thijs/repos/dnp-code/train_test_inds/20180912-Run01/train_test_inds__test_segs_{test_segs}_nseg10.pkl'  # HARD SET TO 10TH PERCENTILE OF 20180912-RUN01 (TEST SEGS 267)


In [4]:
import pickle
## Load data
recordings = {}
for data_set, data_path in data_sets.items():
    recordings[data_set] = Zecording(path=base_path + data_path, kwargs={'ignorelags': True,
                                              'forceinterpolation': False,
                                              'ignoreunknowndata': True,# 'parent': self,
                                              'loadram': True})  # load data
rec = recordings[list(data_sets.keys())[0]]
# rec = recordings['2019-03-26(Run09)']
print(rec)
regions = {#'rh1': np.array([218]), 'rhall': np.array([113]),
          'wb': np.arange(294)}
selected_neurons = {}
n_sel_cells = {}
train_data = {}
test_data = {}
full_data = {}

dict_tt_inds = pickle.load(open(train_inds_path, 'rb'))  # load dictionary with training indices
train_inds = dict_tt_inds['train_inds']  # load training inds, note that: # test_inds = dict_tt_inds['test_inds']
test_inds = dict_tt_inds['test_inds']
print(f'len test inds {len(test_inds)}')
for ir in list(regions.keys()):
    selected_neurons[ir] = np.unique(scipy.sparse.find(rec.labels[:, regions[ir]])[0])
    assert rec.spikes.shape[0] > rec.spikes.shape[1]
    train_data[ir] = rec.spikes[selected_neurons[ir], :][:, train_inds]
    test_data[ir] = rec.spikes[selected_neurons[ir], :][:, test_inds]
    n_sel_cells[ir] = len(selected_neurons[ir])
    full_data[ir] = rec.spikes[selected_neurons[ir], :]

baseline with shape (5553, 54334) is not recognized, so it cannot be loaded.
drifts with shape (5553, 2) is not recognized, so it cannot be loaded.
inferredspikes with shape (5553, 54334) is not recognized, so it cannot be loaded.
ljpcoordinates with shape (3, 54334) is not recognized, so it cannot be loaded.
segmentation with shape (30, 598, 1280) is not recognized, so it cannot be loaded.
temporalmean with shape (30, 598, 1280) is not recognized, so it cannot be loaded.
rawsignal with shape (5553, 54334) is not recognized, so it cannot be loaded.
trace with shape (5553, 1) is not recognized, so it cannot be loaded.
metadata with shape (1, 1) is not recognized, so it cannot be loaded.


Recording from /home/thijs/Desktop/zf_rbm_essentials/20180912_Run01_spontaneous_rbm2.h5
len test inds 1665


### Export 1 RBM:

In [5]:
sys.path.append('/home/thijs/repos/zf-rbm/figure_notebooks')
import swap_sign_RBM as ssrbm
rbm_path = '/home/thijs/Desktop/zf_rbm_essentials/RBM3_20180912-Run01-spontaneous-rbm2_wb_test-segs-267-nseg10_M200_l1-2e-02_duration208093s_timestamp2020-05-16-0844.data'
hu_assert = [1,  11,  12,  19,  30,  36,  37,  38,  41,  43,  52,  55,  67, 68,  70,  72,  88,  94,  95,  99, 100, 107, 111, 117, 118, 120, 124, 128, 133, 136, 138, 140, 151, 152, 167, 170, 171, 172, 175, 177, 181, 186, 188, 191, 198]  # HUs that sholud be swapped
tmp_RBM = pickle.load(open(rbm_path, 'rb'))
RBM = ssrbm.swap_sign_RBM(RBM=tmp_RBM, verbose=2, assert_hu_inds=hu_assert)

45/200 HU weights are flipped
Flipped HUs are: (array([  1,  11,  12,  19,  30,  36,  37,  38,  41,  43,  52,  55,  67,
        68,  70,  72,  88,  94,  95,  99, 100, 107, 111, 117, 118, 120,
       124, 128, 133, 136, 138, 140, 151, 152, 167, 170, 171, 172, 175,
       177, 181, 186, 188, 191, 198]),)


In [12]:
tmp_df = export_weights_for_fishualizer(weights=RBM.weights.T, recording=rec, 
                                   labeled_cells_only=True,
                                   path_weights='/home/thijs/',
                                   filename_weights='weights_RBM-name2',
                                   save_h5=False)

(52518, 200)
n cells: 54334, n labelled cells: 52518


In [45]:
tmp_df.iloc[200].max()

0.42704743

### Export multiple RBMs:

In [5]:
sys.path.append('/home/thijs/repos/zf-rbm/figure_notebooks')
import swap_sign_RBM as ssrbm
import analysis_functions as af

all_rbm_paths = af.rbm_paths_used_for_sweep()


Bad key "text.kerning_factor" on line 4 in
/home/thijs/.conda/envs/py37/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
http://github.com/matplotlib/matplotlib/blob/master/matplotlibrc.template
or from the matplotlib source distribution
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_Gram=True, verbose=0,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps

In [18]:
m_cond_str = 'M200_'
l_cond_str = 'l1-2e-02'
list_rbm_sel = [x for x in all_rbm_paths if m_cond_str in x and l_cond_str in x]
assert len(list_rbm_sel) == 1, list_rbm_sel
rbm_sel_path = list_rbm_sel[0]
weight_path = os.path.join('/'.join(rbm_sel_path.split('/')[:-1]), 'weights/')
weights_fn = 'weights_' + rbm_sel_path.split('/')[-1].rstrip('.data') + '_signswapped' + '.h5'
print(weight_path)

/media/thijs/hooghoudt/new_sweep_april20/RBM_sweep_combined/weights/


In [19]:
rbm_path = rbm_sel_path
# hu_assert = [1,  11,  12,  19,  30,  36,  37,  38,  41,  43,  52,  55,  67, 68,  70,  72,  88,  94,  95,  99, 100, 107, 111, 117, 118, 120, 124, 128, 133, 136, 138, 140, 151, 152, 167, 170, 171, 172, 175, 177, 181, 186, 188, 191, 198]  # HUs that sholud be swapped
tmp_RBM = pickle.load(open(rbm_path, 'rb'))
RBM = ssrbm.swap_sign_RBM(RBM=tmp_RBM, verbose=2, assert_hu_inds=None)
tmp_df = export_weights_for_fishualizer(weights=RBM.weights.T, recording=rec, 
                                   labeled_cells_only=True,
                                   path_weights=weight_path,
                                   filename_weights=weights_fn,
                                   save_h5=True)

69/100 HU weights are flipped
Flipped HUs are: (array([ 0,  1,  2,  4,  5,  6,  8,  9, 10, 11, 12, 13, 15, 17, 18, 19, 20,
       22, 24, 27, 28, 30, 31, 32, 34, 36, 37, 38, 39, 41, 42, 43, 44, 45,
       46, 47, 51, 52, 53, 55, 57, 58, 59, 60, 61, 63, 65, 66, 67, 68, 69,
       70, 71, 73, 77, 79, 81, 82, 84, 86, 87, 88, 89, 90, 92, 93, 94, 97,
       99]),)
(52518, 100)
n cells: 54334, n labelled cells: 52518


### Export VAE weights:

In [20]:
tmp_df = None
RBM = None 
env = None

In [21]:
vae_results_path = '/home/thijs/Google Drive/projects/ZF RBM; M_Internship/VAE/VAE_results_2.data'
env = pickle.load(open(vae_results_path,'rb'))
print(env['W'].shape)
env['spikes_train'] = None 
env['reconstruction_spikes'] = None

(300, 52518)


In [22]:
tmp_df = export_weights_for_fishualizer(weights=env['W'].T, recording=rec, 
                                   labeled_cells_only=True,
                                   path_weights='/home/thijs/Google Drive/projects/ZF RBM; M_Internship/VAE/',
                                   filename_weights='weights_VAE_M300_l1e-2',
                                   save_h5=True)

(52518, 300)
n cells: 54334, n labelled cells: 52518
