### Sliding window and testing

In [1]:
from numpy.random import choice, rand, randn
import numpy as np
import lea  # probability calculations, see https://pypi.org/project/lea/
from sklearn import linear_model
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from dataclasses import dataclass, field
from typing import Optional
import matplotlib.pyplot as plt
import copy
import contourpy as cp
import pandas as pd
import random

In [3]:

from multisensory_playbook import DetectionTask, DetectionTask_Etrans_test, Trials, LinearClassifier
#from multisensory_playbook import detection_params_search_test, create_sliding_window_features


nb_steps = 1000
pm = 0.2 # smaller doesnt mean sparser signals from prey !!!!!!!!
pe = 0.6 # for first value of E
pn = 0.2 # smaller = less noisy (more zeros)
pc = 0.6 #0.3
pi = 0.3 #0.6 
repeats = 5




[[211  63 112  82 242  53 117  53  67]
 [ 10  88   8  79 638  90  15  62  10]
 [ 10  78   6  82 657  77   6  74  10]
 [  8  90  10  90 624  77  15  80   6]
 [  8  67  11  86 642  89  17  72   8]] 9
[[410. 358. 231.   0.   0.   0.   0.   0.   0.]
 [104. 787. 108.   0.   0.   0.   0.   0.   0.]
 [ 98. 808.  93.   0.   0.   0.   0.   0.   0.]
 [113. 793.  93.   0.   0.   0.   0.   0.   0.]
 [111. 781. 107.   0.   0.   0.   0.   0.   0.]] 9


### Borrow code from multisensory_playbook for debugging Trial.counts()

In [793]:
@dataclass
class Task:
    @property
    def random_variables(self):
        return NotImplemented

    def generate_trials(self, repeats, time_steps, random_seed=None):
        if random_seed is not None:
            random.seed(random_seed)  # Set the random seed if provided
        # random variables
        rv = self.random_variables
        M = rv["M"]
        A = rv["A"]
        V = rv["V"]
        # cache calculated joint distribution
        joint_dists = {}
        for m in [-1, 0, 1]:
            if lea.P(M == m) == 0:
                continue
            joint_dists[m] = lea.joint(A, V).given(M == m).calc()
        # generate true target values
        arr_M = np.array(M.random(repeats))
        steps = np.array(
            [joint_dists[m].random(time_steps) for m in arr_M]
        )  # steps has shape (repeats, time_steps, 2)
        if time_steps == 0:
            # print(steps.shape)
            return Trials(
                repeats=repeats,
                time_steps=time_steps,
                task=self,
                M=arr_M,
                A=steps[:, None],
                V=steps[:, None],
            )
        else:
            return Trials(
                repeats=repeats,
                time_steps=time_steps,
                task=self,
                M=arr_M,
                A=steps[:, :, 0],
                V=steps[:, :, 1],
            )

    @property
    def baseline(self):
        if not hasattr(self, "_baseline"):
            M = self.random_variables["M"]
            self._baseline = max([lea.P(M == m) for m in [-1, 0, 1]])
        return self._baseline

    def baseline_reward(self, reward):
        M = self.random_variables["M"]
        probs = np.array([lea.P(M == m) for m in [-1, 0, 1]])
        expected_rewards = np.einsum("m,mg->g", probs, reward)
        return np.max(expected_rewards)


@dataclass
class DetectionTask(Task):
    pm: float
    pe: float
    pn: float
    pc: float
    pi: float

    @property
    def random_variables(self, random_seed=None):
        if hasattr(self, "_random_vars"):
            return self._random_vars
        if random_seed is not None:
            random.seed(random_seed)  # Set the random seed if provided
    
        target = lea.pmf({-1: self.pm * 0.5, 1: self.pm * 0.5, 0: 1 - self.pm})
        emit_if_target = lea.event(self.pe)
        emit_if_no_target = lea.event(0.0)
        emit = target.switch(
            {-1: emit_if_target, 1: emit_if_target, 0: emit_if_no_target}
        )
        signal_dist = {
            (-1, True): lea.pmf({-1: self.pc, +1: self.pi, 0: 1 - self.pc - self.pi}),
            (+1, True): lea.pmf({+1: self.pc, -1: self.pi, 0: 1 - self.pc - self.pi}),
            (0, True): lea.pmf({-1: 0, +1: 0, 0: 1.0}),  # cannot happen
            (-1, False): lea.pmf({-1: self.pn * 0.5, 1: self.pn * 0.5, 0: 1 - self.pn}),
            (0, False): lea.pmf({-1: self.pn * 0.5, 1: self.pn * 0.5, 0: 1 - self.pn}),
            (+1, False): lea.pmf({-1: self.pn * 0.5, 1: self.pn * 0.5, 0: 1 - self.pn}),
        }
        signal = lea.joint(target, emit).switch(signal_dist)
        signal_A, signal_V = signal.clone(n=2, shared=(target, emit))
        self._random_vars = {"M": target, "E": emit, "A": signal_A, "V": signal_V}
        return self._random_vars
    
