# Prediction of Foraging States of *P. pacificus*

This notebook will guide you through the prediction pipeline for foraging behaviours in *Pristionchus pacificus*.<br>
You will already need to have data that was extracted by PharaGlow.<br>

The single steps of this pipeline are the following:
1. additional feature calculation
2. model and augmentation loading
3. data augmentation as defined by AugmentSelect file
4. prediction
5. visualisation

In [None]:
import os
import time
import tqdm
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, FancyArrowPatch
import logging
import yaml
import json
import joblib
from sklearn.impute import SimpleImputer
from scipy.stats.contingency import crosstab
import networkx as nx
from matplotlib.lines import Line2D
import umap
import itertools

#home = os.path.expanduser("~")
sys.path.append(os.getcwd())
from functions.load_model import load_tolist
import functions.visualise as vis
import functions.process as proc
from functions.io import setup_logger, makedir
from functions import FeatureEngine
sys.path.append(os.path.expanduser('~'))
from PpaPy.processing.preprocess import addhistory, select_features

from numba import jit
# set invalid (division by zero error) to ignore
np.seterr(invalid='ignore')

Please provide where your files are stored and where you would like your data to be saved in the following section.

In [None]:
inpath = f"/gpfs/soma_fs/scratch/src/boeger/data_gueniz/"
#inpath = '/gpfs/soma_fs/gnb/gnb9201.bak/Mariannne Roca/MR_MS_pharaglow'
inpath_with_subfolders = True
inpath_pattern = ['Exp1_WT_larvae', 'Exp1_WT_OP50'] #exp1
#inpath_pattern = ['Exp2_WT_larvae',	'Exp2_tph1_larvae', 'Exp2_cat2_larvae',]# 
#inpath_pattern = ['Exp2_tbh1_larvae',  'Exp2_tdc1_larvae'] #exp2

base_outpath = makedir('/gpfs/soma_fs/scratch/src/boeger/PpaPred_eren')
#base_outpath = makedir('/gpfs/soma_fs/scratch/src/boeger/data_roca')

In [None]:
date = time.strftime("%Y%m%d")
datestr = time.strftime("%Y%m%d-%HH%MM")
home = os.path.expanduser("~")
#bac_data/ larvae_data/ self_data/  nothing_data/  tbh1_OP50/  tdc1_OP50/   tbh1_larvae/  tdc1_larvae/ nhr40_OP50/ tph1_larvae/ 
#cat2_OP50/ cat2_larvae/ tph1_larvae/ tph1_OP50/
#ser3_larvae
if inpath_with_subfolders:
    new_inpath = [os.path.join(inpath, sub) for sub in os.listdir(inpath) if any(pat in sub for pat in inpath_pattern)]
    inpath = new_inpath
else:
    inpath = [inpath]

outpath, out_engine, out_predicted = [],[],[]
for p in inpath:
    in_folder = os.path.basename(p)
    outpath.append(makedir(os.path.abspath(f"{base_outpath}/{in_folder}"))) # you can also use datestr to specify the outpath folder, like this makedir(os.path.abspath(f"{datestr}_PpaPrediction"))
    out_engine.append(os.path.join(outpath[-1], in_folder+'_engine'))
    #out_predicted.append(os.path.join(outpath[-1], in_folder+'_predicted'))

In [None]:
os.path.commonpath(inpath)

In the following section, standard model parameters are set. Change those only if necessary.

In [None]:
config = yaml.safe_load(open("config.yml", "r"))

In [None]:
cluster_color = config['cluster_color']
cluster_group = config['cluster_group_man']
cluster_label = config['cluster_names']
clu_group_label = {_:f'{_}, {__}' for _, __ in tuple(zip([c for c in cluster_label.values()],[g for g in cluster_group.values()]))}
skip_already = config['settings']['skip_already']

In [None]:
model_path = config['settings']['model']
version = os.path.basename(model_path).split("_")[1].split(".")[0]
ASpath = config['settings']['ASpath']
smooth = config['settings']['fbfill']
fps = config['settings']['fps']
engine_done = []
prediction_done = []

