Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arguments to $profile() #429

Merged
merged 20 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ Suggests:
bit64,
callr,
data.table,
ggplot2,
knitr,
lubridate,
nanoarrow,
nycflights13,
patrick,
pillar,
rlang,
rmarkdown,
testthat (>= 3.0.0),
tibble,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
- New method `$write_csv()` for `DataFrame` (#414).
- New method `$sink_csv()` for `LazyFrame` (#432).
- New method `$dt$time()` to extract the time from a `datetime` variable (#428).
- Method `$profile()` gains optimization arguments and plot-related arguments (#429).

# polars 0.8.1

Expand Down
112 changes: 107 additions & 5 deletions R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1063,11 +1063,11 @@ LazyFrame_sort = function(
#' table. They must have the same length.
#' @param strategy Strategy for where to find match:
#' * "backward" (default): search for the last row in the right table whose `on`
#' key is less than or equal to the left’s key.
#' key is less than or equal to the left key.
#' * "forward": search for the first row in the right table whose `on` key is
#' greater than or equal to the left’s key.
#' greater than or equal to the left key.
#' * "nearest": search for the last row in the right table whose value is nearest
#' to the left’s key. String keys are not currently supported for a nearest
#' to the left key. String keys are not currently supported for a nearest
#' search.
#' @param tolerance
#' Numeric tolerance. By setting this the join will only be done if the near
Expand Down Expand Up @@ -1357,6 +1357,12 @@ LazyFrame_fetch = function(
#' @description This will run the query and return a list containing the
#' materialized DataFrame and a DataFrame that contains profiling information
#' of each node that is executed.
#'
#' @inheritParams LazyFrame_collect
#' @param show_plot Show a Gantt chart of the profiling result
#' @param truncate_nodes Truncate the label lengths in the Gantt chart to this
#' number of characters. If `0` (default), do not truncate.
#'
#' @details The units of the timings are microseconds.
#'
#' @keywords LazyFrame
Expand Down Expand Up @@ -1397,8 +1403,104 @@ LazyFrame_fetch = function(
#' group_by("Species", maintain_order = TRUE)$
#' agg(pl$col(pl$Float64)$apply(r_func))$
#' profile()
LazyFrame_profile = function() {
.pr$LazyFrame$profile(self) |> unwrap("in $profile()")
LazyFrame_profile = function(
type_coercion = TRUE,
predicate_pushdown = TRUE,
projection_pushdown = TRUE,
simplify_expression = TRUE,
slice_pushdown = TRUE,
comm_subplan_elim = TRUE,
comm_subexpr_elim = TRUE,
streaming = FALSE,
no_optimization = FALSE,
inherit_optimization = FALSE,
collect_in_background = FALSE,
show_plot = FALSE,
truncate_nodes = 0) {

if (isTRUE(no_optimization)) {
predicate_pushdown = FALSE
projection_pushdown = FALSE
slice_pushdown = FALSE
comm_subplan_elim = FALSE
comm_subexpr_elim = FALSE
}

if (isTRUE(streaming)) {
comm_subplan_elim = FALSE
}

lf = self

if (isFALSE(inherit_optimization)) {
lf = self$set_optimization_toggle(
type_coercion,
predicate_pushdown,
projection_pushdown,
simplify_expression,
slice_pushdown,
comm_subplan_elim,
comm_subexpr_elim,
streaming
) |> unwrap("in $profile():")
}

out = lf |>
.pr$LazyFrame$profile() |>
unwrap("in $profile()")

if (isTRUE(show_plot) && requireNamespace("ggplot2", quietly = TRUE)) {
eitsupi marked this conversation as resolved.
Show resolved Hide resolved
timings = out$profile$to_data_frame()
timings$node = factor(timings$node, levels = unique(timings$node))
total_timing = max(timings$end)
if (total_timing > 10000000) {
unit = "s"
total_timing = paste0(total_timing/1000000, "s")
timings$start = timings$start / 1000000
timings$end = timings$end / 1000000
} else if (total_timing > 10000) {
unit = "ms"
total_timing = paste0(total_timing/1000, "ms")
timings$start = timings$start / 1000
timings$end = timings$end / 1000
} else {
unit = "\U00B5s"
total_timing = paste0(total_timing, "\U00B5s")
}

# for some reason, there's an error if I use rlang::.data directly in aes()
.data = rlang::.data
sorhawell marked this conversation as resolved.
Show resolved Hide resolved

plot = ggplot2::ggplot(
timings,
ggplot2::aes(x = .data[["start"]], xend = .data[["end"]],
y = .data[["node"]], yend = .data[["node"]])) +
ggplot2::geom_segment(linewidth = 6) +
ggplot2::xlab(
paste0("Node duration in ", unit, ". Total duration: ", total_timing)
) +
ggplot2::ylab(NULL) +
ggplot2::theme(
axis.text = ggplot2::element_text(size = 12)
)

if (truncate_nodes > 0) {
plot = plot +
ggplot2::scale_y_discrete(
labels = rev(paste0(strtrim(timings$node, truncate_nodes), "...")),
limits = rev
)
} else {
plot = plot +
ggplot2::scale_y_discrete(
limits = rev
)
}

print(plot)
}

out
}

#' @title Explode columns containing a list of values
Expand Down
6 changes: 3 additions & 3 deletions man/DataFrame_join_asof.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/LazyFrame_join_asof.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 56 additions & 1 deletion man/LazyFrame_profile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.