Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reading multiple files with spark_read_ #2118

Merged
merged 4 commits into from Aug 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 22 additions & 8 deletions R/data_interface.R
Expand Up @@ -239,10 +239,12 @@ spark_read_parquet <- function(sc,
columns = NULL,
schema = NULL,
...) {
c(name, path) %<-% spark_read_compat_param(sc, name, path)
params <- spark_read_compat_param(sc, name, path)
name <- params[1L]
path <- params[-1L]
javierluraschi marked this conversation as resolved.
Show resolved Hide resolved
if (overwrite) spark_remove_table_if_exists(sc, name)

df <- spark_data_read_generic(sc, list(spark_normalize_path(path)), "parquet", options, columns, schema)
df <- spark_data_read_generic(sc, as.list(spark_normalize_path(path)), "parquet", options, columns, schema)
spark_partition_register_df(sc, df, name, repartition, memory)
}

Expand Down Expand Up @@ -321,10 +323,12 @@ spark_read_json <- function(sc,
overwrite = TRUE,
columns = NULL,
...) {
c(name, path) %<-% spark_read_compat_param(sc, name, path)
params <- spark_read_compat_param(sc, name, path)
name <- params[1L]
path <- params[-1L]
if (overwrite) spark_remove_table_if_exists(sc, name)

df <- spark_data_read_generic(sc, spark_normalize_path(path), "json", options, columns)
df <- spark_data_read_generic(sc, as.list(spark_normalize_path(path)), "json", options, columns)
spark_partition_register_df(sc, df, name, repartition, memory)
}

