Skip to content

Commit

Permalink
TD model page and bugfix for the update of last compound
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-navarro committed Apr 9, 2024
1 parent 509fdcd commit 28a0ce6
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 91 deletions.
14 changes: 7 additions & 7 deletions R/ANCCR.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,19 @@ ANCCR <- function(
experience[timestep, "time"] -
experience[timestep - 1, "time"]
)
# Update elegibility trace
# Update eligibility trace
e_ij[, timestep] <- e_ij[, timestep - 1] *
gammas[timestep]^(
experience[timestep, "time"] -
experience[timestep - 1, "time"]
)
# Set elegibility trace and anccrs
# Set eligibility trace and anccrs
m_ij[, , timestep] <- m_ij[, , timestep - 1]
anccrs[absents, , timestep] <- anccrs[absents, , timestep - 1]
}
# Delta reset
delta[event, timestep] <- 1
# Increment elegibility trace for the event that occurred by + 1
# Increment eligibility trace for the event that occurred by + 1
e_ij[event, timestep] <- e_ij[event, timestep] + 1
# Update predecessor representation
m_ij[, event, timestep] <- m_ij[, event, timestep] + alphat *
Expand Down Expand Up @@ -234,7 +234,7 @@ ANCCR <- function(
}
}
}
# Update sample elegibility trace
# Update sample eligibility trace
if (timestep < nt) {
# Time to sample baseline b/t events
# VN: The function below is about 100 times faster than the original
Expand All @@ -260,14 +260,14 @@ ANCCR <- function(
}
nextt <- timestep + 1
}
# Update alpha of sample elegibility trace
# Update alpha of sample eligibility trace
alphat <- .anccr_get_alpha(
denom = numsampling + 1,
parameters = parameters,
timestep = timestep
)

