In [1]:
import numpy as np
from scipy.stats import norm


What is the relationship between `get_trial_likelihood` in `ddm.py` and `get_model_log_likelihood` in `ddm_mla.py`?

In [6]:
decay = 0.002  # decay = 0 means barriers are constant.
bias = 0
barrier = 1
RT = 2000
timeStep = 10 #ms
approxStateStep = .1
bias = 0
nonDecisionTime = 0
d = .002
sigma = .02
valueLeft = .7
valueRight = .5

In [5]:
numTimeSteps = RT // timeStep
numTimeSteps

200

In [7]:
barrierUp = barrier * np.ones(numTimeSteps)
barrierDown = -barrier * np.ones(numTimeSteps)
for t in range(1, numTimeSteps):
    barrierUp[t] = barrier / (1 + (decay * t))
    barrierDown[t] = -barrier / (1 + (decay * t))
barrierUp

array([1.        , 0.99800399, 0.99601594, 0.99403579, 0.99206349,
       0.99009901, 0.98814229, 0.98619329, 0.98425197, 0.98231827,
       0.98039216, 0.97847358, 0.9765625 , 0.97465887, 0.97276265,
       0.97087379, 0.96899225, 0.96711799, 0.96525097, 0.96339114,
       0.96153846, 0.9596929 , 0.95785441, 0.95602294, 0.95419847,
       0.95238095, 0.95057034, 0.9487666 , 0.9469697 , 0.94517958,
       0.94339623, 0.94161959, 0.93984962, 0.9380863 , 0.93632959,
       0.93457944, 0.93283582, 0.9310987 , 0.92936803, 0.92764378,
       0.92592593, 0.92421442, 0.92250923, 0.92081031, 0.91911765,
       0.91743119, 0.91575092, 0.91407678, 0.91240876, 0.91074681,
       0.90909091, 0.90744102, 0.9057971 , 0.90415913, 0.90252708,
       0.9009009 , 0.89928058, 0.89766607, 0.89605735, 0.89445438,
       0.89285714, 0.8912656 , 0.88967972, 0.88809947, 0.88652482,
       0.88495575, 0.88339223, 0.88183422, 0.88028169, 0.87873462,
       0.87719298, 0.87565674, 0.87412587, 0.87260035, 0.87108

In [3]:
halfNumStateBins = np.ceil(1 / .1)
halfNumStateBins

10.0

In [8]:
stateStep = 1 / (halfNumStateBins + 0.5)
stateStep

0.09523809523809523

In [9]:
# The vertical axis is divided into states.
states = np.arange(-1 + (stateStep / 2),
                   1 - (stateStep / 2) + stateStep,
                   stateStep)
states

array([-0.95238095, -0.85714286, -0.76190476, -0.66666667, -0.57142857,
       -0.47619048, -0.38095238, -0.28571429, -0.19047619, -0.0952381 ,
        0.        ,  0.0952381 ,  0.19047619,  0.28571429,  0.38095238,
        0.47619048,  0.57142857,  0.66666667,  0.76190476,  0.85714286,
        0.95238095])

In [10]:
biasState = np.argmin(np.absolute(states - bias))

biasState

10

In [12]:
# Initial probability for all states is zero, except the bias state,
# for which the initial probability is one.
prStates = np.zeros((states.size, numTimeSteps))
prStates[biasState,0] = 1

# The probability of crossing each barrier over the time of the trial.
probUpCrossing = np.zeros(numTimeSteps)
probDownCrossing = np.zeros(numTimeSteps)

prStates[biasState]

array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [14]:
changeMatrix = np.subtract(states.reshape(states.size, 1), states)
changeUp = np.subtract(barrierUp, states.reshape(states.size, 1))
changeDown = np.subtract(barrierDown, states.reshape(states.size, 1))
changeMatrix[5]

array([ 0.47619048,  0.38095238,  0.28571429,  0.19047619,  0.0952381 ,
        0.        , -0.0952381 , -0.19047619, -0.28571429, -0.38095238,
       -0.47619048, -0.57142857, -0.66666667, -0.76190476, -0.85714286,
       -0.95238095, -1.04761905, -1.14285714, -1.23809524, -1.33333333,
       -1.42857143])

In [18]:
changeDown[0]

array([-0.04761905, -0.04562304, -0.04363498, -0.04165483, -0.03968254,
       -0.03771806, -0.03576134, -0.03381234, -0.03187102, -0.02993732,
       -0.0280112 , -0.02609263, -0.02418155, -0.02227792, -0.02038169,
       -0.01849283, -0.0166113 , -0.01473704, -0.01287001, -0.01101018,
       -0.00915751, -0.00731195, -0.00547345, -0.00364199, -0.00181752,
        0.        ,  0.00181061,  0.00361435,  0.00541126,  0.00720137,
        0.00898473,  0.01076137,  0.01253133,  0.01429465,  0.01605136,
        0.01780151,  0.01954513,  0.02128226,  0.02301292,  0.02473717,
        0.02645503,  0.02816653,  0.02987173,  0.03157064,  0.03326331,
        0.03494976,  0.03663004,  0.03830417,  0.03997219,  0.04163414,
        0.04329004,  0.04493994,  0.04658385,  0.04822182,  0.04985388,
        0.05148005,  0.05310038,  0.05471488,  0.0563236 ,  0.05792657,
        0.05952381,  0.06111536,  0.06270124,  0.06428149,  0.06585613,
        0.0674252 ,  0.06898873,  0.07054674,  0.07209926,  0.07

In [20]:
mean = .0004
sigma = .02
norm.pdf(changeMatrix, mean, sigma)[1]

array([2.61324948e-004, 1.99431250e+001, 2.16002170e-004, 3.32028667e-019,
       7.24345246e-044, 2.24268465e-078, 9.85469919e-123, 6.14569634e-177,
       5.43940490e-241, 6.83257240e-315, 0.00000000e+000, 0.00000000e+000,
       0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
       0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
       0.00000000e+000])

In [23]:
mean = .0004
sigma = .02
time= 1

prStatesNew = (stateStep *
               np.dot(norm.pdf(changeMatrix, mean, sigma),
               prStates[:,time-1]))
prStatesNew[(states >= barrierUp[time]) |
            (states <= barrierDown[time])] = 0

prStatesNew

array([0.00000000e+000, 0.00000000e+000, 6.50721179e-316, 5.18038562e-242,
       5.85304413e-178, 9.38542780e-124, 2.13589014e-079, 6.89852615e-045,
       3.16217778e-020, 2.05716352e-005, 1.89934524e+000, 2.48880903e-005,
       4.62840975e-020, 1.22158753e-044, 4.57583068e-079, 2.43258411e-123,
       1.83534769e-177, 1.96526626e-241, 2.98659966e-315, 0.00000000e+000,
       0.00000000e+000])

In [24]:

tempUpCross = np.dot(prStates[:,time-1], (1 - norm.cdf(changeUp[:, time], mean, sigma)))
tempDownCross = np.dot(prStates[:,time-1],norm.cdf(changeDown[:, time], mean, sigma))

tempUpCross

0.0

In [25]:
tempDownCross

0.0

In [27]:
# Renormalize to cope with numerical approximations.
sumIn = np.sum(prStates[:,time-1])
sumCurrent = np.sum(prStatesNew) + tempUpCross + tempDownCross
prStatesNew = prStatesNew * sumIn / sumCurrent
tempUpCross = tempUpCross * sumIn / sumCurrent
tempDownCross = tempDownCross * sumIn / sumCurrent

In [28]:
sumIn

1.0

In [29]:
sumCurrent

1.89939069745728

In [30]:
prStatesNew

array([0.00000000e+000, 0.00000000e+000, 3.42594699e-316, 2.72739338e-242,
       3.08153775e-178, 4.94128344e-124, 1.12451332e-079, 3.63196796e-045,
       1.66483798e-020, 1.08306497e-005, 9.99976066e-001, 1.31031969e-005,
       2.43678657e-020, 6.43147053e-045, 2.40910450e-079, 1.28071813e-123,
       9.66282341e-178, 1.03468247e-241, 1.57239880e-315, 0.00000000e+000,
       0.00000000e+000])

In [31]:
tempUpCross

0.0

In [32]:
tempDownCross

0.0

In [34]:
probDownCrossing[-1] = 2

In [35]:
probDownCrossing

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2.])