# Demo Notebook

this notebooks shows how to covnvert a normal static .h5 file into a filetree dataset and add information to it such as cell positions, match IDs, and so forth. This data preparation is needed whenever the gaussian readout with cortex corrdinates is used.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datajoint as dj

#import dataport
#from dataport.bcm import experiment, xmatch, stack
from neuralpredictors.data.datasets import FileTreeDataset
from pathlib import Path
import numpy as np
import h5py
from tqdm import tqdm
import os

In [None]:
experiment = dj.create_virtual_module('experiment', 'sinzlab_houston_data')

In [None]:
experiment.Scan()&"animal_id=22564"

In [None]:
# example multiple restriction
restriction = [dict(animal_id=23555, session=22, scan_idx=3),  
               dict(animal_id=23555, session=23, scan_idx=1), 
               dict(animal_id=23555, session=18, scan_idx=1)]

In [None]:
# In this example, only a single dataset is converted
restriction = dict(animal_id=22564, session=3, scan_idx=12)

In [None]:
# experiment.Scan() & restriction
template = '/data/mouse/toliaslab/static/static{animal_id}-{session}-{scan_idx}-preproc0'
datasets = [(template + '.h5').format(**k) 
                for k in (experiment.Scan() & restriction).fetch('KEY')]
datasets

# Count trials

In [None]:
for datafile in datasets:
    with h5py.File(datafile) as fid:
        print(datafile, fid['images'].shape)
        

# Export Data

In [None]:
# for datapath in Path('data/').glob('*.h5'):
#     FileTreeDataset.initialize_from(datapath)

for datafile in datasets:
    FileTreeDataset.initialize_from(datafile,include_behavior=True, overwrite=True)


# Zip data

In [None]:
for key in (experiment.Scan() & restriction).proj():
    filename = (template + '/').format(**key)
    dat = FileTreeDataset(filename, 'images', 'responses')
    dat.zip()

# Link inputs and targets

In [None]:
for key in (experiment.Scan() & restriction).fetch('KEY'):
    filename = (template + '/').format(**key)
    print(filename)
    dat = FileTreeDataset(filename, 'images', 'responses')
    dat.add_link('responses', 'targets')
    dat.add_link('images', 'inputs')
    print(dat)
    
    

# Augment with Cell positions

In [None]:
for key in (experiment.Scan() & restriction).fetch('KEY'):
    print(key)
    filename = (template + '/').format(**key)
    dat = FileTreeDataset(filename, 'images', 'responses')
    ai, se, si, ui, x, y, z = (experiment.ScanSet.UnitInfo & key).fetch('animal_id', 'session', 'scan_idx', 'unit_id', 'um_x', 'um_y', 'um_z')
    p = np.c_[x,y,z]
    dat.add_neuron_meta('cell_motor_coordinates', ai, se, si, ui, p)

    
    

In [None]:
dat.neurons.cell_motor_coordinates

# Augment with cell matching from StackSet

In [None]:
for key in (experiment.Scan() & restriction).proj():
    filename = (template + '/').format(**key)
    dat = FileTreeDataset(filename, 'images', 'responses')
    
    key['scan_session'] = key.pop('session')
    rel = experiment.StackSet.Match() * experiment.StackSet.Unit() & key
    ai, se, si, ui, match_id, mx, my, mz = rel.fetch('animal_id', 'scan_session', 'scan_idx', 'unit_id', 
                                                     'munit_id','munit_x','munit_y','munit_z')
    dat.add_neuron_meta('multi_match_id', ai, se, si, ui, match_id, fill_missing=-1)
    
    munit_coordinates = np.c_[mx, my, mz]
    dat.add_neuron_meta('multi_unit_stack_coordinates', ai, se, si, ui, munit_coordinates, fill_missing=np.nan)



In [None]:
dat.neurons.multi_unit_stack_coordinates


In [None]:
dat.neurons.multi_match_id

# Correct the color channels 

In [None]:
# p = Path('data/static22845-18-5-preproc0/data/images/')
p = Path(datafile)
for filename in tqdm(p.glob('*.npy')):
    img = np.load(filename)
    if img.shape[0] == 1:
        img = np.concatenate((img, 0 * img))
        np.save(filename, img)

In [None]:
p = Path('data/static22845-18-8-preproc0/data/images/')
for filename in tqdm(p.glob('*.npy')):
    img = np.load(filename)
    if img.shape[0] == 1:
        img = np.concatenate((0 *  img, img))
        np.save(filename, img)

# Augment with multi-cell matching (don't use this for now)

In [None]:
xmatch.MatchingParameters()

In [None]:
for key in (experiment.Scan() & 'animal_id=22564').proj():
    filename = (template + '/').format(**key)
    dat = FileTreeDataset(filename, 'images', 'responses')
    rel = xmatch.UnitMatching.Match() * xmatch.neuro_data.StaticMultiDataset.Member() & key & 'match_params=1'
    
    ai, se, si, ui, match_id = rel.fetch('animal_id', 'session', 'scan_idx', 'unit_id', 'match_id')
    dat.add_neuron_meta('multi_match_id', ai, se, si, ui, match_id, fill_missing=-1)


In [None]:
(dat.neurons.multi_match_id < 0).sum()

In [None]:
dat.change_log

# Lengths of datasets

In [None]:
for key in (experiment.Scan() & 'animal_id=22564').proj():
    filename = 'data/static{animal_id}-{session}-{scan_idx}-preproc0/'.format(**key)
    dat = FileTreeDataset(filename, 'images', 'responses')
    print(len(dat))