logger_out = os.path.join(base_outpath,f"{datestr}_PpaForagingPrediction.log")
logger = setup_logger('logger',filename=logger_out)
logger.info(f"Foraging prediction of Pristionchus pacificus")
logger.info(f"Version of model == {version}, stored at {model_path}\n")
log_inpath = '\n'.join(inpath)
logger.info(f"Files to be predicted stored at:\n{log_inpath}")

## 1. Feature Engineering
In the following section, additional features are calculated.<br>
The engineerd data files are saved under the specified outpath/subfolder.<br>
(with subfolder being the inpath folder name postfixed by _engine)

In [None]:
XYs, CLines  = FeatureEngine.run(inpath, out_engine, logger, return_XYCLine =True, skip_engine = False, skip_already=False)

## 2. Load Model and Augmentation
Here only the model- and augmentation-files are loaded

In [None]:
import pickle
model = joblib.load(open(model_path, 'rb'))
augsel = joblib.load(ASpath)
imp = SimpleImputer(missing_values=np.nan, strategy='mean')

In [None]:
augsel

In [None]:
out_engine

In [None]:
all_engine = [os.path.join(root, name) for root, dirs, files in os.walk(base_outpath) for name in files if 'engine' in os.path.basename(root) and any(pat in os.path.basename(root) for pat in inpath_pattern)]
all_engine

## 3. Prediction

In [None]:
skip_already = True
for fpath in tqdm.tqdm(all_engine):
    fn = os.path.basename(fpath)
    dir_engine = os.path.dirname(fpath)
    out_predicted = makedir(dir_engine[:-len('engine')]+'predicted')
    out_fn = fn.replace('features', 'predicted')

    if skip_already and out_fn in os.listdir(out_predicted):
        continue
    if not fn[0] == '.' and not out_fn in prediction_done and os.path.isfile(fpath):
        d = load_tolist(fpath, droplabelcol=False)[0]
        X = augsel.fit_transform(d)
        col = X.columns
        X = imp.fit_transform(X)
        
        pred = model.predict(X)
        proba = model.predict_proba(X)
        pred_smooth = proc.ffill_bfill(pred, smooth)
        pred_smooth = np.nan_to_num(pred_smooth,-1)
        proba_max = np.amax(proba, axis=1) ### New
        proba_max_mean = pd.DataFrame(proba_max).rolling(30, min_periods=1).mean().values ### New
        proba_low = np.all(proba_max_mean < .5, axis=1) ### New
        pred_smooth[proba_low] = -1 ### NEW

        #fn = os.path.basename(fn)
        #out_fn = '_'.join(fn.split('_')[:4]+['predicted.json'])
        p_out = pd.concat([d, pd.DataFrame(pred_smooth, columns=['prediction']), pd.DataFrame(model.predict_proba(X), columns=[f'proba_{i}' for i in range(proba.shape[1])])], axis=1)

        jsnL = json.loads(p_out.to_json(orient="split"))
        jsnF = json.dumps(jsnL, indent = 4)
        outpath_p = os.path.join(out_predicted,out_fn)
        with open(outpath_p, "w") as outfile:
            outfile.write(jsnF)

        if 8 in pred_smooth:
            print(f'WARNING! unexpected: {fn}')
            break

### Test

In [None]:
nx.__version__

## 4. Prediction
The augmented + predicted data files are saved under the specified outpath/subfolder.<br>
(with subfolder being the inpath folder name postfixed by _predicted)<br>

In the _predicted, plots of the bouts predicted over time along with the velocity and pumping rate are saved as pdf files.

