In [27]:
import numpy as np
from scipy.stats import norm


What is the relationship between `get_trial_likelihood` in `ddm.py` and `get_model_log_likelihood` in `ddm_mla.py`?

In [76]:
decay = 0.002  # decay = 0 means barriers are constant.
bias = 0
barrier = 1
RT = 2000
timeStep = 10 #ms
approxStateStep = .1
bias = 0
nonDecisionTime = 0
d = .002
sigma = .02
valueLeft = .7
valueRight = .5

In [77]:
numTimeSteps = RT // timeStep
numTimeSteps

200

In [78]:
barrierUp = barrier * np.ones(numTimeSteps)
barrierDown = -barrier * np.ones(numTimeSteps)
for t in range(1, numTimeSteps):
    barrierUp[t] = barrier / (1 + (decay * t))
    barrierDown[t] = -barrier / (1 + (decay * t))
barrierUp

array([1.        , 0.99800399, 0.99601594, 0.99403579, 0.99206349,
       0.99009901, 0.98814229, 0.98619329, 0.98425197, 0.98231827,
       0.98039216, 0.97847358, 0.9765625 , 0.97465887, 0.97276265,
       0.97087379, 0.96899225, 0.96711799, 0.96525097, 0.96339114,
       0.96153846, 0.9596929 , 0.95785441, 0.95602294, 0.95419847,
       0.95238095, 0.95057034, 0.9487666 , 0.9469697 , 0.94517958,
       0.94339623, 0.94161959, 0.93984962, 0.9380863 , 0.93632959,
       0.93457944, 0.93283582, 0.9310987 , 0.92936803, 0.92764378,
       0.92592593, 0.92421442, 0.92250923, 0.92081031, 0.91911765,
       0.91743119, 0.91575092, 0.91407678, 0.91240876, 0.91074681,
       0.90909091, 0.90744102, 0.9057971 , 0.90415913, 0.90252708,
       0.9009009 , 0.89928058, 0.89766607, 0.89605735, 0.89445438,
       0.89285714, 0.8912656 , 0.88967972, 0.88809947, 0.88652482,
       0.88495575, 0.88339223, 0.88183422, 0.88028169, 0.87873462,
       0.87719298, 0.87565674, 0.87412587, 0.87260035, 0.87108

In [79]:
halfNumStateBins = np.ceil(1 / .1)
halfNumStateBins

10.0

In [80]:
stateStep = 1 / (halfNumStateBins + 0.5)
stateStep

0.09523809523809523

In [81]:
# The vertical axis is divided into states.
states = np.arange(-1 + (stateStep / 2),
                   1 - (stateStep / 2) + stateStep,
                   stateStep)
states

array([-0.95238095, -0.85714286, -0.76190476, -0.66666667, -0.57142857,
       -0.47619048, -0.38095238, -0.28571429, -0.19047619, -0.0952381 ,
        0.        ,  0.0952381 ,  0.19047619,  0.28571429,  0.38095238,
        0.47619048,  0.57142857,  0.66666667,  0.76190476,  0.85714286,
        0.95238095])

In [82]:
biasState = np.argmin(np.absolute(states - bias))

biasState

10

In [83]:
# Initial probability for all states is zero, except the bias state,
# for which the initial probability is one.
prStates = np.zeros((states.size, numTimeSteps))
prStates[biasState,0] = 1

# The probability of crossing each barrier over the time of the trial.
probUpCrossing = np.zeros(numTimeSteps)
probDownCrossing = np.zeros(numTimeSteps)

prStates[biasState]

array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [84]:
changeMatrix = np.subtract(states.reshape(states.size, 1), states)
changeUp = np.subtract(barrierUp, states.reshape(states.size, 1))
changeDown = np.subtract(barrierDown, states.reshape(states.size, 1))
changeMatrix[5]

array([ 0.47619048,  0.38095238,  0.28571429,  0.19047619,  0.0952381 ,
        0.        , -0.0952381 , -0.19047619, -0.28571429, -0.38095238,
       -0.47619048, -0.57142857, -0.66666667, -0.76190476, -0.85714286,
       -0.95238095, -1.04761905, -1.14285714, -1.23809524, -1.33333333,
       -1.42857143])

In [88]:
changeDown[5]

array([-0.52380952, -0.52181352, -0.51982546, -0.51784531, -0.51587302,
       -0.51390853, -0.51195182, -0.51000282, -0.50806149, -0.50612779,
       -0.50420168, -0.50228311, -0.50037202, -0.49846839, -0.49657217,
       -0.49468331, -0.49280177, -0.49092751, -0.48906049, -0.48720066,
       -0.48534799, -0.48350242, -0.48166393, -0.47983247, -0.478008  ,
       -0.47619048, -0.47437987, -0.47257613, -0.47077922, -0.46898911,
       -0.46720575, -0.46542911, -0.46365915, -0.46189583, -0.46013911,
       -0.45838896, -0.45664534, -0.45490822, -0.45317755, -0.45145331,
       -0.44973545, -0.44802394, -0.44631875, -0.44461984, -0.44292717,
       -0.44124072, -0.43956044, -0.43788631, -0.43621828, -0.43455634,
       -0.43290043, -0.43125054, -0.42960663, -0.42796866, -0.4263366 ,
       -0.42471042, -0.4230901 , -0.42147559, -0.41986687, -0.41826391,
       -0.41666667, -0.41507512, -0.41348924, -0.41190899, -0.41033435,
       -0.40876528, -0.40720175, -0.40564374, -0.40409121, -0.40

In [None]:
mean = .0004
sigma = .02
norm.pdf(changeMatrix, mean, sigma)[1]

In [None]:
mean = .0004
sigma = .02
time= 1

prStatesNew = (stateStep *
               np.dot(norm.pdf(changeMatrix, mean, sigma),
               prStates[:,time-1]))
prStatesNew[(states >= barrierUp[time]) |
            (states <= barrierDown[time])] = 0

prStatesNew

In [None]:

tempUpCross = np.dot(prStates[:,time-1], (1 - norm.cdf(changeUp[:, time], mean, sigma)))
tempDownCross = np.dot(prStates[:,time-1],norm.cdf(changeDown[:, time], mean, sigma))

tempUpCross

In [None]:
tempDownCross

In [None]:
# Renormalize to cope with numerical approximations.
sumIn = np.sum(prStates[:,time-1])
sumCurrent = np.sum(prStatesNew) + tempUpCross + tempDownCross
prStatesNew = prStatesNew * sumIn / sumCurrent
tempUpCross = tempUpCross * sumIn / sumCurrent
tempDownCross = tempDownCross * sumIn / sumCurrent

In [None]:
sumIn

In [None]:
sumCurrent

In [None]:
prStatesNew

In [None]:
tempUpCross

In [None]:
tempDownCross

In [None]:
probDownCrossing[-1] = 2

In [None]:
probDownCrossing

In [1]:
from ddm import DDMTrial, DDM

In [15]:
m1 = DDM(d = 0.02, sigma = 0.01)

In [16]:
m1_trial = m1.simulate_trial(valueLeft=.8, valueRight=.1)

In [17]:
m1_trial.choice

-1

In [18]:
m1_trial.RT

630

In [19]:
out = m1.get_trial_likelihood(m1_trial)
out['likelihood']

2.1662862883122818e-132

In [21]:
m2 = DDM(d = 0.1, sigma = 0.1)
out2 = m2.get_trial_likelihood(m1_trial)
out2['likelihood']

1.4346374210445043e-06

In [24]:
out2['likelihood'] > out['likelihood']

True

Generate data using one DDM

In [70]:
valueLefts = np.random.uniform(size = 100)
valueRights = np.random.uniform(size = 100)
m1_trials = []

m1 = DDM(d = 0.08, sigma = 0.01)

for valueLeft, valueRight in zip(valueLefts, valueRights): 
    m1_trials.append(m1.simulate_trial(valueLeft=valueLeft, valueRight=valueRight))


Get trial likelihoods for the correct and incorrect DDMs

In [71]:
m1_likelihoods = []
m2_likelihoods = []
m2 = DDM(d = 0.001, sigma = 0.1)

for trial in m1_trials:
    m1_likelihoods.append(m1.get_trial_likelihood(trial)['likelihood'])
    m2_likelihoods.append(m2.get_trial_likelihood(trial)['likelihood'])

Is the sum of the likelihood using the correct model higher than those of the incorrect model?

In [72]:
sum(m1_likelihoods)> sum(m2_likelihoods)

False

In [73]:
max(m1_likelihoods)

0.020180915157772296

In [74]:
max(m2_likelihoods)

0.004419192028457729

In [75]:
np.round(valueLefts - valueRights, 2)

array([-0.58, -0.23, -0.18,  0.23,  0.48, -0.49,  0.17, -0.54, -0.94,
       -0.32,  0.6 ,  0.46, -0.5 ,  0.31,  0.09,  0.26,  0.02, -0.67,
       -0.53, -0.37,  0.44,  0.5 ,  0.1 ,  0.55,  0.26,  0.48,  0.08,
       -0.  , -0.33, -0.7 , -0.82,  0.32,  0.43, -0.83, -0.03, -0.05,
       -0.2 , -0.  , -0.55, -0.04,  0.42,  0.2 ,  0.27, -0.27,  0.39,
        0.24,  0.1 , -0.61, -0.24,  0.05, -0.47, -0.16,  0.28,  0.51,
       -0.31,  0.38, -0.23,  0.07,  0.31,  0.57,  0.77, -0.65, -0.7 ,
        0.38,  0.2 ,  0.77, -0.27, -0.43, -0.6 ,  0.68, -0.19,  0.4 ,
       -0.45, -0.25, -0.54, -0.05,  0.39, -0.08, -0.27,  0.09,  0.17,
       -0.28, -0.09,  0.02, -0.22,  0.23,  0.33,  0.2 , -0.25, -0.47,
        0.1 ,  0.5 ,  0.14, -0.34, -0.23, -0.52,  0.01,  0.36,  0.35,
       -0.46])