# Update average sample elegibility trace
# Update average sample eligibility trace
# Name: Baseline predecessor representation
m_i[, timestep + 1] <- m_i[, timestep] + parameters$k * alphat *
(e_i[, timestep + 1] - m_i[, timestep])
Expand Down Expand Up @@ -315,7 +315,7 @@ ANCCR <- function(
threes <- threes[c("m_ij", "ncs", "anccrs", "cws", "das", "qs", "ps")]

names(twos) <- c(
"ij_elegibilities", "i_elegibilities",
"ij_eligibilities", "i_eligibilities",
"i_base_rate"
)
names(threes) <- c(
Expand Down
158 changes: 111 additions & 47 deletions R/TD.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@
#' as returned by `make_experiment`
#' @param mapping A named list specifying trial and stimulus mapping,
#' as returned by `make_experiment`
#' @param debug_t Whether to invoke a `browser` at
#' the end of a trial equal to debug_t.
#' @param debug_ti Whether to invoke a `browser` at
#' the end of a timestep within a trial equal to debug_ti.
#' @param ... Additional named arguments
#' @return A list with raw results
#' @note This model is in a highly experimental state. Use with caution.
#' @noRd

TD <- function(
parameters, timings, experience,
mapping, debug_t = -1,
debug_ti = -1, ...) {
mapping, ...) {
total_trials <- length(unique(experience$trial))
fsnames <- mapping$unique_functional_stimuli

# join betas
betas <- cbind(parameters$betas_off, parameters$betas_on)
colnames(betas) <- c("off", "on")

# get maximum trial duration
max_tsteps <- max(experience$b_to)

Expand All @@ -33,22 +32,15 @@ TD <- function(
dimnames = list(fsnames, (1:(max_tsteps)) * timings$time_resolution)
)

# array for elegibilities
# array for eligibilities
es <- array(0,
dim = c(total_trials, length(fsnames), max_tsteps),
dimnames = list(
NULL,
fsnames, (1:(max_tsteps)) * timings$time_resolution
)
)
# array for associations (weights)
ws <- array(0,
dim = c(total_trials, length(fsnames), length(fsnames), max_tsteps),
dimnames = list(
NULL, fsnames, fsnames,
(1:(max_tsteps)) * timings$time_resolution
)
)
ws <- es <- vector("list", length = total_trials)

# array for values
vs <- array(0,
Expand All @@ -59,22 +51,59 @@ TD <- function(
)
)

# now the smaller arrays that will get modified over trials
w <- dd <- ws[1, , , ] # associations and their deltas
v <- d <- vs[1, , ] # values and pooled deltas
e <- es[1, , ]
# calculate stimulus durations
s_steps <- sapply(fsnames, function(stim) {
if (stim %in% experience$stimulus) {
with(
experience[experience$stimulus == stim, ],
max(b_to) - min(b_from) + 1
)
}
}, simplify = FALSE)


# now the smaller arrays that will get modified over trials (csc-based)
w <- e <- list()
for (stim in fsnames) {
if (is.null(s_steps[[stim]])) {
# backup in case stim not in experience
s_steps[[stim]] <- max(unlist(s_steps))
}
w[[stim]] <- array(0,
dim = c(length(fsnames), s_steps[[stim]]),
dimnames = list(
fsnames,
seq_len(s_steps[[stim]]) * timings$time_resolution
)
)
e[[stim]] <- array(0,
dim = c(s_steps[[stim]]),
dimnames = list(seq_len(s_steps[[stim]]) * timings$time_resolution)
)
}

v <- d <- vs[1, , ] # values and pooled deltas (time-based)

for (tn in seq_len(total_trials)) {
# get trial data
tdat <- experience[experience$trial == tn, ]
# get trial onehot matrix of active components
omat <- .onehot_mat(base_onehot, tdat$stimulus, tdat$b_from, tdat$b_to)
# save association matrix
ws[tn, , , ] <- w
ws[[tn]] <- w

for (ti in seq_len(max_tsteps)) {
# calculate value expectations for this timestep
v[, ti] <- t(w[, , ti]) %*% omat[, ti]
# build weight matrix
present <- fsnames[which(omat[, ti] > 0)]
if (length(present)) {
tw <- sapply(present, function(stim) {
w[[stim]][, sum(omat[stim, 1:ti])]
})
v[, ti] <- tw %*% omat[present, ti, drop = FALSE]
} else {
v[, ti][] <- 0
}
if (!any(tdat$is_test)) {
if (ti == 1) {
# special treatment of traces and deltas for the first timestep
Expand All @@ -85,41 +114,76 @@ TD <- function(
experience[experience$trial == (tn - 1), ],
max(rtime)
)) / timings$time_resolution
e <- e * (parameters$sigma *
parameters$gamma)^decay_steps
for (stim in fsnames) {
e[[stim]] <- e[[stim]] *
(parameters$sigma * parameters$gamma)^decay_steps
}
}
# delta only depends on current prediction
d[, ti] <- (omat[, ti] * parameters$lambdas) +
(parameters$gamma * v[, ti])
d[, ti] <- (parameters$gamma * v[, ti])
} else {
# delta depends on current and previous prediction
d[, ti] <- (omat[, ti] * parameters$lambdas) +
(parameters$gamma * v[, ti]) -
d[, ti] <- (parameters$gamma * v[, ti]) -
v[, ti - 1]
# decay elegibilities by 1 timestep
e <- e * parameters$sigma *
parameters$gamma
}
# compute update
rates <- parameters$alphas * e
dd[] <- sapply(seq_len(max_tsteps), function(i) {
x <- rates[, i] %*% t(d[, ti, drop = FALSE])
# add events to error term
d[, ti] <- d[, ti] + (omat[, ti] * parameters$lambdas)

# rates of learning
rates <- sapply(fsnames, function(stim) {
e[[stim]] * parameters$alphas[stim]
}, simplify = FALSE)
# compute updates
# trial betas
tbetas <- sapply(fsnames, function(i) betas[i, omat[i, ti] + 1])
dd <- sapply(fsnames, function(stim) {
dhold <- (d[, ti, drop = FALSE] * tbetas) %*% rates[[stim]]
# zero-out self-associations
diag(x) <- 0
x
})
# apply update
w <- w + dd
# add maximal trace of what just happened
e[, ti] <- omat[, ti]
#
if (ti == debug_ti) browser() # nocov
dhold[stim, ][] <- 0
dhold
}, simplify = FALSE)

# apply updates and elegibility traces
for (stim in fsnames) {
w[[stim]][] <- w[[stim]] + dd[[stim]]
# decay eligibilities by 1 timestep
e[[stim]] <- e[[stim]] * parameters$sigma * parameters$gamma
# Add event to
e[[stim]][sum(omat[stim, 1:ti])][] <-
e[[stim]][sum(omat[stim, 1:ti])][] +
omat[, ti][stim]
}
}
}
vs[tn, , ] <- v
es[tn, , ] <- e

if (tn == debug_t) browser() # nocov
# Update the last step separately with a "ghost" step
# delta only depends on the last prediction
gd <- -v[, ti, drop = FALSE]
# decay elegibilities
for (stim in fsnames) {
e[[stim]] <- e[[stim]] * parameters$sigma * parameters$gamma
}

# rates of learning
rates <- sapply(fsnames, function(stim) {
e[[stim]] * parameters$alphas[stim]
}, simplify = FALSE)
# compute updates
# trial betas
tbetas[] <- betas[, 1] # all off
dd <- sapply(fsnames, function(stim) {
dhold <- (gd * tbetas) %*% rates[[stim]]
# zero-out self-associations
dhold[stim, ][] <- 0
dhold
}, simplify = FALSE)

# apply updates
for (stim in fsnames) {
w[[stim]][] <- w[[stim]] + dd[[stim]]
}
vs[tn, , ] <- v
es[[tn]] <- e
}
list(associations = ws, values = vs, elegibilities = es)
list(associations = ws, values = vs, eligibilities = es)
}
4 changes: 2 additions & 2 deletions R/information_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ model_outputs <- function(model = NULL) {
"ANCCR" = c(
"action_values", "anccrs",
"causal_weights", "dopamines",
"ij_elegibilities", "i_elegibilities",
"ij_eligibilities", "i_eligibilities",
"i_base_rate", "ij_base_rate",
"net_contingencies", "probabilities",
"representation_contingencies"
),
"TD" = c("values", "associations", "elegibilities"),
"TD" = c("values", "associations", "eligibilities"),
"RAND" = c("associations", "responses")
)
if (is.null(model)) {
Expand Down
Loading

0 comments on commit 28a0ce6

Please sign in to comment.