Skip to content

Commit

Permalink
TD working fine; ready to push to main
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-navarro committed Apr 4, 2024
1 parent 7ab45ab commit 380c2d3
Show file tree
Hide file tree
Showing 64 changed files with 1,019 additions and 733 deletions.
3 changes: 2 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
^revdep$
^cran-comments\.md$
^CRAN-SUBMISSION$
^Dockerfile
^Dockerfile
^sketch\.R
19 changes: 14 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ Title: Canonical Associative Learning Models and their Representations
Version: 0.6.2
Authors@R:
person("Victor", "Navarro", , "navarrov@cardiff.ac.uk", role = c("aut", "cre", "cph"))
Description: Implementations of canonical associative learning models, with tools to run experiment simulations, estimate model parameters, and compare model representations. Experiments and results are represented using S4 classes and methods.
Description: Implementations of canonical associative learning models,
with tools to run experiment simulations, estimate model parameters,
and compare model representations. Experiments and results are
represented using S4 classes and methods.
License: GPL (>= 3)
URL: https://github.com/victor-navarro/calmr,
https://victornavarro.org/calmr/
Expand All @@ -23,13 +26,15 @@ Imports:
patchwork,
progressr,
rlang,
stats
stats,
tools,
utils
Suggests:
DiagrammeR,
knitr,
rmarkdown,
spelling,
testthat (>= 3.0.0),
DiagrammeR
testthat (>= 3.0.0)
VignetteBuilder:
knitr
Config/testthat/edition: 3
Expand All @@ -45,15 +50,17 @@ Collate:
'PKH1982.R'
'RAND.R'
'ANCCR.R'
'rsa_functions.R'
'RW1972.R'
'SM2007.R'
'TD.R'
'rsa_functions.R'
'compare_models.R'
'data.R'
'fit_helpers.R'
'fit_model.R'
'model_parsers.R'
'model_plots.R'
'plotting_functions.R'
'model_graphs.R'
'model_support_functions.R'
'parse_design.R'
Expand All @@ -67,8 +74,10 @@ Collate:
'get_design.R'
'heidi_helpers.R'
'anccr_helpers.R'
'td_helpers.R'
'calmr_verbosity.R'
'parallel_helpers.R'
'maps.R'
'class_model.R'
'class_design.R'
'class_result.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export(get_graph_opts)
export(get_model)
export(get_optimizer_opts)
export(get_parameters)
export(get_plot_opts)
export(get_timings)
export(make_experiment)
export(model_outputs)
Expand All @@ -18,6 +19,8 @@ export(parse_design)
export(patch_graphs)
export(patch_plots)
export(phase_parser)
export(plot_common_scale)
export(plot_targetted_complex_trials)
export(plot_targetted_tbins)
export(plot_targetted_trials)
export(plot_targetted_typed_trials)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# calmr 0.6.2
* Aggregation of ANCCR data now ignores time; time entries are averaged.
* Added the Temporal Difference model under the name "TD". The model is in an experimental state.
* Experiments for time-based models now require a separate list to construct time-based experiences. See `get_timings()`.
* Added `experiences<-`, `timings`, `timings<-` methods for `CalmrExperiment` class.
* Revamped plotting functions and parsing functions.
* Revamped output names for all models to make them more intelligible.
* Fixed a bug related to the aggregation of pools in HDI2020 and HD2022.
* Consolidated some man pages.

# calmr 0.6.1
* Added `outputs` argument to `run_experiment()`, `parse()`, and `aggregate()`, allowing the user to parse/aggregate only some model outputs.
Expand Down
5 changes: 2 additions & 3 deletions R/ANCCR.R
Original file line number Diff line number Diff line change
Expand Up @@ -311,19 +311,18 @@ ANCCR <- function(
simplify = FALSE
)
# bundle prc and src
psrcs <- list(PRC = threes$prc, SRC = threes$src)
threes <- threes[c("m_ij", "ncs", "anccrs", "cws", "das", "qs", "ps")]

names(twos) <- c(
"ij_elegibilities", "i_elegibilities",
"i_base_rate"
)
names(threes) <- c(
"ij_base_rate", "prcs", "srcs", "net_contingencies",
"ij_base_rate", "net_contingencies",
"anccrs", "causal_weights", "dopamines", "action_values",
"probabilities"
)