@dataclass
class Trials:
    repeats: int
    time_steps: int
    M: np.ndarray
    A: np.ndarray
    V: np.ndarray
    task: Task

    """
    # Has been re-implemented to incorporate sliding window functionality
    def counts(self, pairs=True):
        A = self.A
        V = self.V
        if self.time_steps == 0:
            return np.zeros((self.repeats, 6 + 3 * pairs))
        if pairs:
            AV = (A + 1) + 3 * (V + 1)  # shape (repeats, time_steps)
            C = np.apply_along_axis(np.bincount, 1, AV, minlength=9)  # (repeats, 9)
        else:
            CA = np.apply_along_axis(np.bincount, 1, A + 1, minlength=3)  # (repeats, 3)
            CV = np.apply_along_axis(np.bincount, 1, V + 1, minlength=3)  # (repeats, 3)
            C = np.concatenate((CA, CV), axis=1)
        return C
    """
    # Re-implementation of counts to incorporate sliding window functionality 
    def counts(self, windowsize=3, pairs=1): # pairs = 2 implemented!
        A = self.A
        V = self.V
        #print(type(A))

        def calculate_state(draw_sequence):
            # Mapping for the states to digits
            state_to_digit = {-1: 0, 0: 1, 1: 2} # To change -1 to 1
            
            # Convert the draw sequence to a base-3 number
            base_3_number = 0
            for draw in draw_sequence:
                base_3_number = base_3_number * 3 + state_to_digit[draw]
            #print(draw_sequence, base_3_number)

            # The state is the base-3 number
            return base_3_number

        def apply_state(row):
            # Convert row to list and pass it to the calculate_state function
            #print(row.tolist())
            return calculate_state(row.tolist())
        
        if self.time_steps == 0:
            return np.zeros((self.repeats, 6 + 3 * pairs))
        
        if pairs == 0:
            CA = np.apply_along_axis(np.bincount, 1, A + 1, minlength=3)  # (repeats, 3)
            CV = np.apply_along_axis(np.bincount, 1, V + 1, minlength=3)  # (repeats, 3)
            C = np.concatenate((CA, CV), axis=1)

        elif pairs == 1:
            AV = (A + 1) + 3 * (V + 1)  # shape (repeats, time_steps)
            C = np.apply_along_axis(np.bincount, 1, AV, minlength=9)  # (repeats, 9)     

        elif pairs == 2: # consider windows with n number of consecutive AV-pairs
            max_state = 3**(2*windowsize) # 3**(2n) 
            C = np.zeros((self.repeats, max_state))
            for trialnum in range(self.repeats):
                #print(trialnum)
                _A = A[trialnum]
                _V = V[trialnum]
                #print(A.shape)
                df = pd.DataFrame()
                df['A'], df['V'] = _A, _V
                
                if windowsize == 2:
                    df['A-1'], df['V-1'] =  df['A'].shift(1), df['V'].shift(1) # Shifting column down one step
                if windowsize == 3:
                    df['A-1'], df['V-1'] =  df['A'].shift(1), df['V'].shift(1) # Shifting column down one step
                    df['A-2'], df['V-2'] =  df['A'].shift(2), df['V'].shift(2) # Shifting column down one step (window size is 3)
                df = df.dropna()
                #print(df)
                #return df
                
                # Apply the function to each row and store the result in a new column 'state'
                df['state'] = df.apply(apply_state, axis=1)
                # Calculate value counts
                state_counts = df['state'].value_counts()
                

                # Generate a range of numbers representing all possible states
                # Adjust the range based on your specific needs (max_state + 1)
                
                all_possible_states = range(0, max_state)  # Replace max_state with your actual maximum state value

                # Reindex the value counts to include all possible states
                # Fill missing values (states with 0 occurrences) with 0
                state_counts = state_counts.reindex(all_possible_states, fill_value=0)
                #return state_counts
                
                #state_counts = state_counts.values.reshape(1,-1)
                C[trialnum,:] = state_counts
                
                #return state_counts

        return C

In [794]:
task = DetectionTask(pm=pm, pe=pe, pn=pn, pc=pc, pi=pi)
trials = task.generate_trials(time_steps=200, repeats=1)


In [807]:

C0 = (trials.counts(pairs=0)) 
C1 = (trials.counts(pairs=1))
C2 = (trials.counts(windowsize=3, pairs=2))


In [808]:
print(C1.shape)
print(C2.shape)

(1, 9)
(1, 729)


In [809]:
print(np.sum(C1[0]))
print(np.sum(C2))

200
198.0


In [744]:
C2.shape

(100, 9)

In [810]:
# Version without sliding window generator, try to add this functionality within Trials.counts()

