From 1ebfa03a6e6b12f3024d968f93a65ada2740bd01 Mon Sep 17 00:00:00 2001 From: Flavian Desverne Date: Mon, 11 Mar 2024 10:34:38 +0100 Subject: [PATCH] improve virtual selection iterations do avoid allocations --- libs/prisma-value/src/lib.rs | 4 +- .../src/query_builder/read_query_builder.rs | 8 +- .../src/column_metadata.rs | 26 +- .../src/database/operations/coerce.rs | 94 +++--- .../src/database/operations/read.rs | 9 +- .../src/database/operations/update.rs | 2 +- .../src/database/operations/upsert.rs | 2 +- .../src/database/operations/write.rs | 6 +- .../src/nested_aggregations.rs | 6 +- .../src/query_builder/select/lateral.rs | 2 +- .../src/query_builder/select/mod.rs | 16 +- .../sql-query-connector/src/query_ext.rs | 5 +- .../connectors/sql-query-connector/src/row.rs | 33 +- .../interpreter/query_interpreters/read.rs | 28 +- .../src/query_graph_builder/read/utils.rs | 4 +- query-engine/core/src/response_ir/internal.rs | 171 +++------- query-engine/core/src/response_ir/json_ext.rs | 8 +- query-engine/core/src/result_ast/mod.rs | 16 +- .../query-structure/src/field_selection.rs | 299 +++++++++++++++--- query-engine/schema/src/output_types.rs | 8 + 20 files changed, 431 insertions(+), 316 deletions(-) diff --git a/libs/prisma-value/src/lib.rs b/libs/prisma-value/src/lib.rs index bb76fc6c4f83..18bc68796afc 100644 --- a/libs/prisma-value/src/lib.rs +++ b/libs/prisma-value/src/lib.rs @@ -149,7 +149,7 @@ pub fn parse_datetime(datetime: &str) -> chrono::ParseResult f64 { +pub fn stringify_decimal(decimal: &BigDecimal) -> f64 { decimal.to_string().parse::().unwrap() } @@ -262,7 +262,7 @@ fn serialize_decimal(decimal: &BigDecimal, serializer: S) -> Result().unwrap().serialize(serializer) + stringify_decimal(decimal).serialize(serializer) } fn deserialize_decimal<'de, D>(deserializer: D) -> Result diff --git a/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs b/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs index 27185de5c917..fd4900e620a1 100644 --- a/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs +++ b/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs @@ -361,14 +361,14 @@ impl MongoReadQueryBuilder { ) -> crate::Result { for aggr in virtual_selections { let join = match aggr { - VirtualSelection::RelationCount(rf, filter) => { - let filter = filter - .as_ref() + VirtualSelection::RelationCount(x) => { + let filter = x + .filter() .map(|f| MongoFilterVisitor::new(FilterPrefix::default(), false).visit(f.clone())) .transpose()?; JoinStage { - source: rf.clone(), + source: x.field().clone(), alias: Some(aggr.db_alias()), nested: vec![], filter, diff --git a/query-engine/connectors/sql-query-connector/src/column_metadata.rs b/query-engine/connectors/sql-query-connector/src/column_metadata.rs index 83b9882f91de..76cb04b938e8 100644 --- a/query-engine/connectors/sql-query-connector/src/column_metadata.rs +++ b/query-engine/connectors/sql-query-connector/src/column_metadata.rs @@ -1,10 +1,12 @@ -use query_structure::{FieldArity, FieldSelection, RelationSelection, SelectedField, TypeIdentifier}; +use query_structure::{ + FieldArity, FieldSelection, GroupedSelectedField, GroupedVirtualSelection, RelationSelection, TypeIdentifier, +}; -#[derive(Clone, Debug, Copy)] +#[derive(Clone, Debug)] pub enum MetadataFieldKind<'a> { Scalar, Relation(&'a RelationSelection), - Virtual, + Virtual(GroupedVirtualSelection<'a>), } /// Helps dealing with column value conversion and possible error resolution. @@ -47,8 +49,8 @@ impl<'a> ColumnMetadata<'a> { self.arity } - pub(crate) fn kind(&self) -> MetadataFieldKind { - self.kind + pub(crate) fn kind(&self) -> &MetadataFieldKind<'_> { + &self.kind } } @@ -69,7 +71,7 @@ where .collect() } -pub(crate) fn create_from_selection<'a, T>( +pub(crate) fn create_from_selection_for_json<'a, T>( selection: &'a FieldSelection, field_names: &'a [T], ) -> Vec> @@ -77,15 +79,15 @@ where T: AsRef, { selection - .selections() + .grouped_selections() .zip(field_names.iter()) .map(|(field, name)| { - let (type_identifier, arity) = field.type_identifier_with_arity().unwrap(); + let (type_identifier, arity) = field.type_identifier_with_arity_for_json(); + let kind = match field { - SelectedField::Scalar(_) => MetadataFieldKind::Scalar, - SelectedField::Relation(rs) => MetadataFieldKind::Relation(rs), - SelectedField::Virtual(_) => MetadataFieldKind::Virtual, - SelectedField::Composite(_) => unreachable!(), + GroupedSelectedField::Scalar(_) => MetadataFieldKind::Scalar, + GroupedSelectedField::Relation(rs) => MetadataFieldKind::Relation(rs), + GroupedSelectedField::Virtual(vs) => MetadataFieldKind::Virtual(vs), }; ColumnMetadata::new(type_identifier, arity, kind).set_name(name.as_ref()) diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs b/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs index c8ec5bcb6207..a470957f93ee 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs @@ -5,13 +5,6 @@ use std::{io, str::FromStr}; use crate::{query_arguments_ext::QueryArgumentsExt, SqlError}; -#[inline] -fn fields_to_serialize(rs: &RelationSelection) -> impl Iterator { - rs.result_fields - .iter() - .filter_map(|field_name| rs.selections.iter().find(|f| f.prisma_name().as_ref() == field_name)) -} - pub(crate) fn coerce_json_relation_to_pv( mut value: serde_json::Value, rs: &RelationSelection, @@ -48,25 +41,28 @@ fn internal_coerce_json_relation_to_pv(value: &mut serde_json::Value, rs: &Relat serde_json::Value::Object(obj) => { let mut new_obj = serde_json::Map::with_capacity(obj.len()); - for field in fields_to_serialize(rs) { - let (field_name, mut obj_val) = obj.remove_entry(field.prisma_name().as_ref()).unwrap(); - + for field in rs.grouped_fields_to_serialize() { match field { - SelectedField::Scalar(sf) => { + GroupedSelectedField::Scalar(sf) => { + let (field_name, mut obj_val) = obj.remove_entry(sf.name()).unwrap(); + coerce_json_scalar_to_pv(&mut obj_val, sf)?; + + new_obj.insert(field_name, obj_val); } - SelectedField::Relation(nested_rs) => { + GroupedSelectedField::Relation(nested_rs) => { + let (field_name, mut obj_val) = obj.remove_entry(nested_rs.field.name()).unwrap(); + internal_coerce_json_relation_to_pv(&mut obj_val, nested_rs)?; + + new_obj.insert(field_name, obj_val); } - SelectedField::Virtual(_) => { - todo!() - // let coerced_value = coerce_json_virtual_field_to_pv(&key, value)?; - // map.push((key, coerced_value)); + GroupedSelectedField::Virtual(vs) => { + let (field_name, obj_val) = obj.remove_entry(vs.serialized_name().0).unwrap(); + + new_obj.insert(field_name, reorder_virtuals_group(obj_val, &vs)); } - _ => unreachable!(), } - - new_obj.insert(field_name, obj_val); } *obj = new_obj; @@ -89,18 +85,17 @@ fn coerce_json_scalar_to_pv(value: &mut serde_json::Value, sf: &ScalarField) -> *value = serde_json::Value::Array(vec![]); } } - serde_json::Value::Number(n) => match sf.type_identifier() { + serde_json::Value::Number(ref n) => match sf.type_identifier() { TypeIdentifier::Decimal => { - let bd = n - .as_f64() - .and_then(BigDecimal::from_f64) - .map(|bd| bd.normalized()) - .ok_or_else(|| { - build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier())) - })?; + let bd = parse_json_f64(n, sf)?; *value = serde_json::Value::String(bd.normalized().to_string()); } + TypeIdentifier::Float => { + let bd = parse_json_f64(n, sf)?; + + *value = serde_json::Value::Number(Number::from_f64(stringify_decimal(&bd)).unwrap()); + } TypeIdentifier::Boolean => { let err = || build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier())); @@ -183,31 +178,34 @@ fn coerce_json_scalar_to_pv(value: &mut serde_json::Value, sf: &ScalarField) -> Ok(()) } -fn coerce_json_virtual_field_to_pv(key: &str, value: serde_json::Value) -> crate::Result { - match value { - serde_json::Value::Object(obj) => { - let values: crate::Result> = obj - .into_iter() - .map(|(key, value)| coerce_json_virtual_field_to_pv(&key, value).map(|value| (key, value))) - .collect(); - Ok(PrismaValue::Object(values?)) - } +pub fn reorder_virtuals_group(val: serde_json::Value, vs: &GroupedVirtualSelection) -> serde_json::Value { + match val { + serde_json::Value::Object(mut obj) => { + let mut new_obj = serde_json::Map::with_capacity(obj.len()); + + match vs { + GroupedVirtualSelection::RelationCounts(rcs) => { + for rc in rcs { + let (field_name, obj_val) = obj.remove_entry(rc.field().name()).unwrap(); + + new_obj.insert(field_name, obj_val); + } + } + } - serde_json::Value::Number(num) => num - .as_i64() - .ok_or_else(|| { - build_generic_conversion_error(format!( - "Unexpected numeric value {num} for virtual field '{key}': only integers are supported" - )) - }) - .map(PrismaValue::Int), - - _ => Err(build_generic_conversion_error(format!( - "Field '{key}' is not a model field and doesn't have a supported type for a virtual field" - ))), + new_obj.into() + } + _ => val, } } +fn parse_json_f64(n: &Number, sf: &Zipper) -> crate::Result { + n.as_f64() + .and_then(BigDecimal::from_f64) + .map(|bd| bd.normalized()) + .ok_or_else(|| build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier()))) +} + fn build_conversion_error(sf: &ScalarField, from: &str, to: &str) -> SqlError { let container_name = sf.container().name(); let field_name = sf.name(); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index b1df8e9be86a..de8043ae82ed 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -33,8 +33,8 @@ pub(crate) async fn get_single_record_joins( ctx: &Context<'_>, ) -> crate::Result> { let selected_fields = selected_fields.to_virtuals_last(); - let field_names: Vec<_> = selected_fields.prisma_names_grouping_virtuals().collect(); - let meta = column_metadata::create_from_selection(&selected_fields, &field_names); + let field_names: Vec<_> = selected_fields.grouped_prisma_names(); + let meta = column_metadata::create_from_selection_for_json(&selected_fields, &field_names); let query = query_builder::select::SelectBuilder::build( QueryArguments::from((model.clone(), filter.clone())), @@ -117,8 +117,9 @@ pub(crate) async fn get_many_records_joins( ctx: &Context<'_>, ) -> crate::Result { let selected_fields = selected_fields.to_virtuals_last(); - let field_names: Vec<_> = selected_fields.prisma_names_grouping_virtuals().collect(); - let meta = column_metadata::create_from_selection(&selected_fields, &field_names); + let field_names: Vec<_> = selected_fields.grouped_prisma_names(); + let meta = column_metadata::create_from_selection_for_json(&selected_fields, &field_names); + // dbg!(&meta); let mut records = ManyRecords::new(field_names.clone()); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs index e80f3e608385..5107d086fc77 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs @@ -147,7 +147,7 @@ fn process_result_row( meta: &[ColumnMetadata], selected_fields: &ModelProjection, ) -> crate::Result { - let sql_row = row.to_sql_row(meta, &mut std::time::Duration::ZERO)?; + let sql_row = row.to_sql_row(meta)?; let prisma_row = selected_fields.scalar_fields().zip(sql_row.values).collect_vec(); Ok(SelectionResult::new(prisma_row)) diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/upsert.rs b/query-engine/connectors/sql-query-connector/src/database/operations/upsert.rs index bd546a6b0741..f086e4c60798 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/upsert.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/upsert.rs @@ -33,7 +33,7 @@ pub(crate) async fn native_upsert( let result_set = conn.query(query).await?; let row = result_set.into_single()?; - let record = Record::from(row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?); + let record = Record::from(row.to_sql_row(&meta)?); Ok(SingleRecord { record, field_names }) } diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index fc1d62732978..dd27c35fb087 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -167,7 +167,7 @@ pub(crate) async fn create_record( let field_names: Vec<_> = selected_fields.db_names().collect(); let idents = ModelProjection::from(&selected_fields).type_identifiers_with_arities(); let meta = column_metadata::create(&field_names, &idents); - let sql_row = row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?; + let sql_row = row.to_sql_row(&meta)?; let record = Record::from(sql_row); Ok(SingleRecord { record, field_names }) @@ -273,7 +273,7 @@ pub(crate) async fn create_records_returning( for insert in inserts { let result_set = conn.query(insert.into()).await?; for result_row in result_set { - let sql_row = result_row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?; + let sql_row = result_row.to_sql_row(&meta)?; let record = Record::from(sql_row); records.push(record); } @@ -448,7 +448,7 @@ pub(crate) async fn delete_record( let field_db_names: Vec<_> = selected_fields.db_names().collect(); let types_and_arities = selected_fields.type_identifiers_with_arities(); let meta = column_metadata::create(&field_db_names, &types_and_arities); - let sql_row = result_row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?; + let sql_row = result_row.to_sql_row(&meta)?; let record = Record::from(sql_row); Ok(SingleRecord { diff --git a/query-engine/connectors/sql-query-connector/src/nested_aggregations.rs b/query-engine/connectors/sql-query-connector/src/nested_aggregations.rs index 9a8312153e1c..68bd996564d3 100644 --- a/query-engine/connectors/sql-query-connector/src/nested_aggregations.rs +++ b/query-engine/connectors/sql-query-connector/src/nested_aggregations.rs @@ -22,13 +22,13 @@ pub(crate) fn build<'a>( for (index, selection) in virtual_selections.into_iter().enumerate() { match selection { - VirtualSelection::RelationCount(rf, filter) => { + VirtualSelection::RelationCount(rc) => { let join_alias = format!("aggr_selection_{index}"); let aggregator_alias = selection.db_alias(); let join = compute_aggr_join( - rf, + rc.field(), AggregationType::Count, - filter.clone(), + rc.filter().cloned(), aggregator_alias.as_str(), join_alias.as_str(), None, diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs index 2098cd016691..e2e158a5b992 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs @@ -21,7 +21,7 @@ enum VirtualSelectionKey { impl From<&VirtualSelection> for VirtualSelectionKey { fn from(vs: &VirtualSelection) -> Self { match vs { - VirtualSelection::RelationCount(rf, _) => Self::RelationCount(rf.clone()), + VirtualSelection::RelationCount(rc) => Self::RelationCount(rc.field().clone()), } } } diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs index 5aec46423b3b..787ddaa10ced 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs @@ -295,11 +295,11 @@ pub(crate) trait JoinSelectBuilder { ctx: &Context<'_>, ) -> Select<'static> { match vs { - VirtualSelection::RelationCount(rf, filter) => { - if rf.relation().is_many_to_many() { - self.build_relation_count_query_m2m(vs.db_alias(), rf, filter, parent_alias, ctx) + VirtualSelection::RelationCount(rc) => { + if rc.field().relation().is_many_to_many() { + self.build_relation_count_query_m2m(vs.db_alias(), rc.field(), rc.filter(), parent_alias, ctx) } else { - self.build_relation_count_query(vs.db_alias(), rf, filter, parent_alias, ctx) + self.build_relation_count_query(vs.db_alias(), rc.field(), rc.filter(), parent_alias, ctx) } } } @@ -333,7 +333,7 @@ pub(crate) trait JoinSelectBuilder { &mut self, selection_name: impl Into>, rf: &RelationField, - filter: &Option, + filter: Option<&Filter>, parent_alias: Alias, ctx: &Context<'_>, ) -> Select<'a> { @@ -347,7 +347,7 @@ pub(crate) trait JoinSelectBuilder { let select = Select::from_table(related_table) .value(count(asterisk()).alias(selection_name)) .with_join_conditions(rf, parent_alias, related_table_alias, ctx) - .with_filters(filter.clone(), Some(related_table_alias), ctx); + .with_filters(filter.cloned(), Some(related_table_alias), ctx); select } @@ -356,7 +356,7 @@ pub(crate) trait JoinSelectBuilder { &mut self, selection_name: impl Into>, rf: &RelationField, - filter: &Option, + filter: Option<&Filter>, parent_alias: Alias, ctx: &Context<'_>, ) -> Select<'a> { @@ -395,7 +395,7 @@ pub(crate) trait JoinSelectBuilder { .value(count(asterisk()).alias(selection_name)) .left_join(m2m_join_data) .and_where(aggregation_join_conditions) - .with_filters(filter.clone(), Some(related_table_alias), ctx); + .with_filters(filter.cloned(), Some(related_table_alias), ctx); select } diff --git a/query-engine/connectors/sql-query-connector/src/query_ext.rs b/query-engine/connectors/sql-query-connector/src/query_ext.rs index 43e238a3c7a3..587fd1c6309c 100644 --- a/query-engine/connectors/sql-query-connector/src/query_ext.rs +++ b/query-engine/connectors/sql-query-connector/src/query_ext.rs @@ -41,14 +41,11 @@ impl QueryExt for Q { let mut sql_rows = Vec::with_capacity(result_set.len()); - let mut dur = std::time::Duration::ZERO; - let now = std::time::Instant::now(); for row in result_set { - sql_rows.push(row.to_sql_row(idents, &mut dur)?); + sql_rows.push(row.to_sql_row(idents)?); } - println!("coerce_json_relation: {:.2?}", dur); println!("to_row: {:.2?}", now.elapsed()); Ok(sql_rows) diff --git a/query-engine/connectors/sql-query-connector/src/row.rs b/query-engine/connectors/sql-query-connector/src/row.rs index 61270dad4495..397d8a897cc0 100644 --- a/query-engine/connectors/sql-query-connector/src/row.rs +++ b/query-engine/connectors/sql-query-connector/src/row.rs @@ -1,6 +1,9 @@ use crate::{ - column_metadata::ColumnMetadata, database::operations::coerce::coerce_json_relation_to_pv, error::SqlError, - value::to_prisma_value, MetadataFieldKind, + column_metadata::ColumnMetadata, + database::operations::coerce::{coerce_json_relation_to_pv, reorder_virtuals_group}, + error::SqlError, + value::to_prisma_value, + MetadataFieldKind, }; use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; use chrono::{DateTime, NaiveDate, Utc}; @@ -83,11 +86,11 @@ pub(crate) trait ToSqlRow { /// Conversion from a database specific row to an allocated `SqlRow`. To /// help deciding the right types, the provided `ColumnMetadata`s should map /// to the returned columns in the right order. - fn to_sql_row(self, meta: &[ColumnMetadata<'_>], dur: &mut std::time::Duration) -> crate::Result; + fn to_sql_row(self, meta: &[ColumnMetadata<'_>]) -> crate::Result; } impl ToSqlRow for ResultRow { - fn to_sql_row(self, meta: &[ColumnMetadata<'_>], dur: &mut std::time::Duration) -> crate::Result { + fn to_sql_row(self, meta: &[ColumnMetadata<'_>]) -> crate::Result { let mut row = SqlRow::default(); let row_width = meta.len(); @@ -100,7 +103,7 @@ impl ToSqlRow for ResultRow { ValueType::Array(None) => Ok(PrismaValue::List(Vec::new())), ValueType::Array(Some(l)) => l .into_iter() - .map(|val| row_value_to_prisma_value(val, &meta[i], dur)) + .map(|val| row_value_to_prisma_value(val, &meta[i])) .collect::>>() .map(PrismaValue::List), _ => { @@ -111,7 +114,7 @@ impl ToSqlRow for ResultRow { return Err(SqlError::ConversionError(error.into())); } }, - _ => row_value_to_prisma_value(p_value, &meta[i], dur), + _ => row_value_to_prisma_value(p_value, &meta[i]), }?; row.values.push(pv); @@ -121,11 +124,7 @@ impl ToSqlRow for ResultRow { } } -fn row_value_to_prisma_value( - p_value: Value, - meta: &ColumnMetadata<'_>, - dur: &mut std::time::Duration, -) -> Result { +fn row_value_to_prisma_value(p_value: Value, meta: &ColumnMetadata<'_>) -> Result { let create_error = |value: &Value| { let message = match meta.name() { Some(name) => { @@ -172,18 +171,16 @@ fn row_value_to_prisma_value( MetadataFieldKind::Relation(rs) => match p_value.typed { value if value.is_null() => PrismaValue::Null, ValueType::Json(Some(json)) => { - let now = std::time::Instant::now(); - let json = coerce_json_relation_to_pv(json, rs)?; - let res = PrismaValue::new_json(json); - - *dur += now.elapsed(); - res + PrismaValue::new_json(json) } _ => return Err(create_error(&p_value)), }, - MetadataFieldKind::Virtual => todo!(), + MetadataFieldKind::Virtual(vs) => match p_value.typed { + ValueType::Json(Some(json)) => PrismaValue::new_json(reorder_virtuals_group(json, vs)), + _ => return Err(create_error(&p_value)), + }, }, TypeIdentifier::UUID => match p_value.typed { value if value.is_null() => PrismaValue::Null, diff --git a/query-engine/core/src/interpreter/query_interpreters/read.rs b/query-engine/core/src/interpreter/query_interpreters/read.rs index 89747d33dbe3..fce11318ff13 100644 --- a/query-engine/core/src/interpreter/query_interpreters/read.rs +++ b/query-engine/core/src/interpreter/query_interpreters/read.rs @@ -2,7 +2,7 @@ use super::{inmemory_record_processor::InMemoryRecordProcessor, *}; use crate::{interpreter::InterpretationResult, query_ast::*, result_ast::*}; use connector::{error::ConnectorError, ConnectionLike}; use futures::future::{BoxFuture, FutureExt}; -use query_structure::{ManyRecords, RelationLoadStrategy, RelationSelection}; +use query_structure::{ManyRecords, RelationLoadStrategy}; use user_facing_errors::KnownError; pub(crate) fn execute<'conn>( @@ -62,11 +62,9 @@ fn read_one( Ok(RecordSelectionWithRelations { name: query.name, - model, - fields: query.selection_order, - virtuals: query.selected_fields.virtuals_owned(), + fields: query.selected_fields.into_virtuals_last(), + selection_order: query.selection_order, records, - nested: build_relation_record_selection(query.selected_fields.relations()), } .into()) } @@ -172,11 +170,9 @@ fn read_many_by_joins( } else { Ok(RecordSelectionWithRelations { name: query.name, - fields: query.selection_order, - virtuals: query.selected_fields.virtuals_owned(), + fields: query.selected_fields.into_virtuals_last(), + selection_order: query.selection_order, records: result, - nested: build_relation_record_selection(query.selected_fields.relations()), - model: query.model, } .into()) } @@ -185,20 +181,6 @@ fn read_many_by_joins( fut.boxed() } -fn build_relation_record_selection<'a>( - selections: impl Iterator, -) -> Vec { - selections - .map(|rq| RelationRecordSelection { - name: rq.field.name().to_owned(), - fields: rq.result_fields.clone(), - virtuals: rq.virtuals().cloned().collect(), - model: rq.field.related_model(), - nested: build_relation_record_selection(rq.relations()), - }) - .collect() -} - /// Queries related records for a set of parent IDs. fn read_related<'conn>( tx: &'conn mut dyn ConnectionLike, diff --git a/query-engine/core/src/query_graph_builder/read/utils.rs b/query-engine/core/src/query_graph_builder/read/utils.rs index 1e1a035934bb..57afb8a98ca8 100644 --- a/query-engine/core/src/query_graph_builder/read/utils.rs +++ b/query-engine/core/src/query_graph_builder/read/utils.rs @@ -184,7 +184,9 @@ fn extract_relation_count_selections( .map(|where_arg| extract_filter(where_arg.value.try_into()?, rf.related_model())) .transpose()?; - Ok(SelectedField::Virtual(VirtualSelection::RelationCount(rf, filter))) + Ok(SelectedField::Virtual(VirtualSelection::RelationCount( + RelationCountSelection::new(rf, filter), + ))) }) .collect() } diff --git a/query-engine/core/src/response_ir/internal.rs b/query-engine/core/src/response_ir/internal.rs index ecf181c1e844..3eeb635319d6 100644 --- a/query-engine/core/src/response_ir/internal.rs +++ b/query-engine/core/src/response_ir/internal.rs @@ -2,16 +2,15 @@ use super::*; use self::json_ext::{JsonValue, JsonValueExt}; use crate::{ - constants::custom_types, - protocol::EngineProtocol, - result_ast::{RecordSelectionWithRelations, RelationRecordSelection}, - CoreError, QueryResult, RecordAggregations, RecordSelection, + constants::custom_types, protocol::EngineProtocol, result_ast::RecordSelectionWithRelations, CoreError, + QueryResult, RecordAggregations, RecordSelection, }; use connector::AggregationResult; use indexmap::IndexMap; use query_structure::{ - CompositeFieldRef, Field, Model, PrismaValue, SelectionResult, TypeIdentifier, VirtualSelection, + CompositeFieldRef, Field, GroupedSelectedField, PrismaValue, RelationSelection, SelectionResult, TypeIdentifier, + VirtualSelection, }; use schema::{ constants::{aggregations::*, output_fields::*}, @@ -49,6 +48,7 @@ pub(crate) fn serialize_internal( is_list: bool, query_schema: &QuerySchema, ) -> crate::Result { + // dbg!(&result); match result { QueryResult::RecordSelection(Some(rs)) => { serialize_record_selection(*rs, field, field.field_type(), is_list, query_schema) @@ -312,31 +312,13 @@ fn finalize_objects( } } -enum SerializedFieldWithRelations<'a, 'b> { - Model(Field, &'a OutputField<'b>), - VirtualsGroup(&'a str, Vec<&'a VirtualSelection>), -} - -impl<'a, 'b> SerializedFieldWithRelations<'a, 'b> { - fn name(&self) -> &str { - match self { - Self::Model(f, _) => f.name(), - Self::VirtualsGroup(name, _) => name, - } - } -} - -// TODO: Handle errors properly fn serialize_objects_with_relation( result: RecordSelectionWithRelations, typ: &ObjectType<'_>, ) -> crate::Result { let mut object_mapping = UncheckedItemsWithParents::with_capacity(result.records.records.len()); - let fields = - collect_serialized_fields_with_relations(typ, &result.model, &result.virtuals, &result.records.field_names); - - let selected_db_field_names: HashSet = result.fields.clone().into_iter().collect(); + let result_fields: HashSet = result.selection_order.clone().into_iter().collect(); for record in result.records.records.into_iter() { if !object_mapping.contains_key(&record.parent_id) { @@ -346,161 +328,88 @@ fn serialize_objects_with_relation( let values = record.values; let mut object = HashMap::with_capacity(values.len()); - for (val, field) in values.into_iter().zip(fields.iter()) { - // Skip fields that aren't part of the selection set - if !selected_db_field_names.contains(field.name()) { + for (val, field) in values.into_iter().zip(result.fields.grouped_selections()) { + if !result_fields.contains(field.prisma_name()) { continue; } match field { - SerializedFieldWithRelations::Model(Field::Scalar(_), out_field) - if !out_field.field_type().is_object() => - { - object.insert(field.name().to_owned(), serialize_scalar(out_field, val)?); + GroupedSelectedField::Scalar(sf) => { + let out_field = typ.find_field(sf.name()).unwrap(); + + object.insert(sf.name().to_owned(), serialize_scalar(out_field, val)?); } - SerializedFieldWithRelations::Model(Field::Relation(_), out_field) - if out_field.field_type().is_list() => - { - let inner_typ = out_field.field_type.as_object_type().unwrap(); - let rrs = result.nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); + GroupedSelectedField::Relation(rs) if rs.field.is_list() => { let mut values = val.into_json().unwrap().try_into_value().unwrap().into_list().unwrap(); for val in values.iter_mut() { - discriminate_relation_selection(rrs, val, inner_typ)?; + serialize_relation_selection(rs, val)?; } - object.insert(field.name().to_owned(), Item::Json(serde_json::Value::Array(values))); + object.insert(rs.field.name().to_owned(), Item::Json(serde_json::Value::Array(values))); } - SerializedFieldWithRelations::Model(Field::Relation(_), out_field) => { - let inner_typ = out_field.field_type.as_object_type().unwrap(); - let rrs = result.nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); - + GroupedSelectedField::Relation(rs) => { if val.is_null() { - object.insert(field.name().to_owned(), Item::Json(serde_json::Value::Null)); + object.insert(rs.field.name().to_owned(), Item::Json(serde_json::Value::Null)); } else { let mut val = val.into_json().unwrap().try_into_value().unwrap(); - discriminate_relation_selection(rrs, &mut val, inner_typ)?; + serialize_relation_selection(rs, &mut val)?; - object.insert(field.name().to_owned(), Item::Json(val)); + object.insert(rs.field.name().to_owned(), Item::Json(val)); } } - SerializedFieldWithRelations::VirtualsGroup(group_name, virtuals) => { - object.insert(group_name.to_string(), serialize_virtuals_group(val, virtuals)?); + GroupedSelectedField::Virtual(out_field) => { + object.insert(out_field.serialized_name().0.to_owned(), Item::Value(val)); } - - _ => panic!("unexpected field"), } } - let map = reorder_object_with_selection_order(&result.fields, object); + // TODO: Remove this once we don't mess up virtuals ordering in the query builder. + let map = reorder_object_with_selection_order(&result.selection_order, object); - let result = Item::Map(map); - - object_mapping.get_mut(&record.parent_id).unwrap().push(result); + object_mapping.get_mut(&record.parent_id).unwrap().push(Item::Map(map)); } Ok(object_mapping) } -fn discriminate_relation_selection( - rrs: &RelationRecordSelection, - value: &mut serde_json::Value, - typ: &ObjectType<'_>, -) -> crate::Result<()> { - let fields = collect_serialized_fields_with_relations(typ, &rrs.model, &rrs.virtuals, &rrs.fields); +fn serialize_relation_selection(rs: &RelationSelection, value: &mut serde_json::Value) -> crate::Result<()> { + if value.is_null() { + return Ok(()); + } - let value_obj = value.into_object_mut().unwrap(); + let value_obj = value.as_object_mut().unwrap(); - for field in fields { - let value = value_obj.get_mut(field.name()).unwrap(); + for field in rs.grouped_fields_to_serialize() { + let value = value_obj.get_mut(field.prisma_name()).unwrap(); match field { - SerializedFieldWithRelations::Model(Field::Scalar(sf), out_field) - if !out_field.field_type().is_object() => - { + GroupedSelectedField::Scalar(sf) => { *value = discriminate_json_value(value.clone(), sf.type_identifier()); } - SerializedFieldWithRelations::Model(Field::Relation(_), out_field) if out_field.field_type().is_list() => { - let inner_typ = out_field.field_type.as_object_type().unwrap(); - let inner_rrs = rrs.nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); - let values = value.into_list_mut().unwrap(); + GroupedSelectedField::Relation(rs) if rs.field.is_list() => { + let values = value.as_list_mut().unwrap(); for value in values { - discriminate_relation_selection(inner_rrs, value, inner_typ)?; + serialize_relation_selection(rs, value)?; } } - SerializedFieldWithRelations::Model(Field::Relation(_), out_field) => { - let inner_typ = out_field.field_type.as_object_type().unwrap(); - let inner_rrs = rrs.nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); - - discriminate_relation_selection(inner_rrs, value, inner_typ)? - } - - SerializedFieldWithRelations::VirtualsGroup(_group_name, _virtuals) => { - todo!() - // map.insert(group_name.to_string(), serialize_virtuals_group(value, &virtuals)?); - } + GroupedSelectedField::Relation(rs) => serialize_relation_selection(rs, value)?, - _ => (), + // Nested virtual fields are already handled by the connector. + GroupedSelectedField::Virtual(_) => (), } } Ok(()) } -fn collect_serialized_fields_with_relations<'a, 'b>( - object_type: &'a ObjectType<'b>, - model: &Model, - virtuals: &'a [VirtualSelection], - db_field_names: &'a [String], -) -> Vec> { - db_field_names - .iter() - .map(|name| { - model - .fields() - .all() - .find(|field| field.name() == name) - .and_then(|field| { - object_type - .find_field(field.name()) - .map(|out_field| SerializedFieldWithRelations::Model(field, out_field)) - }) - .unwrap_or_else(|| { - let matching_virtuals = virtuals.iter().filter(|vs| vs.serialized_name().0 == name).collect(); - SerializedFieldWithRelations::VirtualsGroup(name.as_str(), matching_virtuals) - }) - }) - .collect() -} - -fn serialize_virtuals_group(obj_value: PrismaValue, virtuals: &[&VirtualSelection]) -> crate::Result { - let mut db_object: HashMap = HashMap::from_iter(obj_value.into_object().unwrap()); - let mut out_object = Map::new(); - - // We have to reorder the object fields according to selection even if the query - // builder respects the initial order because JSONB does not preserve order. - for vs in virtuals { - let (group_name, nested_name) = vs.serialized_name(); - - let value = db_object.remove(nested_name).ok_or_else(|| { - CoreError::SerializationError(format!( - "Expected virtual field {nested_name} not found in {group_name} object" - )) - })?; - - out_object.insert(nested_name.into(), Item::Value(vs.coerce_value(value)?)); - } - - Ok(Item::Map(out_object)) -} - enum SerializedField<'a, 'b> { Model(Field, &'a OutputField<'b>), Virtual(&'a VirtualSelection), @@ -865,12 +774,12 @@ fn convert_prisma_value_json_protocol( Ok(item_value) } -fn discriminate_json_value(value: JsonValue, st: TypeIdentifier) -> JsonValue { +fn discriminate_json_value(value: JsonValue, typ: TypeIdentifier) -> JsonValue { if crate::executor::get_engine_protocol().is_graphql() { return value; } - match (st, value) { + match (typ, value) { (TypeIdentifier::Json, x) => custom_types::make_json_object(custom_types::JSON, x), (TypeIdentifier::DateTime, x) => custom_types::make_json_object(custom_types::DATETIME, x), (TypeIdentifier::Decimal, x) => custom_types::make_json_object(custom_types::DECIMAL, x), diff --git a/query-engine/core/src/response_ir/json_ext.rs b/query-engine/core/src/response_ir/json_ext.rs index 165cea3101ae..c19b3de5a895 100644 --- a/query-engine/core/src/response_ir/json_ext.rs +++ b/query-engine/core/src/response_ir/json_ext.rs @@ -3,10 +3,10 @@ pub(crate) type JsonValue = serde_json::Value; pub(crate) trait JsonValueExt { fn into_object(self) -> Option; - fn into_object_mut(&mut self) -> Option<&mut JsonObject>; + fn as_object_mut(&mut self) -> Option<&mut JsonObject>; fn into_list(self) -> Option>; - fn into_list_mut(&mut self) -> Option<&mut Vec>; + fn as_list_mut(&mut self) -> Option<&mut Vec>; } impl JsonValueExt for JsonValue { @@ -17,7 +17,7 @@ impl JsonValueExt for JsonValue { } } - fn into_object_mut(&mut self) -> Option<&mut JsonObject> { + fn as_object_mut(&mut self) -> Option<&mut JsonObject> { match self { JsonValue::Object(obj) => Some(obj), _ => None, @@ -31,7 +31,7 @@ impl JsonValueExt for JsonValue { } } - fn into_list_mut(&mut self) -> Option<&mut Vec> { + fn as_list_mut(&mut self) -> Option<&mut Vec> { match self { JsonValue::Array(arr) => Some(arr), _ => None, diff --git a/query-engine/core/src/result_ast/mod.rs b/query-engine/core/src/result_ast/mod.rs index e450b7213774..897782b8ab17 100644 --- a/query-engine/core/src/result_ast/mod.rs +++ b/query-engine/core/src/result_ast/mod.rs @@ -1,5 +1,5 @@ use connector::AggregationRow; -use query_structure::{ManyRecords, Model, SelectionResult, VirtualSelection}; +use query_structure::{FieldSelection, ManyRecords, Model, SelectionResult, VirtualSelection}; #[derive(Debug, Clone)] pub(crate) enum QueryResult { @@ -17,21 +17,13 @@ pub struct RecordSelectionWithRelations { /// Name of the query. pub(crate) name: String, - /// Holds an ordered list of selected field names for each contained record. - pub(crate) fields: Vec, + pub(crate) selection_order: Vec, - /// Holds the list of virtual selections included in the query result. - /// TODO: in the future it should be covered by [`RecordSelection::fields`] by storing ordered - /// `Vec` or `FieldSelection` instead of `Vec`. - pub(crate) virtuals: Vec, + /// Holds an ordered list of selected field names for each contained record. + pub(crate) fields: FieldSelection, /// Selection results pub(crate) records: ManyRecords, - - pub(crate) nested: Vec, - - /// The model of the contained records. - pub(crate) model: Model, } impl From for QueryResult { diff --git a/query-engine/query-structure/src/field_selection.rs b/query-engine/query-structure/src/field_selection.rs index 1edc73accc3e..012ed7536b76 100644 --- a/query-engine/query-structure/src/field_selection.rs +++ b/query-engine/query-structure/src/field_selection.rs @@ -43,6 +43,14 @@ impl FieldSelection { self.selections.iter() } + pub fn grouped_selections(&self) -> impl Iterator> { + self.selections.iter().to_grouped_virtuals() + } + + pub fn grouped_prisma_names(&self) -> Vec { + self.grouped_selections().map(|f| f.prisma_name().to_owned()).collect() + } + pub fn scalars(&self) -> impl Iterator + '_ { self.selections().filter_map(SelectedField::as_scalar) } @@ -51,6 +59,17 @@ impl FieldSelection { self.selections().filter_map(SelectedField::as_virtual) } + pub fn grouped_virtuals(&self) -> impl Iterator> { + self.grouped_selections().filter_map(GroupedSelectedField::into_virtual) + } + + pub fn extract_from_grouped<'a>(&'a self, field_names: &[String]) -> Vec> { + field_names + .iter() + .filter_map(move |name| self.grouped_selections().find(|field| field.prisma_name() == name)) + .collect() + } + pub fn virtuals_owned(&self) -> Vec { self.virtuals().cloned().collect() } @@ -212,17 +231,7 @@ impl FieldSelection { /// Returns type identifiers and arities, treating all virtual fields as separate fields. pub fn type_identifiers_with_arities(&self) -> Vec<(TypeIdentifier, FieldArity)> { self.selections() - .filter_map(SelectedField::type_identifier_with_arity) - .collect() - } - - /// Returns type identifiers and arities, grouping the virtual fields so that the type - /// identifier and arity is returned for the whole object containing multiple virtual fields - /// and not each of those fields separately. This represents the selection in joined queries - /// that use JSON objects for relations and relation aggregations. - pub fn type_identifiers_with_arities_grouping_virtuals(&self) -> Vec<(TypeIdentifier, FieldArity)> { - self.selections_with_virtual_group_heads() - .filter_map(|vs| vs.type_identifier_with_arity_grouping_virtuals()) + .map(SelectedField::type_identifier_with_arity) .collect() } @@ -258,7 +267,7 @@ pub enum SelectedField { Virtual(VirtualSelection), } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct RelationSelection { pub field: RelationField, pub args: QueryArguments, @@ -268,6 +277,89 @@ pub struct RelationSelection { pub selections: Vec, } +impl RelationSelection { + pub fn new( + field: RelationField, + args: QueryArguments, + result_fields: Vec, + selections: Vec, + ) -> Self { + Self { + field, + args, + result_fields, + selections, + } + } + + pub fn type_identifier_with_arity(&self) -> (TypeIdentifier, FieldArity) { + if self.field.is_list() { + (TypeIdentifier::Json, FieldArity::Required) + } else { + (TypeIdentifier::Json, self.field.arity()) + } + } +} + +impl std::fmt::Debug for RelationSelection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RelationSelection") + .field("field", &self.field) + .field("selections", &self.selections) + .finish() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum GroupedSelectedField<'a> { + Scalar(&'a ScalarFieldRef), + Relation(&'a RelationSelection), + Virtual(GroupedVirtualSelection<'a>), +} + +impl<'a> GroupedSelectedField<'a> { + pub fn prisma_name(&self) -> &'a str { + match self { + GroupedSelectedField::Scalar(sf) => sf.name(), + GroupedSelectedField::Relation(rf) => rf.field.name(), + GroupedSelectedField::Virtual(vs) => vs.serialized_name().0, + } + } + + pub fn type_identifier_with_arity_for_json(&self) -> (TypeIdentifier, FieldArity) { + match self { + GroupedSelectedField::Scalar(sf) => sf.type_identifier_with_arity(), + GroupedSelectedField::Relation(rs) => rs.type_identifier_with_arity(), + GroupedSelectedField::Virtual(vs) => vs.type_identifier_with_arity_for_json(), + } + } + + pub fn into_virtual(self) -> Option> { + if let Self::Virtual(v) = self { + Some(v) + } else { + None + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum GroupedVirtualSelection<'a> { + RelationCounts(Vec<&'a RelationCountSelection>), +} + +impl<'a> GroupedVirtualSelection<'a> { + pub fn serialized_name(&self) -> (&'static str, &str) { + match self { + Self::RelationCounts(x) => x[0].serialized_name(), + } + } + + pub fn type_identifier_with_arity_for_json(&self) -> (TypeIdentifier, FieldArity) { + (TypeIdentifier::Json, FieldArity::Required) + } +} + impl RelationSelection { pub fn scalars(&self) -> impl Iterator { self.selections.iter().filter_map(|selection| match selection { @@ -287,6 +379,22 @@ impl RelationSelection { self.selections.iter().filter_map(SelectedField::as_virtual) } + pub fn virtuals_grouped(&self) -> impl Iterator> { + self.selections + .iter() + .to_grouped_virtuals() + .filter_map(GroupedSelectedField::into_virtual) + } + + pub fn grouped_selections(&self) -> impl Iterator> { + self.selections.iter().to_grouped_virtuals() + } + + pub fn grouped_fields_to_serialize(&self) -> impl Iterator> { + self.grouped_selections() + .filter(|field| self.result_fields.iter().any(|name| name == field.prisma_name())) + } + pub fn related_model(&self) -> Model { self.field.related_model() } @@ -294,13 +402,53 @@ impl RelationSelection { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum VirtualSelection { - RelationCount(RelationFieldRef, Option), + RelationCount(RelationCountSelection), +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct RelationCountSelection { + field: RelationFieldRef, + filter: Option, +} + +impl std::fmt::Debug for RelationCountSelection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RelationCountSelection") + .field("field", &self.field) + .finish() + } +} + +impl RelationCountSelection { + pub fn new(field: RelationFieldRef, filter: Option) -> Self { + Self { field, filter } + } + + pub fn db_alias(&self) -> String { + format!("_aggr_count_{}", self.field.name()) + } + + pub fn serialized_name(&self) -> (&'static str, &str) { + ("_count", self.field.name()) + } + + pub fn model(&self) -> Model { + self.field.model() + } + + pub fn field(&self) -> &RelationFieldRef { + &self.field + } + + pub fn filter(&self) -> Option<&Filter> { + self.filter.as_ref() + } } impl VirtualSelection { pub fn db_alias(&self) -> String { match self { - Self::RelationCount(rf, _) => format!("_aggr_count_{}", rf.name()), + Self::RelationCount(x) => x.db_alias(), } } @@ -308,19 +456,19 @@ impl VirtualSelection { match self { // TODO: we can't use UNDERSCORE_COUNT here because it would require a circular // dependency between `schema` and `query-structure` crates. - Self::RelationCount(rf, _) => ("_count", rf.name()), + Self::RelationCount(x) => x.serialized_name(), } } pub fn model(&self) -> Model { match self { - Self::RelationCount(rf, _) => rf.model(), + Self::RelationCount(x) => x.model(), } } pub fn coerce_value(&self, value: PrismaValue) -> crate::Result { match self { - Self::RelationCount(_, _) => match value { + Self::RelationCount(_) => match value { PrismaValue::Null => Ok(PrismaValue::Int(0)), _ => value.coerce(TypeIdentifier::Int), }, @@ -329,25 +477,37 @@ impl VirtualSelection { pub fn field(&self) -> Field { match self { - Self::RelationCount(rf, _) => rf.clone().into(), + Self::RelationCount(x) => x.field.clone().into(), } } pub fn type_identifier_with_arity(&self) -> (TypeIdentifier, FieldArity) { match self { - Self::RelationCount(_, _) => (TypeIdentifier::Int, FieldArity::Required), + Self::RelationCount(_) => (TypeIdentifier::Int, FieldArity::Required), + } + } + + pub fn type_identifier_with_arity_for_json(&self) -> (TypeIdentifier, FieldArity) { + match self { + Self::RelationCount(_) => (TypeIdentifier::Json, FieldArity::Required), } } pub fn relation_field(&self) -> &RelationField { match self { - VirtualSelection::RelationCount(rf, _) => rf, + VirtualSelection::RelationCount(x) => &x.field, } } pub fn filter(&self) -> Option<&Filter> { match self { - VirtualSelection::RelationCount(_, filter) => filter.as_ref(), + VirtualSelection::RelationCount(x) => x.filter.as_ref(), + } + } + + pub fn as_relation_count(&self) -> Option<&RelationCountSelection> { + match self { + Self::RelationCount(v) => Some(v), } } } @@ -399,31 +559,27 @@ impl SelectedField { } } - /// Returns the type identifier and arity of this field, unless it is a composite field, in - /// which case [`None`] is returned. - pub fn type_identifier_with_arity(&self) -> Option<(TypeIdentifier, FieldArity)> { + /// Returns the type identifier and arity of this field. + pub fn type_identifier_with_arity(&self) -> (TypeIdentifier, FieldArity) { match self { - SelectedField::Scalar(sf) => Some(sf.type_identifier_with_arity()), - SelectedField::Relation(rf) if rf.field.is_list() => Some((TypeIdentifier::Json, FieldArity::Required)), - SelectedField::Relation(rf) => Some((TypeIdentifier::Json, rf.field.arity())), - SelectedField::Composite(_) => None, - SelectedField::Virtual(vs) => Some(vs.type_identifier_with_arity()), + SelectedField::Scalar(sf) => sf.type_identifier_with_arity(), + SelectedField::Relation(rs) => rs.type_identifier_with_arity(), + SelectedField::Virtual(vs) => vs.type_identifier_with_arity(), + SelectedField::Composite(_) => unreachable!(), } } - /// Returns the type identifier and arity of this field, unless it is a composite field, in - /// which case [`None`] is returned. + /// Returns the type identifier and arity of this field when it is queries as JSON object. /// /// In the case of virtual fields that are wrapped into objects in Prisma queries /// (specifically, relation aggregations), the returned information refers not to the current /// field itself but to the whole object that contains this field. This is used by the queries /// with relation JOINs because they use JSON objects to reprsent both relations and relation - /// aggregations, so individual virtual fields that correspond to those relation aggregations - /// don't exist as separate values in the result of the query. - pub fn type_identifier_with_arity_grouping_virtuals(&self) -> Option<(TypeIdentifier, FieldArity)> { + /// aggregations, so individual virtual fields that correspond to those relation aggreg + pub fn type_identifier_with_arity_for_json(&self) -> (TypeIdentifier, FieldArity) { match self { - SelectedField::Virtual(_) => Some((TypeIdentifier::Json, FieldArity::Required)), - _ => self.type_identifier_with_arity(), + SelectedField::Virtual(vs) => vs.type_identifier_with_arity_for_json(), + x => x.type_identifier_with_arity(), } } @@ -471,6 +627,11 @@ impl SelectedField { pub fn is_scalar(&self) -> bool { matches!(self, Self::Scalar(..)) } + + /// Returns `true` if the selected field is [`Virtual`]. + pub fn is_virtual(&self) -> bool { + matches!(self, Self::Virtual(..)) + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -597,3 +758,69 @@ impl IntoIterator for FieldSelection { self.selections.into_iter() } } + +pub struct GroupedSelectionIterator<'a, I> +where + I: Iterator, +{ + inner: std::iter::Peekable, +} + +impl<'a, I> GroupedSelectionIterator<'a, I> +where + I: Iterator, +{ + fn new(inner: std::iter::Peekable) -> Self { + GroupedSelectionIterator { inner } + } +} + +impl<'a, I> Iterator for GroupedSelectionIterator<'a, I> +where + I: Iterator, +{ + type Item = GroupedSelectedField<'a>; + + fn next(&mut self) -> Option { + match self.inner.next() { + Some(SelectedField::Scalar(scalar_field_ref)) => Some(GroupedSelectedField::Scalar(scalar_field_ref)), + Some(SelectedField::Relation(relation_selection)) => { + Some(GroupedSelectedField::Relation(relation_selection)) + } + Some(SelectedField::Virtual(vs)) => { + let group_iter = self.inner.peeking_take_while(|field| field.is_virtual()); + + let virtual_group = match vs { + VirtualSelection::RelationCount(x) => GroupedVirtualSelection::RelationCounts( + std::iter::once(x) + .chain( + group_iter + .filter_map(|sf| sf.as_virtual().and_then(VirtualSelection::as_relation_count)), + ) + .collect(), + ), + }; + + Some(GroupedSelectedField::Virtual(virtual_group)) + } + _ => None, + } + } +} + +trait ToGrouped<'a, I> +where + I: Iterator, +{ + /// TODO: document fn + fn to_grouped_virtuals(self) -> GroupedSelectionIterator<'a, I>; +} + +impl<'a, I> ToGrouped<'a, I> for I +where + I: Iterator, +{ + fn to_grouped_virtuals(self) -> GroupedSelectionIterator<'a, I> { + GroupedSelectionIterator::new(self.peekable()) + } +} diff --git a/query-engine/schema/src/output_types.rs b/query-engine/schema/src/output_types.rs index 32956d01d50b..19eef1388f79 100644 --- a/query-engine/schema/src/output_types.rs +++ b/query-engine/schema/src/output_types.rs @@ -86,6 +86,14 @@ impl<'a> OutputType<'a> { } } + pub fn as_scalar(&self) -> Option<&ScalarType> { + if let InnerOutputType::Scalar(v) = &self.inner { + Some(v) + } else { + None + } + } + pub fn is_list(&self) -> bool { self.is_list }