Skip to content

Commit

Permalink
Merge pull request #396 from r-dbi/f-arrow-more
Browse files Browse the repository at this point in the history
  • Loading branch information
krlmlr committed Oct 3, 2022
2 parents 76d976b + 6e6945a commit 2803ead
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 25 deletions.
5 changes: 5 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Suggests:
rprojroot,
RSQLite (>= 1.1-2),
testthat,
vctrs,
xml2
VignetteBuilder:
knitr
Expand Down Expand Up @@ -85,6 +86,8 @@ Collate:
'SQLKeywords_missing.R'
'data-types.R'
'data.R'
'dbAppendStream.R'
'dbAppendStream_DBIConnection.R'
'dbAppendTable.R'
'dbAppendTable_DBIConnection.R'
'dbBegin.R'
Expand All @@ -99,6 +102,8 @@ Collate:
'dbCommit.R'
'dbConnect.R'
'dbConnect_DBIConnector.R'
'dbCreateFromStream.R'
'dbCreateFromStream_DBIConnection.R'
'dbCreateTable.R'
'dbCreateTable_DBIConnection.R'
'dbDataType.R'
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export(ANSI)
export(Id)
export(SQL)
export(SQLKeywords)
export(dbAppendStream)
export(dbAppendTable)
export(dbBegin)
export(dbBind)
Expand All @@ -18,6 +19,7 @@ export(dbClearResult)
export(dbColumnInfo)
export(dbCommit)
export(dbConnect)
export(dbCreateFromStream)
export(dbCreateTable)
export(dbDataType)
export(dbDisconnect)
Expand Down Expand Up @@ -83,11 +85,13 @@ exportClasses(DBIResult)
exportClasses(DBIResultStream)
exportClasses(DBIResultStreamDefault)
exportClasses(SQL)
exportMethods(dbAppendStream)
exportMethods(dbAppendTable)
exportMethods(dbBind)
exportMethods(dbCanConnect)
exportMethods(dbClearResult)
exportMethods(dbConnect)
exportMethods(dbCreateFromStream)
exportMethods(dbCreateTable)
exportMethods(dbDataType)
exportMethods(dbExecute)
Expand Down
6 changes: 6 additions & 0 deletions R/dbAppendStream.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#' @name dbWriteTable
#' @aliases dbAppendStream
#' @export
setGeneric("dbAppendStream",
def = function(conn, name, value, ...) standardGeneric("dbAppendStream")
)
23 changes: 23 additions & 0 deletions R/dbAppendStream_DBIConnection.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#' @rdname hidden_aliases
#' @usage NULL
dbAppendStream_DBIConnection <- function(conn, name, value, ...) {
require_arrow()

name <- dbQuoteIdentifier(conn, name)

value <- arrow::as_record_batch_reader(value)

while (TRUE) {
# Append next batch (starting with the first or second, doesn't matter)
tmp <- value$read_next_batch()
if (is.null(tmp)) {
break
}
dbAppendTable(conn, name, as.data.frame(tmp), ...)
}

TRUE
}
#' @rdname hidden_aliases
#' @export
setMethod("dbAppendStream", signature("DBIConnection"), dbAppendStream_DBIConnection)
2 changes: 1 addition & 1 deletion R/dbBind_DBIResultStream.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' @rdname hidden_aliases
#' @usage NULL
dbBind_DBIResultStream <- function(res, params, ...) {
dbBind(res@result, params = params, ...)
dbBind(res@result, params = as.list(as.data.frame(params)), ...)
}
#' @rdname hidden_aliases
#' @export
Expand Down
6 changes: 6 additions & 0 deletions R/dbCreateFromStream.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#' @name dbWriteTable
#' @aliases dbCreateFromStream
#' @export
setGeneric("dbCreateFromStream",
def = function(conn, name, value, ...) standardGeneric("dbCreateFromStream")
)
33 changes: 33 additions & 0 deletions R/dbCreateFromStream_DBIConnection.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#' @rdname hidden_aliases
#' @usage NULL
dbCreateFromStream_DBIConnection <- function(conn, name, value, ..., temporary = FALSE) {
require_arrow()

value <- arrow::as_record_batch_reader(value)

ptype <- get_arrow_ptype(value)
dbCreateTable(conn, name, ptype, ..., temporary = temporary)
}

get_arrow_ptype <- function(value) {
schema <- value$schema
stopifnot(!is.null(schema))

arrays <- lapply(
stats::setNames(schema$fields, schema$names),
function(field) arrow::concat_arrays(type = field$type)
)
vectors <- lapply(
arrays,
function(array) tryCatch(
as.vector(array),
error = function(...) vctrs::unspecified()
)
)

vctrs::new_data_frame(vectors, n = 0L)
}

