In [10]:
import os
import time
import tqdm
import numpy as np
import pandas as pd
import json
import itertools
import sys

sys.path.append(os.path.expanduser('~'))
from PpaPy.helper.io import makedir

import warnings
warnings.filterwarnings(action='ignore', message='Mean of empty slice')

In [11]:
class NanConverter(json.JSONEncoder):
    def nan2None(self, obj):
        if isinstance(obj, dict):
            return {k:self.nan2None(v) for k,v in obj.items()}
        elif isinstance(obj, list):
            return [self.nan2None(v) for v in obj]
        elif isinstance(obj, float) and np.isnan(obj):
            return None
        return obj
    def encode(self, obj, *args, **kwargs):
        return super().encode(self.nan2None(obj), *args, **kwargs)

In [12]:
def load_data_from_keys(json_file, key):
    def traverse_dict(d, key):
        if isinstance(d, dict):
            if key in d.keys():
                return {key: d[key]}
            else:
                return {k: traverse_dict(v, key) for k, v in d.items()}
        elif isinstance(d, list):
            return [traverse_dict(x, key) for x in d]
        else:
            return d

    with open(json_file, 'r') as f:
        data = json.load(f)
    nested_dict = traverse_dict(data, key)
    return {(innerKey, outerKey): values for outerKey, innerDict in nested_dict.items() for innerKey, values in innerDict.items()} 

Please provide where your files are stored and where you would like your data to be saved in the following section.

In [15]:
datestr = time.strftime("%Y%m%d-%HH%MM")
inpath = "/gpfs/soma_fs/scratch/src/boeger/PpaPred_eren_35727184" 
inpath = '/gpfs/soma_fs/scratch/src/boeger/PpaPred_roca_35727184'

outpath = makedir(f"/gpfs/soma_fs/home/boeger/PpaPred/{os.path.basename(inpath)}")

inpath_pattern = ['Exp2_WT_larvae', 'Exp2_tbh1_larvae','Exp2_tph1_larvae','Exp2_tdc1_larvae', 'Exp2_cat2_larvae']#
inpath_pattern = ['Exp3_WT_larvae', "Exp3_octr1_larvae", "Exp3_ser3_larvae", "Exp3_ser6_larvae", "Exp3_tyra2_larvae", "Exp3_ser2_larvae", "Exp3_lgc55_larvae", "Exp3_tyra3_larvae", "Exp3_tdc1_larvae", "Exp3_tbh1_larvae", "Exp3_tbh1tdc1_larvae"]
#inpath_pattern = ['Exp3_suppl_ser2tyra2tyra3_larvae']
#inpath_pattern = ['Exp1_WT_larvae', 'Exp1_WT_OP50',]
#inpath_pattern = ['Exp2_WT_larvae']
inpath_pattern = ["L147", "L157", "L176", "L118", "L119", "L156"]#'L118','L147','L156',
inpath_with_subfolders = True

overwrite = True

In [16]:
for data_str in inpath_pattern:
    ### I/O ################################################
    pattern_dir = [os.path.join(root, name) for root, dirs, files in os.walk(inpath) for name in dirs if data_str in name][0]
    all_files = [os.path.join(root, name) for root, dirs, files in os.walk(pattern_dir) for name in files if name.endswith('json') or name.endswith('csv') ]
    loc_all = {os.path.basename(f):f for f in all_files if 'prediction.json' in f}
    loc_summ = [f for f in all_files if 'summary.csv' in f]
    loc_trans =  [f for f in all_files if 'transitions.csv' in f]
    #loc_onoff = [f for f in all_files if 'onoff.json' in f]

    if outpath is None:
        outpath = os.path.commonpath(loc_all.values())
    JsonOut = os.path.join(outpath,f'{data_str}_batch.json')

    ### Load and save to batch json ##################################
    for fn,fpath in tqdm.tqdm(loc_all.items()):
        id = '_'.join(fn.split('_')[:-1])
        data = pd.read_json(fpath, orient='split')
        y = data['prediction']
        if 8 in np.unique(y):
            print(f"Warning: {fn}")
    
        proba = data.filter(regex='proba')
        idx = proba.columns.str.split('_', expand=True)
        proba.columns = idx
        mean_probas = {cl:np.nanmean(proba.loc[:,('proba',cl)][y == eval(cl)]) for cl in proba.columns.levels[1]}
        
        summ_ = pd.read_csv([l for l in loc_summ if id in l][0])
        
        fr_transition_ = pd.read_csv([l for l in loc_trans if id in l][0], index_col=0)
        fr_transition_[fr_transition_==0] = np.nan # for now until processing in FeedingPrediction is fixed
        fr_transition_tuple = dict(zip(str(list(itertools.product(fr_transition_.columns.astype(int), fr_transition_.index))).strip('[()]').split('), ('), fr_transition_.values.T.flatten()))

        data_mean = data[['velocity', 'rate', 'prediction']].groupby('prediction').mean().reindex(range(-1,8))
        
        # prep of json file structure
        etho = {id:{'count':summ_.duration_count.fillna(0).to_dict(),
                    'mean duration':summ_.duration_mean.to_dict(),
                    'rel time in': summ_.duration_relative.fillna(0).to_dict(),
                    'mean velocity': data_mean.velocity.to_dict(),
                    'mean rate': data_mean.rate.to_dict(),
                    'mean transitions':fr_transition_tuple,
                    'mean prediction probability': mean_probas,
                    'ethogram':y.to_list()}}
        
        # if file exists and overwrite is false
        ow_org = overwrite
        if os.path.isfile(JsonOut) and not overwrite:
            with open(JsonOut, "r") as jsonfile:
                batch = json.load(jsonfile)
        else:
            batch = {}
            overwrite = False
        
        batch.update(etho)
        jsnF = json.dumps(batch, indent = 4, cls=NanConverter)
        with open(JsonOut, "w") as outfile:
            outfile.write(jsnF)

    overwrite = ow_org

100%|██████████| 3/3 [00:00<00:00,  6.41it/s]
100%|██████████| 27/27 [00:06<00:00,  4.38it/s]
100%|██████████| 13/13 [00:03<00:00,  4.01it/s]
100%|██████████| 47/47 [00:12<00:00,  3.84it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 9/9 [00:02<00:00,  4.00it/s]