# Expand the detection params search to include handling of sliding window features
def detection_params_search_test(p_ranges, nb_trials, nb_steps, tasktype='DetectionTask', trans_prob=None, window_size=2, random_seed=None):
    #classifier_type = LinearClassifier #MAPClassifier ### TEST
    #print('test')
    # Sample task parameters
    while True:
        
        p = {
            k: rand() * (upper - lower) + lower
            for k, (lower, upper) in p_ranges.items()
        }
        # Add debugging prints for parameter values
        #print("Sampled parameters:", p)

        if p["pc"] <= 0.5 * p["pn"]:
            continue
        if p["pi"] >= p["pc"]:
            continue
        if p["pi"] >= 0.5 * p["pn"]:
            continue
        if p["pi"] + p["pc"] <= p["pn"]:
            continue
        if p["pc"] + p["pi"] > 1.0:
            continue
        break
    #print(p)
    # Generate trials
    if tasktype == 'DetectionTask':
        task = DetectionTask(**p)
        full_trials = task.generate_trials(nb_trials, nb_steps, random_seed=random_seed)
        #print("Generated trials with DetectionTask", full_trials) ### TEST
        
    elif tasktype == 'DetectionTask_Etrans_test':
        p['nb_repeats'] = nb_trials
        p['nb_steps'] = nb_steps
        if trans_prob:
            p['trans_prob'] = trans_prob[0] 
        task = DetectionTask_Etrans_test(**p, random_seed=random_seed) 
        full_trials = task.generate_trials
        #print("Generated trials with DetectionTask_Etrans_test", full_trials) ### TEST
        # Reset p
        keys_to_remove = {'nb_repeats', 'nb_steps', 'trans_prob'}
        p = dict(filter(lambda item: item[0] not in keys_to_remove, p.items()))
    
    # Train-test trials :Generate test data separately
    training_size = nb_trials
    testing_size = nb_trials
    training_trials = Trials(
        repeats=training_size,
        time_steps=nb_steps,
        M=full_trials.M,
        A=full_trials.A,
        V=full_trials.V,
        task=task
    )

    testing_trials = Trials(
        repeats=testing_size,
        time_steps=nb_steps,
        M=full_trials.M,
        A=full_trials.A,
        V=full_trials.V,
        task=task
    )

   
    # Calculate accuracy
    accs_tmp = []

    for pairs in [0, 1, 2]:
        #print(pairs)
        # Check if there is only one class in training data. If yes, skip 
        unique_classes_train = np.unique(training_trials.M)
        unique_classes_test = np.unique(testing_trials.M)
        if len(unique_classes_train) == 1:
            #print(f"Skipping training and testing: Only one class ({unique_classes_train[0]}) in training data.")
            #print(f"Params: {p}")
            #accs_tmp.append([])  # Append None or a default value to indicate skipping
            print('train classes = 1')
            return [0.0, 0.0, 0.0], np.array(list(p.values())) * 0.0 # function exits here
            
        if len(unique_classes_test) == 1:
            #print(f"Skipping training and testing: Only one class ({unique_classes_test[0]}) in test data.")
            #print(f"Params: {p}")
            
            return [0.0, 0.0, 0.0], np.array(list(p.values())) * 0.0

        classifier = LinearClassifier(task, pairs=pairs)

        # Train and test the classifier using sliding window features
        trained_classifier = classifier.train(training_trials)
        res = trained_classifier.test(testing_trials)
        accs_tmp.append(res.accuracy)

    # Filter for accuracy
    _, a = np.unique(full_trials.M, return_counts=True)  # majority class classifier
    a = a.max() / a.sum()
    w = 1 - a
    c = (1 + a) / 2
    #print(c - w / 2 * 0.05, c + w / 2 * 0.05) ### TEST
    #return accs_tmp, np.array(list(p.values())) ### TEST: ignore the condition of return for testing 
    
    if (max(accs_tmp) > (c - w / 2 * 0.75)) & (min(accs_tmp) < (c + w / 2 * 0.75)): # 0.75
        return accs_tmp, np.array(list(p.values()))
    else:
        #print('else')
        return [0.0, 0.0, 0.0], np.array(list(p.values())) * 0.0 # three elements in returned list since now we have 3 values for pairs



In [811]:
p_ranges = {
    "pm": (0.0, 1.0),  # p of motion
    "pe": (0.0, 1.0),  # p of emitting given there is motion
    "pc": (0.0, 1.0),  # p correct direction when emitting
    "pn": (0.0, 1.0),  # p not neutral when not emitting
    "pi": (0.0, 0.5),  # p incorrect when emitting
}
p_labels = ["$p_m$", "$p_e$", "$p_c$", "$p_n$", "$p_i$"]

nb_trials = 100# original: 10000
nb_steps = 100 # original: 90
search_size = 10000 # original: 10000
# probability of transitioning from e_current to e_next
trans_prob = {
    0: [0.5, 0.5],  # Probabilities for transitioning from E=0
    1: [0.5, 0.5],  # Probabilities for transitioning from E=1
}
random_seed = 42

In [813]:


for i in range(1, 10):
    print(i, detection_params_search_test(p_ranges, nb_trials, nb_steps, 'DetectionTask', (trans_prob, ), random_seed==random_seed))

UnboundLocalError: cannot access local variable 'full_trials' where it is not associated with a value

In [452]:
accs

[0.0, 0.0, 0.0]

In [420]:
C2.shape

(1, 82)