Skip to content

Commit

Permalink
Push SessionState into FileFormat (apache#4349)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Dec 21, 2022
1 parent 975ff15 commit 7864b83
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 175 deletions.
57 changes: 37 additions & 20 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use object_store::{GetResult, ObjectMeta, ObjectStore};
use super::FileFormat;
use crate::avro_to_arrow::read_avro_schema_from_reader;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{AvroExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
Expand All @@ -47,6 +48,7 @@ impl FileFormat for AvroFormat {

async fn infer_schema(
&self,
_ctx: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand All @@ -68,6 +70,7 @@ impl FileFormat for AvroFormat {

async fn infer_stats(
&self,
_ctx: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -77,6 +80,7 @@ impl FileFormat for AvroFormat {

async fn create_physical_plan(
&self,
_ctx: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -101,10 +105,11 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let session_ctx = SessionContext::with_config(config);
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches = stream
Expand All @@ -124,9 +129,10 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, Some(1)).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(11, batches[0].num_columns());
Expand All @@ -138,9 +144,10 @@ mod tests {
#[tokio::test]
async fn read_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand Down Expand Up @@ -190,9 +197,10 @@ mod tests {
#[tokio::test]
async fn read_bool_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![1]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -216,9 +224,10 @@ mod tests {
#[tokio::test]
async fn read_i32_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -239,9 +248,10 @@ mod tests {
#[tokio::test]
async fn read_i96_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![10]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -262,9 +272,10 @@ mod tests {
#[tokio::test]
async fn read_f32_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![6]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -288,9 +299,10 @@ mod tests {
#[tokio::test]
async fn read_f64_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![7]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -314,9 +326,10 @@ mod tests {
#[tokio::test]
async fn read_binary_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![9]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&ctx, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -338,14 +351,15 @@ mod tests {
}

async fn get_exec(
ctx: &SessionState,
file_name: &str,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let testdata = crate::test_util::arrow_test_data();
let store_root = format!("{}/avro", testdata);
let format = AvroFormat {};
scan_format(&format, &store_root, file_name, projection, limit).await
scan_format(ctx, &format, &store_root, file_name, projection, limit).await
}
}

Expand All @@ -356,13 +370,16 @@ mod tests {

use super::super::test_util::scan_format;
use crate::error::DataFusionError;
use crate::prelude::SessionContext;

#[tokio::test]
async fn test() -> Result<()> {
let session_ctx = SessionContext::new();
let ctx = session_ctx.state();
let format = AvroFormat {};
let testdata = crate::test_util::arrow_test_data();
let filename = "avro/alltypes_plain.avro";
let result = scan_format(&format, &testdata, filename, None, None).await;
let result = scan_format(&ctx, &format, &testdata, filename, None, None).await;
assert!(matches!(
result,
Err(DataFusionError::NotImplemented(msg))
Expand Down
25 changes: 18 additions & 7 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use super::FileFormat;
use crate::datasource::file_format::file_type::FileCompressionType;
use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -113,6 +114,7 @@ impl FileFormat for CsvFormat {

async fn infer_schema(
&self,
_ctx: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand Down Expand Up @@ -150,6 +152,7 @@ impl FileFormat for CsvFormat {

async fn infer_stats(
&self,
_ctx: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -159,6 +162,7 @@ impl FileFormat for CsvFormat {

async fn create_physical_plan(
&self,
_ctx: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -184,11 +188,12 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let session_ctx = SessionContext::with_config(config);
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
// skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work)
let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let task_ctx = ctx.task_ctx();
let exec = get_exec(&ctx, "aggregate_test_100.csv", projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
Expand All @@ -212,9 +217,10 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let ctx = session_ctx.state();
let task_ctx = session_ctx.task_ctx();
let projection = Some(vec![0, 1, 2, 3]);
let exec = get_exec("aggregate_test_100.csv", projection, Some(1)).await?;
let exec = get_exec(&ctx, "aggregate_test_100.csv", projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(4, batches[0].num_columns());
Expand All @@ -225,8 +231,11 @@ mod tests {

#[tokio::test]
async fn infer_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let ctx = session_ctx.state();

let projection = None;
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let exec = get_exec(&ctx, "aggregate_test_100.csv", projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand Down Expand Up @@ -259,9 +268,10 @@ mod tests {
#[tokio::test]
async fn read_char_column() -> Result<()> {
let session_ctx = SessionContext::new();
let ctx = session_ctx.state();
let task_ctx = session_ctx.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let exec = get_exec(&ctx, "aggregate_test_100.csv", projection, None).await?;

let batches = collect(exec, task_ctx).await.expect("Collect batches");

Expand All @@ -281,12 +291,13 @@ mod tests {
}

async fn get_exec(
ctx: &SessionState,
file_name: &str,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let root = format!("{}/csv", crate::test_util::arrow_test_data());
let format = CsvFormat::default();
scan_format(&format, &root, file_name, projection, limit).await
scan_format(ctx, &format, &root, file_name, projection, limit).await
}
}
32 changes: 22 additions & 10 deletions datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use super::FileScanConfig;
use crate::datasource::file_format::file_type::FileCompressionType;
use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::NdJsonExec;
use crate::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -86,6 +87,7 @@ impl FileFormat for JsonFormat {

async fn infer_schema(
&self,
_ctx: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand Down Expand Up @@ -129,6 +131,7 @@ impl FileFormat for JsonFormat {

async fn infer_stats(
&self,
_ctx: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -138,6 +141,7 @@ impl FileFormat for JsonFormat {

async fn create_physical_plan(
&self,
_ctx: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -161,10 +165,11 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let projection = None;
let exec = get_exec(projection, None).await?;
let session_ctx = SessionContext::with_config(config);
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = None;
let exec = get_exec(&ctx, projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
Expand All @@ -188,9 +193,10 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = None;
let exec = get_exec(projection, Some(1)).await?;
let exec = get_exec(&ctx, projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(4, batches[0].num_columns());
Expand All @@ -202,7 +208,9 @@ mod tests {
#[tokio::test]
async fn infer_schema() -> Result<()> {
let projection = None;
let exec = get_exec(projection, None).await?;
let session_ctx = SessionContext::new();
let ctx = session_ctx.state();
let exec = get_exec(&ctx, projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand All @@ -218,9 +226,10 @@ mod tests {
#[tokio::test]
async fn read_int_column() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let ctx = session_ctx.state();
let task_ctx = ctx.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec(projection, None).await?;
let exec = get_exec(&ctx, projection, None).await?;

let batches = collect(exec, task_ctx).await.expect("Collect batches");

Expand All @@ -243,22 +252,25 @@ mod tests {
}

async fn get_exec(
ctx: &SessionState,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let filename = "tests/jsons/2.json";
let format = JsonFormat::default();
scan_format(&format, ".", filename, projection, limit).await
scan_format(ctx, &format, ".", filename, projection, limit).await
}

#[tokio::test]
async fn infer_schema_with_limit() {
let session = SessionContext::new();
let ctx = session.state();
let store = Arc::new(LocalFileSystem::new()) as _;
let filename = "tests/jsons/schema_infer_limit.json";
let format = JsonFormat::default().with_schema_infer_max_rec(Some(3));

let file_schema = format
.infer_schema(&store, &[local_unpartitioned_file(filename)])
.infer_schema(&ctx, &store, &[local_unpartitioned_file(filename)])
.await
.expect("Schema inference");

Expand Down
Loading

0 comments on commit 7864b83

Please sign in to comment.