In [None]:
def transition_plotter(transition_toother, cluster_color, transition_self=None, figsize=(8,6), mut_scale=40, node_size=4000, 
                    other_connectionstyle = "arc3,rad=.15", self_connectionstyle="arc3,rad=0.5", node_alpha = 1, exclude_label = [], clu_group_label=None):
    if transition_self is None:
        #print(transition_self)
        transition_self = transition_toother.copy().diagonal()
        np.fill_diagonal(transition_toother, 0)


        
    A = np.nan_to_num(np.around(transition_toother.T,3))
    G = nx.from_numpy_matrix(A, create_using=nx.DiGraph)
    
    weights = nx.get_edge_attributes(G,'weight').values()
    arr_out = [e[0] for e in G.edges(data=True)]
                        
    color_map = [cluster_color[k] for k in cluster_color if k not in exclude_label]
    edge_color = [cluster_color[c-1] for c in arr_out]
    #edge_alpha = [node_alpha[c] for c in arr_out]
                        
    fig, ax = plt.subplots(1, figsize=figsize)
    fig_w = fig.get_size_inches()[0]
    arrowsize = [w*mut_scale for w in weights]
                        
    if clu_group_label is None:
        labels = dict(zip(range(len(G)),range(len(G))))
    else:
        labels = dict(zip(range(len(G)),  [clu_group_label[k] for k in clu_group_label if k not in exclude_label]))
        
    label_collection = nx.draw_networkx_labels(G, pos=nx.circular_layout(G), ax=ax, labels=labels)
    
    node_collection = nx.draw_networkx_nodes(G, pos=nx.circular_layout(G), ax=ax, node_color = color_map, node_size= node_size, margins=0.1,
                                             alpha= node_alpha,
                                             edgecolors=color_map)
    edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax, 
                                             arrowsize =arrowsize, connectionstyle=other_connectionstyle, arrowstyle="simple",
                                             label=list(weights), node_size=node_size, edge_color=edge_color)
    
    ### to self
    edgelist = [i for i in G.nodes() if transition_self[i] > 0 and transition_self[i] != np.nan]
    selfweights = {i:transition_self[i] for i in edgelist}
    G.add_edges_from([(i,i) for i in edgelist])
    
    for i in edgelist:
        cor = np.round(nx.circular_layout(G)[i],2)
        rad = np.arctan2(*cor)-np.arctan2(0,0)
        rad_s, rad_t = rad-.15, rad+.15
        vl = np.linalg.norm(cor)+.2
        xy_t = [vl*np.sin(rad_s),vl*np.cos(rad_s)]
        xy_s = [vl*np.sin(rad_t),vl*np.cos(rad_t)]
        (A, _, C, D) =  vis.SemiCirc_coordinates(xy_s, xy_t, r=0.2)
        arrow0 = FancyArrowPatch(posA=A, posB=D, connectionstyle=self_connectionstyle, arrowstyle="simple", mutation_scale= selfweights[i]*mut_scale, color=color_map[i])
        arrow1 = FancyArrowPatch(posA=D, posB=C, connectionstyle=self_connectionstyle, arrowstyle="simple", mutation_scale= selfweights[i]*mut_scale, color=color_map[i])
        ax.add_artist(arrow0)
        ax.add_artist(arrow1)
    
        
    for arr_s in np.linspace(0.2,1,5):
        arrow = FancyArrowPatch((1.6, arr_s), (1.9, arr_s), mutation_scale=arr_s*mut_scale, label = arr_s, color='k', alpha=0.5)
        ax.text(1.95, arr_s-0.03, f"{int(arr_s*100)}%")
        ax.add_patch(arrow)
    ax.set_xlim(-2,2)
    ax.set_ylim(-1.5,1.5)
    ax.axis('off')
    return fig

def ethogram_plotter(d, y, onoff,  smooth, cluster_color, figsize=(20,5), fps=30,xtick_spread=30, d_toplot=['velocity', 'rate'], d_bar_alpha =0.3):
    timeinsec = np.array(range(len(d)))/fps
    
    fig, axs = plt.subplots(3,1, figsize=figsize,constrained_layout=True)
    
    for c in np.unique(y).astype(int):
        axs[0].broken_barh(onoff[c],(0,1),facecolors = cluster_color[c])
    axs[0].set_xticks(range(len(timeinsec))[::xtick_spread*fps])
    axs[0].set_xticklabels(timeinsec[::xtick_spread*fps].astype(int))
    axs[0].set_title(f'Cluster preditcion (smoothed {smooth/fps} sec).')
    axs[0].xaxis.set_minor_locator(plt.MultipleLocator(5*fps))
    for i,c in enumerate(d_toplot):
        for c_ in np.unique(y).astype(int):
            #axs[i+1].broken_barh(onoff[c_],(min(d[c]),max(d[c])),facecolors = cluster_color[c_], alpha=0.6, zorder=0)
            axs[i+1].broken_barh(onoff[c_],(np.nanmin(d[c]),np.nanmax(d[c])-np.nanmin(d[c])),facecolors = cluster_color[c_], alpha=d_bar_alpha, zorder=0)
        axs[i+1].plot(d[c].rolling(30, min_periods=0).mean(),c='k')
        axs[i+1].set_xticks(range(len(timeinsec))[::xtick_spread*fps])
        axs[i+1].set_xticklabels(timeinsec[::xtick_spread*fps].astype(int))
        axs[i+1].set_title(f"{c} (smoothed, 1 sec)")
        axs[i+1].xaxis.set_minor_locator(plt.MultipleLocator(5*fps))
    axs[2].set_xlabel('sec')
    
    plt.legend(handles=[Patch(facecolor=cluster_color[i]) for i in np.unique(y).astype(int)],
          labels=[clu_group_label[k] for k in cluster_label if k in np.unique(y)],
          ncol=3, loc='upper left',
          bbox_to_anchor=(0, -0.5))
    fig.suptitle(f'Ethogram of {fn}',fontsize=16)
    return fig
    
