In [1]:
import numpy as np

In [266]:
def forward_model_rem(params, N_tot=50, ratio_sd=0.5, w=20):
    """Generates observed responses from the REM model."""
    
    # Extract params
    g, u, c = params
    
    # Prepare a vector of simulated responses
    # Column 1 - responses (0 - old, 1 - new)
    # Column 2 - true values (0 - old, 1 - new)
    x = np.zeros((N_tot, 2))
    
    # Compute number of study and number of distractor items
    N_s = int(N_tot * ratio_sd)
    N_d = int(N_tot * (1 - ratio_sd))
    
    # --- Feature generation --- #
    # Generate feature matrix
    W = np.random.geometric(g, size=(N_tot, w))
    
    # Extract study and distractor feature matrix
    idx = np.random.permutation(N_tot)
    idx_s = idx[:N_s]
    idx_d = idx[N_s:]
    Ws = W[idx_s, :]
    Wd = W[idx_d, :]
    
    # --- Memory storage --- #
    # Attention to feature with probability u - otherwise feature is 0
    mask_u = np.random.binomial(1, u, size=(N_s, w))
    Ws_stored = Ws * mask_u
    
    # Imperfect 'copy' of attended features
    mask_c = np.random.binomial(1, (1-c), size=(N_s, w))
    mask_attended = mask_c & mask_u
    Ws_stored = np.where(mask_attended, np.random.geometric(g, size=(N_s, w)), Ws_stored)
    
    # --- Retrieval from memory --- #
    # Highly vectorized - perform pairwise matching for all items with broadcasting
    non_zero_mismatch = ((Ws_stored != W[:, np.newaxis, :]) & (Ws_stored != 0)) & (W[:, np.newaxis, :] != 0)
    non_zero_match = (Ws_stored == W[:, np.newaxis, :]) & (Ws_stored != 0)
    
    # Compute non zero non matching
    beta_a = np.log(1 - c)
    na = beta_a * np.sum(non_zero_mismatch, axis=-1)
    
    # Compute non zero matching
    mi = np.where(non_zero_match, np.log( (c + (1 - c) * g * ((1 - g) ** (Ws_stored-1))) /
                                          (g * ((1 - g) ** (Ws_stored - 1))) ), 
                  0) 
    
    mi = np.sum(mi, axis=-1)
    
    # Compute decision rule (0 - old, 1 - new) and true values (0 - old, 1 - new)
    phi = np.mean(np.exp(na + mi), axis=-1)
    
    # Phi > 1 => old (0), Phi < 1 => new (1)
    x[phi < 1, 0] = 1
    # Distractor => 1, otherwise => old (0) 
    x[idx_d, 1] = 1

    return x

In [267]:
%%time
forward_model_rem([0.2, 0.9, 0.3], w=10, ratio_sd=0.5)

Wall time: 1.49 ms


array([[0., 0.],
       [0., 1.],
       [0., 0.],
       [1., 1.],
       [1., 1.],
       [0., 0.],
       [1., 1.],
       [0., 1.],
       [0., 0.],
       [1., 1.],
       [0., 0.],
       [0., 0.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 0.],
       [1., 1.],
       [1., 0.],
       [1., 1.],
       [1., 0.],
       [1., 1.],
       [0., 1.],
       [0., 0.],
       [1., 1.],
       [1., 0.],
       [0., 0.],
       [0., 1.],
       [1., 0.],
       [0., 0.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.],
       [0., 0.],
       [0., 0.],
       [1., 0.],
       [1., 1.],
       [0., 0.],
       [0., 0.],
       [1., 0.],
       [0., 0.],
       [1., 1.],
       [0., 0.],
       [1., 1.],
       [1., 0.],
       [0., 1.],
       [1., 0.],
       [0., 0.],
       [1., 1.]])