In [None]:
%load_ext autoreload
%autoreload 2
%precision 2

In [None]:
pwd

In [None]:
import numpy as np
import os,sys; sys.path.append(os.path.expandvars('$CODE_MEMORY_ERRORS'))
from models import Shadmehr_model

# Shadmehr params

# simulate simple markov chain
#change_prob = 0.9  # z in Shadmehr
change_prob = 0.1  # z in Shadmehr
#np.random.seed(5)

def simulate(change_prob = 0.9, seed = None, verbose=1, ntrials = 100):
    if seed is not None:
        np.random.seed(seed)
    switches = np.random.binomial(1, change_prob, ntrials-1)
    start_state = 0
    feedback_recorded = -np.ones( ntrials ) 
    feedback_recorded[0] = start_state
    for i,notsv in enumerate(switches):
        if notsv:
            feedback_recorded[i+1] = feedback_recorded[i]
        else:
            feedback_recorded[i+1] = 1 - feedback_recorded[i]
    #feedback_recorded
    # weights_def not se in Shadmehr, needed manual tuning
    # beta in Shadmehr is 0.05, but it is clearly too large
    #weights_def=0.05
    smodel = Shadmehr_model(ntrials, min_err=-5,max_err=5, weights_def=None,
                           sigma =1., err_sens_def=0.001, beta=0.001,
                           decay=1)
    #smodel = Shadmehr_model(ntrials)
    # processing
    smodel.print_basic()
    for ti in range(ntrials):
        # TODO: update movement somehow
        smodel.update(feedback_recorded[ti])
        if verbose:
            smodel.print_basic()
    
    return smodel, feedback_recorded
    

def plot(feedback_recorded, smodel, ttladd = ''):
    pert_ests = smodel.pert_est[:smodel.prev_trial_ind + 1]
    import matplotlib.pyplot as plt
    nc,nr = 5,1;  ww,hh=10,2
    fig,axs = plt.subplots(nc,nr, figsize=(nr*ww,nc*hh))
    ax = axs[0]
    ax.plot(np.arange(len(pert_ests) ), pert_ests, label='$\hat{x}$')
    ax.plot(np.arange(len(pert_ests) ), feedback_recorded, alpha=0.6, label='$x$')
    ax.set_ylim(-0.1,1.1)
    ax.set_title('Pert estimation vs perturation' + ttladd)
    ax.legend()

    ax = axs[1]
    errors = smodel.errors[:smodel.prev_trial_ind + 1]
    ax.plot(np.arange(len(errors) ), errors, label='errors')
    wms = np.max( smodel.weights[:smodel.prev_trial_ind + 1], axis=1)
    ax.plot(np.arange(len(errors) ), wms, label='max weights')
    ax.set_title('Errors')
    ax.legend()

    ax = axs[2]
    for wi in range(len(smodel.basis_err)):
        ws = smodel.weights[:smodel.prev_trial_ind + 1,wi]
        ax.plot(ws, label=f'{smodel.basis_err[wi]:.2f}')
    ax.legend(loc=(1,0))
    ax.set_title('All weights')

    ax = axs[3]
    es = smodel.err_sens[:smodel.prev_trial_ind + 1]
    ax.plot(es, label=f'es')
    ax.legend()
    ax.set_title('Err sens')

    ax = axs[4]
    #es = smodel.err_sens[:smodel.prev_trial_ind + 1]
    ws_all = smodel.weights[:smodel.prev_trial_ind + 1]
    for wi in range(len(smodel.basis_err)):
        es = []
        for ti in range(ws_all.shape[0]):
            es_cur = smodel.err2err_sens(ws_all[ti], smodel.basis_err[wi]) 
            es += [es_cur]
        ax.plot(np.arange(ntrials),es, label=f'{smodel.basis_err[wi]:.2f}')
    ax.legend(loc=(1,0))
    ax.set_title('Err sens per basis el')
    
    
def plotMeanES(smodels, ttladd = '', ax=None):
    #pert_ests = smodel.pert_est[:smodel.prev_trial_ind + 1]
    import matplotlib.pyplot as plt
    if ax is None:
        ax = plt.gca()
#     nc,nr = 2,1;  ww,hh=4,2
#     fig,axs = plt.subplots(nc,nr, figsize=(nr*ww,nc*hh))
    #ax = axs[0]
    ess = []
    for smodel in smodels:
        es = smodel.err_sens[:smodel.prev_trial_ind + 1]
        ess += [es]
    ess = np.vstack(ess)
    es = ess.mean(0)
    ess_errbar = scipy.stats.sem(ess,0)
    #ess_errbar = ess.std(0)
    print(ess.shape)
    ax.errorbar(np.arange(len(es) ) , es, yerr=ess_errbar, label=f'es')
    ax.legend()
    ax.set_title(f'Err sens mean {ttladd}')
    
    ax.axhline(y=0, ls=':', c='r')

    # TODO: for some reason too abrupt drop of err sens after a wrong thing
    # the more certain we get the larger the weights and 
    # then the more abrupt would be the change

In [None]:
import scipy.stats

In [None]:
smodel, feedback_recorded = simulate(change_prob = 0.9, verbose=0)
plot(feedback_recorded, smodel, ttladd = '')

smodel, feedback_recorded = simulate(change_prob = 0.1, verbose=0)
plot(feedback_recorded, smodel, ttladd = '')

In [None]:
change_probs = [0.9, 0.1]
#change_prob = 0.1

np.random.seed(2)
N =16

nc,nr = 2,1;  ww,hh=5,3
fig,axs = plt.subplots(nc,nr, figsize=(nr*ww,nc*hh))

for axi,change_prob  in enumerate(change_probs):
    smodels = []
    for nt in range(N):
        smodel, feedback_recorded = simulate(change_prob = change_prob, verbose=0)
        smodels += [smodel]

    plotMeanES(smodels, ttladd = f' zeta={change_prob:.1f}', ax=axs[axi])
plt.tight_layout()

In [None]:
__file__

In [None]:
from datetime import datetime  as dt 


In [None]:
import re; help(re.match)

In [None]:
feedback_recorded

In [None]:
# import pymc as pm

# # Assume 10 trials and 5 successes out of those trials
# # Change these numbers to see how the posterior plot changes
# trials = 10; successes = 5

# # Set up model context
# with pm.Model() as coin_flip_model:
#     # Probability p of success we want to estimate
#     # and assign Beta prior
#     p = pm.Beta("p", alpha=1, beta=1)
    
#     # Define likelihood
#     obs = pm.Binomial("obs", p=p, n=trials,
#         observed=successes,
#     )

#     # Hit Inference Button
#     idata = pm.sample()

In [None]:
ntrials = 10
#feedback_recorded = np.random.uniform(size=ntrials)
feedback_recorded = np.ones(ntrials) * 10
dif = max(np.abs(np.max(feedback_recorded)), np.abs( np.min(feedback_recorded) ) )
mn = -dif #np.min(feedback_recorded)
mx = dif  #np.max(feedback_recorded) 
mn -= np.abs(mn) * 0.1
mx += np.abs(mx) * 0.1
smodel = Shadmehr_model(ntrials, min_err=mn,max_err=mx)
#smodel = Shadmehr_model(ntrials)

    
# processing
smodel.print_basic()
for ti in range(ntrials):
    # TODO: update movement somehow
    smodel.update(feedback_recorded[ti])
    smodel.print_basic()

### Q: what shall we simulate?
### Q: how did Shadmehr supply _pertrubations_ to his model (instead of just feedback)
### Q: correct initial conditions = ?