In [6]:
import os
import time
import tqdm
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, FancyArrowPatch
from matplotlib import colors
import logging
import yaml
import json
import joblib
from sklearn.impute import SimpleImputer
from scipy.stats.contingency import crosstab
import networkx as nx
from matplotlib.lines import Line2D
import umap
import itertools
import scipy
from scipy.cluster.hierarchy import linkage, dendrogram
import seaborn as sns

sys.path.append(os.getcwd())
from functions.load_model import load_tolist
import functions.visualise as vis
import functions.process as proc
from functions.io import setup_logger, makedir
from functions import FeatureEngine
from numba import jit

import warnings
warnings.filterwarnings(action='ignore', message='Mean of empty slice')



In [7]:
class NanConverter(json.JSONEncoder):
    def nan2None(self, obj):
        if isinstance(obj, dict):
            return {k:self.nan2None(v) for k,v in obj.items()}
        elif isinstance(obj, list):
            return [self.nan2None(v) for v in obj]
        elif isinstance(obj, float) and np.isnan(obj):
            return None
        return obj
    def encode(self, obj, *args, **kwargs):
        return super().encode(self.nan2None(obj), *args, **kwargs)

In [8]:
def load_data_from_keys(json_file, key):
    def traverse_dict(d, key):
        if isinstance(d, dict):
            if key in d.keys():
                return {key: d[key]}
            else:
                return {k: traverse_dict(v, key) for k, v in d.items()}
        elif isinstance(d, list):
            return [traverse_dict(x, key) for x in d]
        else:
            return d

    with open(json_file, 'r') as f:
        data = json.load(f)
    nested_dict = traverse_dict(data, key)
    return {(innerKey, outerKey): values for outerKey, innerDict in nested_dict.items() for innerKey, values in innerDict.items()} 

Please provide where your files are stored and where you would like your data to be saved in the following section.

In [142]:
datestr = time.strftime("%Y%m%d-%HH%MM")
inpath = "/gpfs/soma_fs/scratch/src/boeger/PpaPred_eren" 
#inpath = '/gpfs/soma_fs/scratch/src/boeger/data_roca'

outpath = "/gpfs/soma_fs/home/boeger/PpaPred/data_eren"

inpath_pattern = ['larvae_data', 'bac_data', 'tph1_larvae','ser3_larvae']
#inpath_pattern = ['tbh1_larvae', 'tdc1_larvae','ser6_larvae']
inpath_pattern = ['octr1_larvae', 'cat2_larvae']
inpath_pattern = ['WT_larvae', 'WT_OP50']
#inpath_pattern = ['Exp2_WT_larvae',	'Exp2_tph1_larvae', 'Exp2_cat2_larvae','Exp2_tbh1_larvae',  'Exp2_tdc1_larvae']
inpath_with_subfolders = True

WT_ordering = False#[1., 0., 2., 6., 8., 3., 4., 7., 5.]
overwrite = True

In [143]:
for data_str in inpath_pattern:
    ### I/O ################################################
    all_files = [os.path.join(root, name) for root, dirs, files in os.walk(inpath) for name in files if 'predicted' in os.path.basename(root) and data_str in os.path.basename(root)]
    loc_all = {os.path.basename(f):f for f in all_files if 'predicted.json' in f}
    loc_summ = [f for f in all_files if 'summary.csv' in f]
    loc_trans =  [f for f in all_files if 'transitions.csv' in f]
    #loc_onoff = [f for f in all_files if 'onoff.json' in f]

    if outpath is None:
        outpath = os.path.commonpath(loc_all.values())
    JsonOut = os.path.join(outpath,f'{data_str}_batch.json')

    ### Load and save to batch json ##################################
    for fn,fpath in tqdm.tqdm(loc_all.items()):
        id = '_'.join(fn.split('_')[:-1])
        data = pd.read_json(fpath, orient='split')
        y = data['prediction']
        if 8 in np.unique(y):
            print(f"Warning: {fn}")
    
        proba = data.filter(regex='proba')
        idx = proba.columns.str.split('_', expand=True)
        proba.columns = idx
        mean_probas = {cl:np.nanmean(proba.loc[:,('proba',cl)][y == eval(cl)]) for cl in proba.columns.levels[1]}
        
        summ_ = pd.read_csv([l for l in loc_summ if id in l][0])
        
        fr_transition_ = pd.read_csv([l for l in loc_trans if id in l][0], index_col=0)
        fr_transition_[fr_transition_==0] = np.nan # for now until processing in FeedingPrediction is fixed
        fr_transition_tuple = dict(zip(str(list(itertools.product(fr_transition_.columns.astype(int), fr_transition_.index))).strip('[()]').split('), ('), fr_transition_.values.T.flatten()))

        data_mean = data[['velocity', 'rate', 'prediction']].groupby('prediction').mean().reindex(range(-1,8))
        
        # prep of json file structure
        etho = {id:{'count':summ_.duration_count.fillna(0).to_dict(),
                    'mean duration':summ_.duration_mean.to_dict(),
                    'rel time in': summ_.duration_relative.fillna(0).to_dict(),
                    'mean velocity': data_mean.velocity.to_dict(),
                    'mean rate': data_mean.rate.to_dict(),
                    'mean transitions':fr_transition_tuple,
                    'mean prediction probability': mean_probas,
                    'ethogram':y.to_list()}}
        
        # if file exists and overwrite is false
        if os.path.isfile(JsonOut) and not overwrite:
            with open(JsonOut, "r") as jsonfile:
                batch = json.load(jsonfile)
        else:
            batch = {}
            ow_org = overwrite
            overwrite = False
        
        batch.update(etho)
        jsnF = json.dumps(batch, indent = 4, cls=NanConverter)
        with open(JsonOut, "w") as outfile:
            outfile.write(jsnF)

    overwrite = ow_org

100%|██████████| 360/360 [08:17<00:00,  1.38s/it]
100%|██████████| 40/40 [00:15<00:00,  2.60it/s]


In [139]:
fr_transition_[fr_transition_==0] = np.nan

In [140]:
fr_transition_

Unnamed: 0,-1,0,1,2,3,4,5,6,7
-1,,,,,,,,,
0,,0.571429,,0.066667,,,,,
1,,,0.987654,0.044444,,,,,
2,,0.428571,0.006173,0.888889,,,0.333333,,
3,,,,,,,,,
4,,,,,,,,,
5,,,0.006173,,,,0.666667,,
6,,,,,,,,,
7,,,,,,,,,


In [125]:
np.where(fr_transition_==0)

(array([0, 4, 5, 7, 8]), array([3, 3, 3, 3, 3]))