Expand Down Expand Up @@ -782,12 +786,15 @@ spark_read_text <- function(sc,
options = list(),
whole = FALSE,
...) {
c(name, path) %<-% spark_read_compat_param(sc, name, path)
params <- spark_read_compat_param(sc, name, path)
name <- params[1L]
path <- params[-1L]
if (overwrite) spark_remove_table_if_exists(sc, name)

columns = list(line = "character")

if (identical(whole, TRUE)) {
if (length(path) != 1L) stop("spark_read_text is only suppored with path of length 1 if whole=TRUE.")
path_field <- invoke_static(sc, "sparklyr.SQLUtils", "createStructField", "path", "character", TRUE)
contents_field <- invoke_static(sc, "sparklyr.SQLUtils", "createStructField", "contents", "character", TRUE)
schema <- invoke_static(sc, "sparklyr.SQLUtils", "createStructType", list(path_field, contents_field))
Expand All @@ -796,7 +803,7 @@ spark_read_text <- function(sc,
df <- invoke(hive_context(sc), "createDataFrame", rdd, schema)
}
else {
df <- spark_data_read_generic(sc, list(spark_normalize_path(path)), "text", options, columns)
df <- spark_data_read_generic(sc, as.list(spark_normalize_path(path)), "text", options, columns)
}

spark_partition_register_df(sc, df, name, repartition, memory)
Expand Down Expand Up @@ -868,10 +875,17 @@ spark_read_orc <- function(sc,
columns = NULL,
schema = NULL,
...) {
c(name, path) %<-% spark_read_compat_param(sc, name, path)
params <- spark_read_compat_param(sc, name, path)
name <- params[1L]
path <- params[-1L]

if (length(path) != 1L && (spark_version(sc) < "2.0.0")) {
stop("spark_read_orc is only suppored with path of length 1 for spark versions < 2.0.0")
}

if (overwrite) spark_remove_table_if_exists(sc, name)

df <- spark_data_read_generic(sc, list(spark_normalize_path(path)), "orc", options, columns, schema)
df <- spark_data_read_generic(sc, as.list(spark_normalize_path(path)), "orc", options, columns, schema)
spark_partition_register_df(sc, df, name, repartition, memory)
}

Expand Down
6 changes: 5 additions & 1 deletion R/utils.R
Expand Up @@ -148,7 +148,7 @@ spark_sanitize_names <- function(names, config) {
# that this will take care of path.expand ("~") as well as converting
# relative paths to absolute (necessary since the path will be read by
# another process that has a different current working directory)
spark_normalize_path <- function(path) {
spark_normalize_single_path <- function(path) {
# don't normalize paths that are urls
if (grepl("[a-zA-Z]+://", path)) {
path
Expand All @@ -158,6 +158,10 @@ spark_normalize_path <- function(path) {
}
}

spark_normalize_path <- function(paths) {
unname(sapply(paths, spark_normalize_single_path))
}

stopf <- function(fmt, ..., call. = TRUE, domain = NULL) {
stop(simpleError(
sprintf(fmt, ...),
Expand Down
73 changes: 73 additions & 0 deletions tests/testthat/test-read-write-multiple.R
@@ -0,0 +1,73 @@
context("read-write-multiple")

sc <- testthat_spark_connection()

test_readwrite <- function(sc, writer, reader, name = "testtable", ...) {
path <- file.path(dirname(sc$output_file), c("batch_1", "batch_2"))
path_glob <- file.path(dirname(sc$output_file), "batch*")
on.exit(unlink(path, recursive = TRUE, force = TRUE), add = TRUE)

writer(sdf_copy_to(sc, data.frame(line = as.character(1L:3L))), path[1L])
writer(sdf_copy_to(sc, data.frame(line = as.character(4L:6L))), path[2L])

if (is.element("whole", names(list(...))) && isTRUE(list(...)$whole)) {
res_1 <- reader(sc, name, path[1L], ...) %>% collect() %>% pull(contents) %>% strsplit("\n") %>% unlist() %>% sort()
res_2 <- reader(sc, name, path[2L], ...) %>% collect() %>% pull(contents) %>% strsplit("\n") %>% unlist() %>% sort()
res_3 <- reader(sc, name, path, ...) %>% collect() %>% pull(contents) %>% strsplit("\n") %>% unlist() %>% sort()
res_4 <- reader(sc, name, path_glob, ...) %>% collect() %>% pull(contents) %>% strsplit("\n") %>% unlist() %>% sort()
} else {
res_1 <- reader(sc, name, path[1L], ...) %>% collect() %>% pull(line) %>% sort()
res_2 <- reader(sc, name, path[2L], ...) %>% collect() %>% pull(line) %>% sort()
res_3 <- reader(sc, name, path, ...) %>% collect() %>% pull(line) %>% sort()
res_4 <- reader(sc, name, path_glob, ...) %>% collect() %>% pull(line) %>% sort()
}

list(
all(res_1 == as.character(1:3)),
all(res_2 == as.character(4:6)),
all(res_3 == as.character(1:6)),
all(res_4 == as.character(1:6))
)
}

test_that(
"spark_read_parquet() reads multiple parquet files",
expect_equal(
test_readwrite(sc = sc, writer = spark_write_parquet, reader = spark_read_parquet),
list(TRUE, TRUE, TRUE, TRUE)
)
)

test_that(
"spark_read_orc() reads multiple orc files", {
test_requires_version("2.0.0")
expect_equal(
test_readwrite(sc = sc, writer = spark_write_orc, reader = spark_read_orc),
list(TRUE, TRUE, TRUE, TRUE)
)
}
)

test_that(
"spark_read_json() reads multiple json files",
expect_equal(
test_readwrite(sc = sc, writer = spark_write_json, reader = spark_read_json),
list(TRUE, TRUE, TRUE, TRUE)
)
)

test_that(
"spark_read_text() reads multiple text files",
expect_equal(
test_readwrite(sc = sc, writer = spark_write_text, reader = spark_read_text),
list(TRUE, TRUE, TRUE, TRUE)
)
)

test_that(
"spark_read_text() throws a useful error for multiple files with whole=TRUE",
expect_error(
test_readwrite(sc = sc, writer = spark_write_text, reader = spark_read_text, whole = TRUE),
"spark_read_text is only suppored with path of length 1 if whole=TRUE"
)
)