def CLtrajectory_plotter(CLine, XY, y, cluster_color, cluster_label, figsize=(10,10)):
    fig, ax = plt.subplots(figsize=(10,10))
    legend_elements = [Line2D([0], [0],color=cluster_color[i], label=cluster_label [i]) for i in cluster_label]
    adjustCL = (CLine-np.nanmean(CLine))+np.repeat(XY.reshape(XY.shape[0],1,XY.shape[1]), CLine.shape[1], axis=1)-np.nanmean(XY, axis=0)# fits better than subtracting 50
    adjustXY = XY-np.nanmean(XY, axis=0)
    for l in np.unique(y).astype(int):
    #for l in [2,3,5,8]:#[1,2,6,7]#[2,3,5,8]
        #if l != 6:
        il = np.where(y == l)[0]
        ax.plot(*adjustCL[il].T, c=cluster_color[l], alpha = 0.1)#cluster_color[l]
            #plt.scatter(XY[:,0][il],XY[:,1][il], marker=".", lw=2, c=bar_c[l], alpha=0.1)
    ax.set_title(fn)
    ax.axis('equal')
    ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1,1))
    return fig

In [None]:
class NpIntEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        return json.JSONEncoder.default(self, obj)

In [None]:
ethograms = True
summaries = True
transitions = True
trajectories = True

all_predicted = [os.path.join(root, name) for root, dirs, files in os.walk(base_outpath) for name in files if 'predicted' in os.path.basename(root) and any(pat in os.path.basename(root) for pat in inpath_pattern) and 'predicted.json' in name]
#for fn in tqdm.tqdm(os.listdir(out_predicted)):
#    if fn[-13:] == 'predicted.csv' or fn[-14:] == 'predicted.json':

In [None]:
len(all_predicted)