psrcs <- threes[c("prcs", "srcs")]

c(twos, threes, list(representation_contingencies = psrcs))
}
11 changes: 4 additions & 7 deletions R/HD2022.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ HD2022 <- function(v = NULL, # nolint: object_name_linter.
# compute combV for all stimuli
combV <- .combV(
v = v, pre_func = fstims,
post_func = fsnames,
db_trial = t
post_func = fsnames
)

# compute chainV for all stimuli with a similarity rule
Expand All @@ -61,8 +60,7 @@ HD2022 <- function(v = NULL, # nolint: object_name_linter.
v = v,
pre_nomi = nstims,
pre_func = fstims,
post_func = fsnames,
db_trial = t
post_func = fsnames
)

# identify absent stimuli and calculate their "retrieved" salience
Expand All @@ -72,12 +70,11 @@ HD2022 <- function(v = NULL, # nolint: object_name_linter.
pre_nomi = nstims,
pre_func = fstims,
fsnames = fsnames,
nomi2func = mapping$nomi2func,
db_trial = t
nomi2func = mapping$nomi2func
)

# Distribute R
r <- .distR(ralphas, combV, chainV, t)
r <- .distR(ralphas, combV, chainV)

# save data
vs[t, , ] <- v
Expand Down
14 changes: 6 additions & 8 deletions R/HDI2020.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,16 @@ HDI2020 <- function(v = NULL, # nolint: object_name_linter.

# compute combV for all stimuli
combV <- .combV(
v = v, pre_func = fstims,
post_func = fsnames,
db_trial = t
v = v,
pre_func = fstims,
post_func = fsnames
)

# compute chainV for all stimuli without a similarity rule
chainV <- .chainV(
v = v,
pre_func = fstims,
post_func = fsnames,
db_trial = t
post_func = fsnames
)

# identify absent stimuli and calculate their "retrieved" salience
Expand All @@ -65,12 +64,11 @@ HDI2020 <- function(v = NULL, # nolint: object_name_linter.
pre_nomi = nstims,
pre_func = fstims,
fsnames = fsnames,
nomi2func = mapping$nomi2func,
db_trial = t
nomi2func = mapping$nomi2func
)

# Distribute R
r <- .distR(ralphas, combV, chainV, t)
r <- .distR(ralphas, combV, chainV)

# save data
vs[t, , ] <- v
Expand Down
2 changes: 1 addition & 1 deletion R/PKH1982.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ PKH1982 <- function(
}
}
results <- list(
associations = list(evs = evs, ivs = ivs),
associations = list(EV = evs, IV = ivs),
associabilities = as,
responses = rs
)
Expand Down
12 changes: 6 additions & 6 deletions R/TD.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
#' as returned by `make_experiment`
#' @param mapping A named list specifying trial and stimulus mapping,
#' as returned by `make_experiment`
#' @param debug Logical specifying whether to print debug information.
#' @param debug_t Whether to invoke a `browser` at
#' the end of a timestep equal to debug_t.
#' the end of a trial equal to debug_t.
#' @param debug_ti Whether to invoke a `browser` at
#' the end of a timestep within a trial equal to debug_ti.
#' @param ... Additional named arguments
#' @return A list with raw results
#' @note This model is in a highly experimental state. Use with caution.
#' @noRd

TD <- function(
parameters, timings, experience,
mapping, debug = FALSE, debug_t = -1,
mapping, debug_t = -1,
debug_ti = -1, ...) {
total_trials <- length(unique(experience$trial))
fsnames <- mapping$unique_functional_stimuli
Expand Down Expand Up @@ -112,14 +113,13 @@ TD <- function(
# add maximal trace of what just happened
e[, ti] <- omat[, ti]
#
if (ti == debug_ti) browser()
if (ti == debug_ti) browser() # nocov
}
}
vs[tn, , ] <- v
es[tn, , ] <- e

if (tn == debug_t) browser()
if (debug) message(tn)
if (tn == debug_t) browser() # nocov
}
list(associations = ws, values = vs, elegibilities = es)
}
21 changes: 12 additions & 9 deletions R/class_experiment.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ methods::setClass(
#' @param ... Extra arguments passed to [calmr_model_graph()]
#' and [calmr_model_plot()].
#' @name CalmrExperiment-methods
#' @seealso [plotting_functions],[calmr_model_plot]
#' @seealso [plotting_functions()],[calmr_model_plot()],[calmr_model_graph()]
NULL
#> NULL

Expand Down Expand Up @@ -169,8 +169,10 @@ methods::setMethod(
function(x, value) {
stopifnot(
"value must be either 1 or length(experiences(x))" =
length(value) == 1 || length(value) == experiences(x)
length(value) == 1 || length(value) == length(experiences(x))
)
x@experiences <- value
x
}
)
#' @noRd
Expand Down Expand Up @@ -338,8 +340,8 @@ setMethod(
#' @noRd
setGeneric("graph", function(x, ...) standardGeneric("graph")) # nocov
#' @rdname CalmrExperiment-methods
#' @return `graph()` returns a list of 'ggplot' plot objects.
#' @aliases graph
#' @return `graph()` returns a list of 'ggplot' plot objects.
#' @export
setMethod("graph", "CalmrExperiment", function(x, ...) {
if (is.null(x@results@aggregated_results)) {
Expand All @@ -354,9 +356,10 @@ setMethod("graph", "CalmrExperiment", function(x, ...) {
assoc_output <- .model_associations(m)
odat <- res[[assoc_output]]
weights <- odat[odat$model == m, ]
if (assoc_output == c("eivs")) {
evs <- weights[weights$type == "evs", ]
ivs <- weights[weights$type == "ivs", ]
if (x@model == "PKH1982") {
browser()
evs <- weights[weights$type == "EV", ]
ivs <- weights[weights$type == "IV", ]
weights <- evs
weights$value <- weights$value - ivs$value
}
Expand All @@ -376,8 +379,8 @@ setMethod("graph", "CalmrExperiment", function(x, ...) {
#' @noRd
methods::setGeneric(
"timings",
function(x) standardGeneric("timings")
) # nocov
function(x) standardGeneric("timings") # nocov
)
#' @noRd
methods::setGeneric(
"timings<-",
Expand All @@ -397,7 +400,7 @@ methods::setMethod(
#' @aliases timings<-
#' @export
methods::setMethod("timings<-", "CalmrExperiment", function(x, value) {
.assert_timings(timings(x), value)
.assert_timings(value, design(x))
x@timings <- value
x
})
8 changes: 4 additions & 4 deletions R/fit_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
#' pars <- get_parameters(df, model = "RW1972")
#' pars$alphas["US"] <- 0.9
#' exper <- make_experiment(df, parameters = pars, model = "RW1972")
#' res <- run_experiment(exper, outputs = "rs")
#' rs <- results(res)$rs$value
#' res <- run_experiment(exper, outputs = "responses")
#' responses <- results(res)$responses$value
#'
#' # define model function
#' model_fun <- function(p, ex) {
#' np <- parameters(ex)
#' np[[1]]$alphas[] <- p
#' parameters(ex) <- np
#' results(run_experiment(ex))$rs$value
#' results(run_experiment(ex))$responses$value
#' }
#'
#' # Get optimizer options
Expand All @@ -38,7 +38,7 @@
#' )
#' optim_opts$initial_pars[] <- rep(.6, 2)
#'
#' fit_model(rs, model_fun, optim_opts,
#' fit_model(responses, model_fun, optim_opts,
#' ex = exper, method = "L-BFGS-B",
#' control = list(maxit = 1)
#' )
Expand Down
19 changes: 1 addition & 18 deletions R/get_timings.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,6 @@ get_timings <- function(design) {
.default_global_timings <- function() {
list(
"use_exponential" = TRUE,
"time_resolution" = 1.0
"time_resolution" = 0.5
)
}

# Returns whether a parameter is a trial parameter
.is_trial_parameter <- function(parameter) {
trial_pars <- list(
"post_trial_delay",
"mean_ITI", "max_ITI"
)
parameter %in% trial_pars
}

# Returns wheter a parameter is a transition parameter
.is_trans_parameter <- function(parameter) {
trans_pars <- list(
"transition_delay"
)
parameter %in% trans_pars
}
Loading

0 comments on commit 380c2d3

Please sign in to comment.