#' @rdname hidden_aliases
#' @export
setMethod("dbCreateFromStream", signature("DBIConnection"), dbCreateFromStream_DBIConnection)
18 changes: 7 additions & 11 deletions R/dbWriteStream_DBIConnection.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @rdname hidden_aliases
#' @usage NULL
dbWriteStream_DBIConnection <- function(conn, name, value, append = FALSE, overwrite = FALSE, ...) {
dbWriteStream_DBIConnection <- function(conn, name, value, append = FALSE, overwrite = FALSE, ..., temporary = FALSE) {
require_arrow()

name <- dbQuoteIdentifier(conn, name)
Expand All @@ -11,20 +11,16 @@ dbWriteStream_DBIConnection <- function(conn, name, value, append = FALSE, overw
stop("overwrite and append cannot both be TRUE")
}

if (overwrite || !append) {
# Create table *and* append first batch if needed
dbWriteTable(conn, name, as.data.frame(value$read_next_batch()), ..., append = append, overwrite = overwrite)
if (overwrite && dbExistsTable(conn, name)) {
dbRemoveTable(conn, name)
}

while (TRUE) {
# Append next batch (starting with the first or second, doesn't matter)
tmp <- value$read_next_batch()
if (is.null(tmp)) {
break
}
dbAppendTable(conn, name, as.data.frame(tmp), ...)
if (overwrite || !append) {
dbCreateFromStream(conn, name, value, temporary = temporary)
}

dbAppendStream(conn, name, value)

TRUE
}
#' @rdname hidden_aliases
Expand Down
9 changes: 8 additions & 1 deletion man/dbWriteTable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 21 additions & 4 deletions man/hidden_aliases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 18 additions & 2 deletions tests/testthat/test-arrow.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ test_that("write arrow to sqlite", {

tbl <- arrow::as_arrow_table(data)

res <- dbWriteStream(con, "data_tbl", tbl)
dbCreateFromStream(con, "data_tbl", tbl)
dbAppendStream(con, "data_tbl", tbl)

expect_equal(
dbReadTable(con, "data_tbl"),
as.data.frame(tbl)
)

res <- dbWriteStream(con, "data_tbl", tbl, overwrite = TRUE)

expect_equal(
dbReadTable(con, "data_tbl"),
Expand All @@ -30,8 +38,9 @@ test_that("write arrow to sqlite", {
tbl
)

stream <- dbGetStream(con, "SELECT COUNT(*) FROM data_tbl")
expect_equal(
as.data.frame(dbGetStream(con, "SELECT COUNT(*) FROM data_tbl")$read_table())[[1]],
as.data.frame(stream$read_table())[[1]],
nrow(tbl)
)

Expand All @@ -41,4 +50,11 @@ test_that("write arrow to sqlite", {
nrow(tbl)
)
dbClearResult(res)

# Implicit test for dbBind()
stream <- dbGetStream(con, "SELECT * FROM data_tbl WHERE a < $a", params = tbl["a"])
expect_equal(
as.data.frame(stream$read_table()),
as.data.frame(data[c(1, 1:2), ], row.names = 1:3)
)
})
21 changes: 15 additions & 6 deletions vignettes/DBI-arrow.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ stream$read_next_batch()

## Prepared queries

```{r eval = FALSE}
result <- dbGetStream(con, "SELECT COUNT(*) FROM tbl WHERE a < ?")
stream <- dbStream(result, params = ...)
stream$read_next_batch()
stream <- dbStream(result, params = ...)
stream$read_next_batch()
```{r}
in_stream <- arrow::as_arrow_table(data.frame(a = 1:4))
stream <- dbGetStream(con, "SELECT * FROM tbl WHERE a < $a", param = in_stream)
as.data.frame(stream)
```

## Writing data
Expand All @@ -117,6 +115,17 @@ dbWriteStream(con, "tbl_new", stream)
dbReadTable(con, "tbl_new")
```

## Appending data

```{r}
stream <- dbGetStream(con, "SELECT * FROM tbl WHERE a < 3")
dbCreateFromStream(con, "tbl_split", stream)
dbAppendStream(con, "tbl_split", stream)
stream <- dbGetStream(con, "SELECT * FROM tbl WHERE a >= 3")
dbAppendStream(con, "tbl_split", stream)
dbReadTable(con, "tbl_split")
```

As usual, do not forget to disconnect from the database when done.

```{r}
Expand Down

0 comments on commit 2803ead

Please sign in to comment.