In [None]:
for fpath in tqdm.tqdm(all_predicted):
    
    fn = os.path.basename(fpath)
    fn_out = fn.replace('predicted.json','')
    out_predicted = os.path.dirname(fpath)
    
    d = load_tolist(os.path.join(out_predicted,fn), droplabelcol=False)[0]
    y_ps = d['prediction'].values
    d['prediction'].to_csv(os.path.join(out_predicted, fn_out+'prediction.csv'), index=False)
        
    if ethograms:            
        onoff = proc.onoff_dict(y_ps, labels =np.unique(y_ps))
        onoff = {int(k):v for k,v in onoff.items()}
        with open(os.path.join(out_predicted, fn_out+'_onoff.json'), "w") as onoff_out: 
            json.dump(onoff,onoff_out,cls=NpIntEncoder)
        ethogram_plot = ethogram_plotter(d, y_ps, onoff,  smooth, cluster_color)
        #plt.savefig('clusterbouts.pdf')
        plt.savefig(os.path.join(out_predicted, fn_out+'_predictedbouts.pdf'))
        plt.show()

    if summaries:
        idx = pd.IndexSlice
        onoff, dur, transi = proc.onoff_dict(y_ps, labels =np.unique(y_ps), return_duration=True, return_transitions=True)
        data_describe = d.groupby(y_ps).describe().T.loc[idx[:, ['mean','std','count']], :].sort_index(level=0).T
        dur_describe = pd.DataFrame(dur, columns=['duration']).groupby(transi).describe().T.loc[idx[:, ['mean','std','count']], :].sort_index(level=0).T
        dur_describe['duration','relative'] = pd.DataFrame(dur, columns=['duration']).groupby(transi).apply(lambda cd: cd.sum()/len(d))
        summary = pd.concat([dur_describe, data_describe], axis=1)
        summary.index.name = 'cluster'
        summary = summary.T.reset_index(drop=True).set_index(summary.T.index.map('_'.join)).T
        summary = summary.set_index(summary.index.astype(int))
        summary = summary.reindex([k for k in cluster_label if k != -1])
        summary.to_csv(os.path.join(out_predicted, fn_out+'summary.csv'))
    
    if transitions:
        y_ps_transition = pd.DataFrame(y_ps).rolling(30).apply(lambda s: s.mode()[0])[29::30].values.flatten()
        
        trans_col,fr_transition = crosstab(y_ps_transition[1:],y_ps_transition[:-1],
                                           levels=([k for k in cluster_label],[k for k in cluster_label])
                                          )
        #othersum_axis0 = fr_transition.sum(axis=0)-fr_transition.diagonal()
        transition_all = fr_transition/fr_transition.sum(axis=0)
        #transition_toother = fr_transition/othersum_axis0
        #transition_self = fr_transition.diagonal()/(fr_transition.sum(axis=0))
        #np.fill_diagonal(transition_toother, 0)
        
        #transition_merged = transition_toother.copy()
        #diag_idx = np.diag_indices(len(transition_merged))
        #transition_merged[diag_idx] = transition_self
        transition_merged = pd.DataFrame(transition_all, columns = trans_col[1], index=trans_col[0])#.fillna(0) #should not fill nan with 0!
        transition_merged.to_csv(os.path.join(out_predicted, fn_out+'transitions.csv'))
        
        #### TRANSITION PLOT
    
        #transition_plot = transition_plotter(transition_all, cluster_color, node_alpha=summary['duration_relative'].fillna(0).tolist())
        #plt.savefig(os.path.join(out_predicted,fn_out+'clustertransitions.pdf'))
    

    if trajectories:
        XY = XYs[fn.replace('_predicted.json','.json_labeldata.csv')]
        CLine = CLines[fn.replace('_predicted.json','.json_labeldata.csv')]


        
        CLtrajectory_plot = CLtrajectory_plotter(CLine, XY, y_ps, cluster_color, cluster_label, figsize=(10,10),)
        plt.savefig(os.path.join(out_predicted, fn_out+'CLtrajectory.pdf'))

In [None]:
dur_describe = pd.DataFrame(dur, columns=['duration']).groupby(transi).describe().T.loc[idx[:, ['mean','std','count']], :].sort_index(level=0).T
dur_describe

In [None]:
dur_describe['duration','relative'] = pd.DataFrame(dur, columns=['duration']).groupby(transi).apply(lambda cd: cd.sum()/len(d))
dur_describe

In [None]:
np.sum(dur_describe['duration','relative'])

In [None]:
import matplotlib.pyplot as plt
import json
onoff = {1: [(0, 99), (200,99)],
         2: [(100,99)]}

with open("sample.json", "w") as outfile: 
    json.dump(onoff, outfile)

In [None]:
f = open('sample.json',) 
data = json.load(f)
f.close()
data

In [None]:
color = {1:'blue',2:'red'}
for c in [1,2]:
    plt.broken_barh(onoff[c],(0,1),facecolors = color[c])

In [None]:
loc = {}
for fn in os.listdir(out_predicted):
    if "predicted.json" in fn:
        loc[fn]= os.path.join(out_predicted, fn)

data_batch = load_tolist(loc, droplabelcol=False)
data_batch_concat = pd.concat([d for d in data_batch], axis=0)

In [None]:
y_batch_concat = data_batch_concat['prediction']
y_batch = [d['prediction'] for d in data_batch]
fn = out_predicted+'_batch'

