diff --git a/R/plot.R b/R/plot.R index 13dfac2..4dd601a 100755 --- a/R/plot.R +++ b/R/plot.R @@ -174,7 +174,7 @@ plot_gam_splines <- function( #' the uncertainty on the estimation of smoothness parameters. #' @param out_dir Directory in which to save plots #' -#' @return A dataframe of spline differences at each node +#' @return A data frame of spline differences at each node #' @export #' #' @examples @@ -285,12 +285,14 @@ spline_diff <- function(gam_model, #' @param df Data frame. #' @param metrics Name(s) of the metrics to plot per figure, character vector. #' By default, will be all diffusion metrics in the provided data frame. +#' @param node_col Column name in the provided data frame with the node ID, +#' character. #' @param bundles Name(s) of the tract bundles to plot per facet, character #' vector. By default, will be all tract bundles in the provided data #' frame. -#' @param bundles_col Name of the column in the provided data frame with the -#' tract bundles. -#' @param group_col Name of the column in the data frame to group by as a color, +#' @param bundles_col Column names in the provided data frame with the tract +#' bundles, character. +#' @param group_col Column name in the data frame to group by as a color, #' character. By default, no grouping variable is provided. #' @param line_func Line function that provides the line positioning. See #' \link[ggplot2]{stat_summary} for more information. @@ -300,11 +302,14 @@ spline_diff <- function(gam_model, #' @param ribbon_alpha Ribbon alpha level. #' @param n_groups Number of groups to split a numeric grouping variable. #' @param pal_name Grouping color palette name, character. Default is colorblind. +#' @param save_fig Boolean. If TRUE, saves figures in `out_dir`. #' @param out_dir Output directory of saved plots. #' @param figsize Figure size. A numeric vector of (width, height) in inches. #' #' @return List of plot handles corresponding to the specified metrics. #' +#' @importFrom stats family +#' @importFrom rlang .data #' @export #' #' @examples @@ -332,6 +337,7 @@ spline_diff <- function(gam_model, plot_tract_profiles <- function ( df, metrics = NULL, + node_col = "nodeID", bundles = NULL, bundles_col = "tractID", group_col = NULL, @@ -341,6 +347,7 @@ plot_tract_profiles <- function ( ribbon_alpha = 0.25, n_groups = 3, pal_name = "colorblind", + save_fig = FALSE, out_dir = getwd(), figsize = c(8, 11.5) ) { @@ -369,10 +376,16 @@ plot_tract_profiles <- function ( "median_hilow" = ggplot2::median_hilow ) + # bind local variables for plotting + metric <- nodes <- tracts <- value <- NULL + # prepare data.frame for plotting plot_df <- df %>% tidyr::pivot_longer(cols = tidyselect::all_of(metrics), names_to = "metric") %>% - dplyr::rename(tracts = tidyselect::all_of(bundles_col)) %>% + dplyr::rename( + nodes = tidyselect::all_of(node_col), + tracts = tidyselect::all_of(bundles_col) + ) %>% dplyr::filter(tracts %in% bundles, metric %in% metrics) # factorized grouping variable, split into groups if numeric @@ -397,8 +410,9 @@ plot_tract_profiles <- function ( # create current metric figure handle plot_handle <- plot_df %>% dplyr::filter(metric == curr_metric) %>% - ggplot2::ggplot(ggplot2::aes(x = nodeID, y = value, group = .data[[group_col]], - color = .data[[group_col]], fill = .data[[group_col]])) + + ggplot2::ggplot(ggplot2::aes(x = nodes, y = value, + group = .data[[group_col]], color = .data[[group_col]], + fill = .data[[group_col]])) + ggplot2::stat_summary( color = NA, geom = "ribbon", fun.data = ribbon_func, alpha = ribbon_alpha) + ggplot2::stat_summary( @@ -415,17 +429,20 @@ plot_tract_profiles <- function ( plot_handle <- plot_handle + theme(legend.position = "none") } - # save tract profile figure - plot_fname <- paste0("tract-profile_by-", group_col, "_", - stringr::str_replace_all(curr_metric, "_", "-"), ".png") - ggplot2::ggsave( - filename = file.path(out_dir, plot_fname), - plot = plot_handle, - width = figsize[1], - height = figsize[2], - units = "in", - device = "png" - ) + # save tract profile figure if specified + if (save_fig) { + group_name <- ifelse(group_col == "_group", "", paste0("by-", group_col, "_")) + plot_fname <- paste0("tract-profile_", group_name, + stringr::str_replace_all(curr_metric, "_", "-"), ".png") + ggplot2::ggsave( + filename = file.path(out_dir, plot_fname), + plot = plot_handle, + width = figsize[1], + height = figsize[2], + units = "in", + device = "png" + ) + } # collect plot handles by metric plot_handles <- c(plot_handles, list(plot_handle))