Skip to content

Commit

Permalink
Merge pull request #8 from unimelbmdap/optim_settings
Browse files Browse the repository at this point in the history
Fit settings and vignette update
  • Loading branch information
mariadelmarq committed Mar 18, 2024
2 parents e733fbc + 6c1e510 commit dff7ffd
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 43 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Description: This package implements the Linear Approach to Threshold with
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Imports:
dplyr,
ggplot2,
Expand Down
49 changes: 35 additions & 14 deletions R/fit_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,19 @@ fit_data <- function(
# determine reasonable parameter values to start the optimisation
fit_info$start_points <- calc_start_points(data = data, fit_info = fit_info)

# the parameter values are divided by these values internally within
# the optimiser, to put the parameters on similar scales
# see e.g., https://www.r-bloggers.com/2014/01/tuning-optim-with-parscale/
parscale <- abs(fit_info$start_points)

# increase the number of maximum allowable iterations of the optimiser
maxit <- 1000000

# run the optimiser
fit_info$optim_result <- stats::optim(
fit_info$start_points,
objective_function,
control = list(parscale = parscale, maxit = maxit),
data = data,
fit_info = fit_info,
)
Expand Down Expand Up @@ -136,6 +145,7 @@ fit_data <- function(
return(fit_info)
}


#' Evalulate the cumulative distribution function under the model.
#'
#' @param q Vector of quantiles
Expand Down Expand Up @@ -163,6 +173,7 @@ model_cdf <- function(q, later_mu, later_sd, early_sd = NULL) {
return(p)
}


#' Evalulate the probability density function under the model.
#'
#' @param x Vector of quantiles
Expand Down Expand Up @@ -243,24 +254,31 @@ calc_loglike <- function(data, fit_info) {
return(sum(loglike))
}


# calculate Akaike's 'An Information Criterion'
calc_aic <- function(loglike, n_params) {
k <- 2
aic <- -2 * loglike + k * n_params
return(aic)
}


# parses a vector of parameters into a named list
unpack_params <- function(params, n_a, n_sigma, n_sigma_e) {
# first `n_a` items are the a parameters
a <- params[1:n_a]
# next are the sigma parameters
sigma <- params[(n_a + 1):(n_a + n_sigma)]
# note that the log of sigma is used in the optimiser
log_sigma <- params[(n_a + 1):(n_a + n_sigma)]
sigma <- exp(log_sigma)

labelled_params <- list(a = a, sigma = sigma)

if (n_sigma_e > 0) {
sigma_e <- params[(n_a + n_sigma + 1):length(params)]
# the sigma_e parameter is represented as the log of a multiplier of sigma
log_sigma_e_mult <- params[(n_a + n_sigma + 1):length(params)]
sigma_e_mult <- exp(log_sigma_e_mult)
sigma_e <- sigma * sigma_e_mult
labelled_params$sigma_e <- sigma_e
}

Expand Down Expand Up @@ -293,6 +311,7 @@ convert_a_to_mu_and_k <- function(a, sigma, intercept_form) {
return(list(mu = mu, k = k))
}


# returns the KS statistic given a set of model parameter values
# and the observed data
objective_function <- function(params, data, fit_info) {
Expand Down Expand Up @@ -378,6 +397,8 @@ calc_start_points <- function(data, fit_info) {
dplyr::pull(.data$val)
)

log_sigma_values <- log(sigma_values)

if (fit_info$intercept_form) {
a_values <- mu_values / sigma_values
if (fit_info$n_a == 1) {
Expand All @@ -387,22 +408,20 @@ calc_start_points <- function(data, fit_info) {
a_values <- mu_values
}

start_points <- c(a_values, sigma_values)
start_points <- c(a_values, log_sigma_values)

if (fit_info$with_early_component) {
sigma_e_values <- (
data |>
dplyr::group_by(.data$i_sigma_e) |>
dplyr::summarize(val = stats::sd(.data$promptness) * 3) |>
dplyr::pull(.data$val)
)
# sigma_e is given by exp(log_sigma_e_mult) * sigma
# set each log_sigma_e_mult to log(3), so 3 x sigma
log_sigma_e_mult_values <- log(rep(3, length.out = fit_info$n_sigma_e))

start_points <- c(start_points, sigma_e_values)
start_points <- c(start_points, log_sigma_e_mult_values)
}

return(start_points)
}


# calculates the Kolmogorov-Smirnov statistic
calc_ks_stat <- function(ecdf_p, cdf_p) {
max(abs(ecdf_p - cdf_p))
Expand Down Expand Up @@ -439,10 +458,10 @@ dnorm_with_early <- function(x, later_mu, later_sd, early_sd, log = FALSE) {
exp(-(((x - later_mu)**2) / (2 * later_sd**2)))
* (1 + erf((x - early_mu) / (sqrt(2) * early_sd)))
) / later_sd
+ (
exp(-(((x - early_mu)**2) / (2 * early_sd**2)))
* (1 + erf((x - later_mu) / (sqrt(2) * later_sd)))
) / early_sd
+ (
exp(-(((x - early_mu)**2) / (2 * early_sd**2)))
* (1 + erf((x - later_mu) / (sqrt(2) * later_sd)))
) / early_sd
) / (2 * sqrt(2 * pi))
)

Expand All @@ -453,10 +472,12 @@ dnorm_with_early <- function(x, later_mu, later_sd, early_sd, log = FALSE) {
return(p)
}


erf <- function(x) {
return(2 * stats::pnorm(x * sqrt(2)) - 1)
}


# works out how many parameters there are, given the sharing amongst
# the parameters
set_param_counts <- function(fit_info) {
Expand Down
8 changes: 5 additions & 3 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ reciprobit_plot <- function(
}

# Prepare for deprecation in `trans` argument after ggplot 3.5.0
if (packageVersion("ggplot2") < "3.5.0") {
if (utils::packageVersion("ggplot2") < "3.5.0") {
trans_arg <- list(trans = stats::qnorm)
} else {
trans_arg <- list(transform = stats::qnorm)
Expand Down Expand Up @@ -106,8 +106,10 @@ reciprobit_plot <- function(
do.call(
ggplot2::sec_axis,
c(
list(name = "Z-score",
breaks = z_breaks),
list(
name = "Z-score",
breaks = z_breaks
),
trans_arg
)
)
Expand Down
Loading

0 comments on commit dff7ffd

Please sign in to comment.