In [None]:
def onoff_dict(arr_raw, labels = range(-1,4), return_duration=False, return_transitions = False, return_all = False, treatasone=True):

    if not isinstance(arr_raw, list):
        arr_raw = [arr_raw]

    arr_onoff = {}
    arr_transi =  []
    arr_onset =  []
    arr_onnext =  []
    arr_dur =  []
    total_dur = 0
    for i,a in enumerate(arr_raw):
        if isinstance(a, pd.Series) or isinstance(a, pd.DataFrame):
            a = a.values
        arr_s = a[1:]
        arr = a[:-1]

        transi = np.append(arr[arr != arr_s], arr[-1])
        onset = (np.concatenate([[0],np.where([arr != arr_s])[1]+1]))
        onnext = (np.append(np.array((onset)[1:]), [len(arr)+1]))
        dur = (onnext)-onset
        arr_transi.append(transi)
        arr_onset.append(onset)
        arr_onnext.append(onnext)
        arr_dur.append(dur)

        if treatasone:
            if i > 0:
                total_dur += arr_onnext[i-1][-1]

        for b in np.unique(a):
            b_idx = np.where(transi == b)
            b_onoff = list(zip(onset[b_idx]+total_dur, dur[b_idx]+total_dur))
            if b in arr_onoff.keys():
                arr_onoff[b] = arr_onoff[b]+b_onoff
            else:
                arr_onoff[b] = b_onoff
    if treatasone:
        arr_dur = np.concatenate(arr_dur)
        arr_transi = np.concatenate(arr_transi)
        arr_onset = np.concatenate([a+arr_onnext[i-1][-1] if i > 0 else a for i,a in enumerate(arr_onset)])
        arr_onnext = np.concatenate([a+arr_onnext[i-1][-1] if i > 0 else a for i,a in enumerate(arr_onnext)])
        # might hav to work here on further, change how arr_onset and arr_onnext are daved

    if return_all == True:
        return arr_onoff, arr_dur, arr_transi, arr_onset, arr_onnext
    elif return_transitions == True and return_duration == True:
        return arr_onoff, arr_dur, arr_transi
    elif return_duration == True or return_transitions == True:
        if return_duration == True:
            return arr_onoff, arr_dur
        if return_transitions == True:
            return arr_onoff, arr_transi
    else:
        return arr_onoff

In [None]:
idx = pd.IndexSlice
onoff, dur, transi = onoff_dict(y_batch, labels =np.unique(y_batch_concat), return_duration=True, return_transitions=True)
data_describe = data_batch_concat.groupby(y_batch_concat).describe().T.loc[idx[:, ['mean','std','count']], :].sort_index(level=0).T
dur_describe = pd.DataFrame(dur, columns=['duration']).groupby(transi).describe().T.loc[idx[:, ['mean','std','count']], :].sort_index(level=0).T
dur_describe['duration','relative'] = pd.DataFrame(dur, columns=['duration']).groupby(transi).apply(lambda cd: cd.sum()/len(y_batch_concat))
summary = pd.concat([dur_describe, data_describe], axis=1)
summary.index.name = 'cluster'
summary = summary.T.reset_index(drop=True).set_index(summary.T.index.map('_'.join)).T
summary = summary.set_index(summary.index.astype(int))
summary.to_csv(os.path.join(out_predicted, os.path.basename(outpath)+'_batch_summary.csv'), index=False)

In [None]:
transi

In [None]:
for i,d in enumerate(data_batch):
    frame = d['prediction'].rolling(30).apply(lambda s: s.mode()[0])[29::30].values.flatten()
    trans_col_,fr_transition_ = crosstab(frame[1:], frame[:-1], levels=([k for k in cluster_label if k != -1],[k for k in cluster_label if k != -1]))
    fr_transition_ = pd.read_csv(d) ##################################read transitions.csv should look like normal fr_transition, get trans_col from header
    if i == 0:
        fr_transition = fr_transition_
        trans_col = trans_col_
    if trans_col_ == trans_col:
        fr_transition += fr_transition_
    else:
        print('WARNING')
    #fr_transition/fr_transition.sum(axis=0)

#othersum_axis0 = fr_transition.sum(axis=0)-fr_transition.diagonal()
#transition_toother = fr_transition/othersum_axis0
#transition_self = fr_transition.diagonal()/(othersum_axis0+fr_transition.diagonal())
#np.fill_diagonal(transition_toother, 0)

