In [1]:
import pandas as pd
import numpy as np
from embo import InformationBottleneck as IB
import matplotlib.pyplot as plt
import os,ndd,pickle

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
# Get windowed sequence
def get_windowed_xs(x,z,w=1):
    xc = np.array(x.copy())
    zw = z[w:].copy()
    aux_base = 2**np.arange(w)
    xw_binned = np.array([np.inner(xc[i:i+w], aux_base) for i in np.arange(len(zw))]).astype(int)
    return(xw_binned, zw)

# Calculate mutual information using NSB method    
def mutual_inf_nsb(x, y, ks):
    ar = np.column_stack((x, y))
    mi = ndd.mutual_information(ar, ks)
    return np.log2(np.e)*mi # ndd returns nats - multiply by log2(e) to convert to bits

# Compute i_past vs i_future from empirical data
#   x is observations
#   y is ground truth
#   z is responses (choices)
#   w is window size of i_past
def get_i_past_future(x, y, z, w):
    x_w, z_w = get_windowed_xs(x, z, w)
    i_past = mutual_inf_nsb(x_w, z_w, [2**w,2])
    i_future = mutual_inf_nsb(z, y, [2,2])
    return(i_past, i_future)

In [3]:
adat_trials = pd.read_csv("adat_trials_test.csv")
adat_trials

Unnamed: 0.1,Unnamed: 0,Subject,Trial,Jar,Bead,Response,Confidence,Correct,History
0,1000,5d6f,10,0,0,0,0,1,0
1,1001,5d6f,11,0,0,0,1,1,0
2,1002,5d6f,12,0,0,0,1,1,0
3,1003,5d6f,13,0,0,0,1,1,0
4,1004,5d6f,14,0,1,0,1,0,0
...,...,...,...,...,...,...,...,...,...
995,1995,5d6f,1005,1,1,1,0,1,101111111
996,1996,5d6f,1006,1,0,1,0,0,1011111111
997,1997,5d6f,1007,1,0,1,1,0,111111110
998,1998,5d6f,1008,1,0,1,0,0,1111111100


In [4]:
beads = adat_trials['Bead']
choices = adat_trials['Response']
urns = adat_trials['Jar']

choice1 = choices[1:].reset_index(drop = True)
choice1[1000] = 0

In [5]:
w = 6
(ip1, if1) = get_i_past_future(beads, urns, choices, w)
ip2 = IB(beads, choice1, window_size_x=w, window_size_y=1).get_saturation_point()
if2 = IB(choices, urns, window_size_x=1, window_size_y=1).get_saturation_point()

print(f'ipast:   1={ip1:.3f}, 2={ip2:.3f}')
print(f'ifuture: 1={if1:.3f}, 2={if2:.3f}')

ipast:   1=0.557, 2=0.610
ifuture: 1=0.331, 2=0.334


In [7]:
w = 7
(ip1, if1) = get_i_past_future(beads, urns, choices, w)
ip2 = IB(beads, choice1, window_size_x=w, window_size_y=1).get_saturation_point()
if2 = IB(choices, urns, window_size_x=1, window_size_y=1).get_saturation_point()

print(f'ipast:   1={ip1:.3f}, 2={ip2:.3f}')
print(f'ifuture: 1={if1:.3f}, 2={if2:.3f}')

ipast:   1=0.598, 2=nan
ifuture: 1=0.331, 2=0.334
