Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(frontend): refactor extended query mode #8919

Merged
merged 3 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/frontend/src/binder/bind_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{Result, RwError};
use pgwire::types::{Format, FormatIterator};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::ScalarImpl;

use super::statement::RewriteExprsRecursive;
Expand Down Expand Up @@ -85,8 +85,10 @@ impl BoundStatement {
param_formats: Vec<Format>,
) -> Result<BoundStatement> {
let mut rewriter = ParamRewriter {
param_formats: FormatIterator::new(&param_formats, params.len())
.map_err(ErrorCode::BindError)?
.collect(),
params,
param_formats,
error: None,
};

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::catalog::TableId;
use crate::expr::ExprImpl;
use crate::user::UserId;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundDelete {
/// Id of the table to perform deleting.
pub table_id: TableId,
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::catalog::TableId;
use crate::expr::{ExprImpl, InputRef};
use crate::user::UserId;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundInsert {
/// Id of the table to perform inserting.
pub table_id: TableId,
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ impl Binder {
Self::new_inner(session, true, vec![])
}

pub fn new_for_stream_with_param_types(
session: &SessionImpl,
param_types: Vec<DataType>,
) -> Binder {
Self::new_inner(session, true, param_types)
}

/// Bind a [`Statement`].
pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
self.bind_statement(stmt)
Expand Down
23 changes: 22 additions & 1 deletion src/frontend/src/binder/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::catalog::Field;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_sqlparser::ast::Statement;

Expand All @@ -20,14 +21,34 @@ use super::update::BoundUpdate;
use crate::binder::{Binder, BoundInsert, BoundQuery};
use crate::expr::ExprRewriter;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BoundStatement {
Insert(Box<BoundInsert>),
Delete(Box<BoundDelete>),
Update(Box<BoundUpdate>),
Query(Box<BoundQuery>),
}

impl BoundStatement {
pub fn output_fields(&self) -> Vec<Field> {
match self {
BoundStatement::Insert(i) => i.returning_schema.as_ref().map_or(
vec![Field::unnamed(risingwave_common::types::DataType::Int64)],
|s| s.fields().into(),
),
BoundStatement::Delete(d) => d.returning_schema.as_ref().map_or(
vec![Field::unnamed(risingwave_common::types::DataType::Int64)],
|s| s.fields().into(),
),
BoundStatement::Update(u) => u.returning_schema.as_ref().map_or(
vec![Field::unnamed(risingwave_common::types::DataType::Int64)],
|s| s.fields().into(),
),
BoundStatement::Query(q) => q.schema().fields().into(),
}
}
}

impl Binder {
pub(super) fn bind_statement(&mut self, stmt: Statement) -> Result<BoundStatement> {
match stmt {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::catalog::TableId;
use crate::expr::{Expr as _, ExprImpl};
use crate::user::UserId;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundUpdate {
/// Id of the table to perform updating.
pub table_id: TableId,
Expand Down
118 changes: 84 additions & 34 deletions src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@ use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{Query, Statement};

use super::{query, HandlerArgs, RwPgResponse};
use super::{handle, query, HandlerArgs, RwPgResponse};
use crate::binder::BoundStatement;
use crate::session::SessionImpl;

pub struct PrepareStatement {
#[derive(Clone)]
pub enum PrepareStatement {
Prepared(PreparedResult),
PureStatement(Statement),
}
Comment on lines +28 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify in which circumstances we will use PureStatement and add some comments? IIUC, it seems to be used by Create statements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refine them in another PR.


#[derive(Clone)]
pub struct PreparedResult {
pub statement: Statement,
pub bound_statement: BoundStatement,
pub param_types: Vec<DataType>,
}

pub struct Portal {
#[derive(Clone)]
pub enum Portal {
Portal(PortalResult),
PureStatement(Statement),
}

#[derive(Clone)]
pub struct PortalResult {
pub statement: Statement,
pub bound_statement: BoundStatement,
pub result_formats: Vec<Format>,
Expand All @@ -44,16 +58,38 @@ pub fn handle_parse(
session.clear_cancel_query_flag();
let str_sql = stmt.to_string();
let handler_args = HandlerArgs::new(session, &stmt, &str_sql)?;
match stmt {
match &stmt {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_parse(handler_args, stmt, specific_param_types),
_ => Err(ErrorCode::NotSupported(
format!("Can't support {} in extended query mode now", str_sql,),
"".to_string(),
)
.into()),
Statement::CreateView {
query,
..
} => {
if have_parameter_in_query(query) {
return Err(ErrorCode::NotImplemented(
"CREATE VIEW with parameters".to_string(),
None.into(),
)
.into());
}
Ok(PrepareStatement::PureStatement(stmt))
}
Statement::CreateTable {
query,
..
} => {
if let Some(query) = query && have_parameter_in_query(query) {
Err(ErrorCode::NotImplemented(
"CREATE TABLE AS SELECT with parameters".to_string(),
None.into(),
).into())
} else {
Ok(PrepareStatement::PureStatement(stmt))
}
}
_ => Ok(PrepareStatement::PureStatement(stmt)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Create Sink?

}
}

Expand All @@ -63,32 +99,46 @@ pub fn handle_bind(
param_formats: Vec<Format>,
result_formats: Vec<Format>,
) -> Result<Portal> {
let PrepareStatement {
statement,
bound_statement,
..
} = prepare_statement;
let bound_statement = bound_statement.bind_parameter(params, param_formats)?;
Ok(Portal {
statement,
bound_statement,
result_formats,
})
match prepare_statement {
PrepareStatement::Prepared(prepared_result) => {
let PreparedResult {
statement,
bound_statement,
..
} = prepared_result;
let bound_statement = bound_statement.bind_parameter(params, param_formats)?;
Ok(Portal::Portal(PortalResult {
statement,
bound_statement,
result_formats,
}))
}
PrepareStatement::PureStatement(stmt) => Ok(Portal::PureStatement(stmt)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some assertion? e.g. params.is_empty()

}
}

pub async fn handle_execute(session: Arc<SessionImpl>, portal: Portal) -> Result<RwPgResponse> {
session.clear_cancel_query_flag();
let str_sql = portal.statement.to_string();
let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?;
match &portal.statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_execute(handler_args, portal).await,
_ => Err(ErrorCode::NotSupported(
format!("Can't support {} in extended query mode now", str_sql,),
"".to_string(),
)
.into()),
match portal {
Portal::Portal(portal) => {
session.clear_cancel_query_flag();
let str_sql = portal.statement.to_string();
let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?;
match &portal.statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_execute(handler_args, portal).await,
_ => unreachable!(),
}
}
Portal::PureStatement(stmt) => {
let sql = stmt.to_string();
handle(session, stmt, &sql, vec![]).await
}
}
}

/// A quick way to check if a query contains parameters.
fn have_parameter_in_query(query: &Query) -> bool {
query.to_string().contains("$1")
}
91 changes: 53 additions & 38 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ use risingwave_common::session_config::QueryMode;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{SetExpr, Statement};

use super::extended_handle::{Portal, PrepareStatement};
use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult};
use super::{PgResponseStream, RwPgResponse};
use crate::binder::Binder;
use crate::binder::{Binder, BoundStatement};
use crate::catalog::TableId;
use crate::handler::flush::do_flush;
use crate::handler::privilege::resolve_privileges;
Expand Down Expand Up @@ -368,6 +368,8 @@ pub async fn local_execute(
Ok(execution.stream_rows())
}

// TODO: Following code have redundant code with `handle_query`, we may need to refactor them in
// future.
pub fn handle_parse(
handler_args: HandlerArgs,
statement: Statement,
Expand All @@ -382,15 +384,58 @@ pub fn handle_parse(

let param_types = binder.export_param_types()?;

Ok(PrepareStatement {
Ok(PrepareStatement::Prepared(PreparedResult {
statement,
bound_statement,
param_types,
})
}))
}

pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result<RwPgResponse> {
let Portal {
pub fn gen_batch_query_plan_for_bound(
session: &SessionImpl,
context: OptimizerContextRef,
stmt: Statement,
bound: BoundStatement,
) -> Result<(PlanRef, QueryMode, Schema)> {
let must_dist = must_run_in_distributed_mode(&stmt)?;

let mut planner = Planner::new(context);

let mut logical = planner.plan(bound)?;
let schema = logical.schema();
let batch_plan = logical.gen_batch_plan()?;

let must_local = must_run_in_local_mode(batch_plan.clone());

let query_mode = match (must_dist, must_local) {
(true, true) => {
return Err(ErrorCode::InternalError(
"the query is forced to both local and distributed mode by optimizer".to_owned(),
)
.into())
}
(true, false) => QueryMode::Distributed,
(false, true) => QueryMode::Local,
(false, false) => match session.config().get_query_mode() {
QueryMode::Auto => determine_query_mode(batch_plan.clone()),
QueryMode::Local => QueryMode::Local,
QueryMode::Distributed => QueryMode::Distributed,
},
};

let physical = match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?,
QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?,
};
Ok((physical, query_mode, schema))
}

pub async fn handle_execute(
handler_args: HandlerArgs,
portal: PortalResult,
) -> Result<RwPgResponse> {
let PortalResult {
statement,
bound_statement,
result_formats,
Expand All @@ -407,38 +452,8 @@ pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result
let (plan_fragmenter, query_mode, output_schema) = {
let context = OptimizerContext::from_handler_args(handler_args);

let must_dist = must_run_in_distributed_mode(&statement)?;

let mut planner = Planner::new(context.into());

let mut logical = planner.plan(bound_statement)?;
let schema = logical.schema();
let batch_plan = logical.gen_batch_plan()?;

let must_local = must_run_in_local_mode(batch_plan.clone());

let query_mode = match (must_dist, must_local) {
(true, true) => {
return Err(ErrorCode::InternalError(
"the query is forced to both local and distributed mode by optimizer"
.to_owned(),
)
.into())
}
(true, false) => QueryMode::Distributed,
(false, true) => QueryMode::Local,
(false, false) => match session.config().get_query_mode() {
QueryMode::Auto => determine_query_mode(batch_plan.clone()),
QueryMode::Local => QueryMode::Local,
QueryMode::Distributed => QueryMode::Distributed,
},
};

let physical = match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?,
QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?,
};
let (physical, query_mode, schema) =
gen_batch_query_plan_for_bound(&session, context.into(), statement, bound_statement)?;

let context = physical.plan_base().ctx.clone();
tracing::trace!(
Expand Down
Loading