In [None]:
transition_all = fr_transition/fr_transition.sum(axis=0)
#diag_idx = np.diag_indices(len(transition_toother))
#transition_all[diag_idx] = transition_self
#transition_all = pd.DataFrame(transition_all, columns = trans_col[1], index=trans_col[0])
transition_csv = pd.DataFrame(transition_all, columns = trans_col[1], index=trans_col[0]).fillna(0)
transition_csv.to_csv(os.path.join(out_predicted, os.path.basename(outpath)+'_batch_transitions.csv'), index=False)

In [None]:
transition_plot = transition_plotter(transition_all, cluster_color, node_alpha=dur_describe['duration','relative'].tolist())
plt.text(1.5, -1, f'{in_folder}\nN = {len(data_batch)}', fontsize=12)
plt.savefig(os.path.join(out_predicted, os.path.basename(outpath)+'_batch_transitions.pdf'))
plt.show()

In [None]:
plt.imshow(fr_transition/fr_transition.sum(axis=0))

In [None]:
import numpy as np

class MarkovChain(object):
    def __init__(self, transition_matrix, states):
        """
        Initialize the MarkovChain instance.

        Parameters
        ----------
        transition_matrix: 2-D array
            A 2-D array representing the probabilities of change of 
            state in the Markov Chain.

        states: 1-D array 
            An array representing the states of the Markov Chain. It
            needs to be in the same order as transition_matrix.
        """
        self.transition_matrix = np.atleast_2d(transition_matrix)
        self.states = states
        self.index_dict = {self.states[index]: index for index in 
                           range(len(self.states))}
        self.state_dict = {index: self.states[index] for index in
                           range(len(self.states))}

    def next_state(self, current_state):
        """
        Returns the state of the random variable at the next time 
        instance.

        Parameters
        ----------
        current_state: str
            The current state of the system.
        """
        return np.random.choice(
         self.states, 
         p=self.transition_matrix[self.index_dict[current_state], :]
        )

    def generate_states(self, current_state, no=10):
        """
        Generates the next states of the system.

        Parameters
        ----------
        current_state: str
            The state of the current random variable.

        no: int
            The number of future states to generate.
        """
        future_states = []
        for i in range(no):
            next_state = self.next_state(current_state)
            future_states.append(next_state)
            current_state = next_state
        return np.array(future_states)

In [None]:
transition_matrix = np.nan_to_num(fr_transition/fr_transition.sum(axis=0))
np.round(transition_matrix,3)

In [None]:
markov_chain = MarkovChain(transition_matrix=transition_matrix.T, 
                           states=[0,1,2,3,4,5])

predicted = (markov_chain.generate_states(current_state=4, no=200))
plt.plot(predicted)

In [None]:
def ethogram_only_plotter(y, onoff, cluster_color, fn, figsize=(20,2), fps=30,xtick_spread=100,):
    timeinsec = np.array(range(len(y)))/fps
    
    fig, ax = plt.subplots(1, figsize=figsize,constrained_layout=True)
    
    for c in np.unique(y).astype(int):
        ax.broken_barh(onoff[c],(0,1),facecolors = cluster_color[c])
    ax.set_xticks(range(len(timeinsec))[::xtick_spread*fps])
    ax.set_xticklabels(timeinsec[::xtick_spread*fps].astype(int))
    ax.set_title(f'Cluster preditcion')
    ax.set_xlabel('sec')

    plt.legend(handles=[Patch(facecolor=cluster_color[i]) for i in np.unique(y).astype(int)],
          labels=[clu_group_label[k] for k in cluster_label if k in np.unique(y)],
          ncol=3, loc='upper left',
          bbox_to_anchor=(0, -0.5))
    fig.suptitle(f'Ethogram of {fn}',fontsize=16)
    return fig

In [None]:
onoff_predicted = proc.onoff_dict(predicted, labels =np.unique(y_ps))
ethogram_predicted = ethogram_only_plotter(predicted, onoff_predicted, cluster_color, fn = 'nhr40 predicted', fps =1, xtick_spread=25)