Skip to content

Commit

Permalink
Support postgres CREATE FUNCTION (#722)
Browse files Browse the repository at this point in the history
* support basic pg CREATE FUNCTION

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* support function argument

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* fix display and use verify in test

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* support OR REPLACE

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* fix compile error in bigdecimal

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* unify all `CreateFunctionBody` to a structure

Signed-off-by: Runji Wang <wangrunji0408@163.com>

Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 committed Dec 1, 2022
1 parent f621142 commit 5b53df9
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 29 deletions.
150 changes: 143 additions & 7 deletions src/ast/mod.rs
Expand Up @@ -1405,11 +1405,15 @@ pub enum Statement {
/// CREATE FUNCTION
///
/// Hive: https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction
/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
CreateFunction {
or_replace: bool,
temporary: bool,
name: ObjectName,
class_name: String,
using: Option<CreateFunctionUsing>,
args: Option<Vec<CreateFunctionArg>>,
return_type: Option<DataType>,
/// Optional parameters.
params: CreateFunctionBody,
},
/// `ASSERT <condition> [AS <message>]`
Assert {
Expand Down Expand Up @@ -1866,19 +1870,26 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::CreateFunction {
or_replace,
temporary,
name,
class_name,
using,
args,
return_type,
params,
} => {
write!(
f,
"CREATE {temp}FUNCTION {name} AS '{class_name}'",
"CREATE {or_replace}{temp}FUNCTION {name}",
temp = if *temporary { "TEMPORARY " } else { "" },
or_replace = if *or_replace { "OR REPLACE " } else { "" },
)?;
if let Some(u) = using {
write!(f, " {}", u)?;
if let Some(args) = args {
write!(f, "({})", display_comma_separated(args))?;
}
if let Some(return_type) = return_type {
write!(f, " RETURNS {}", return_type)?;
}
write!(f, "{params}")?;
Ok(())
}
Statement::CreateView {
Expand Down Expand Up @@ -3679,6 +3690,131 @@ impl fmt::Display for ContextModifier {
}
}

/// Function argument in CREATE FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CreateFunctionArg {
pub mode: Option<ArgMode>,
pub name: Option<Ident>,
pub data_type: DataType,
pub default_expr: Option<Expr>,
}

impl CreateFunctionArg {
/// Returns an unnamed argument.
pub fn unnamed(data_type: DataType) -> Self {
Self {
mode: None,
name: None,
data_type,
default_expr: None,
}
}

/// Returns an argument with name.
pub fn with_name(name: &str, data_type: DataType) -> Self {
Self {
mode: None,
name: Some(name.into()),
data_type,
default_expr: None,
}
}
}

impl fmt::Display for CreateFunctionArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(mode) = &self.mode {
write!(f, "{} ", mode)?;
}
if let Some(name) = &self.name {
write!(f, "{} ", name)?;
}
write!(f, "{}", self.data_type)?;
if let Some(default_expr) = &self.default_expr {
write!(f, " = {}", default_expr)?;
}
Ok(())
}
}

/// The mode of an argument in CREATE FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ArgMode {
In,
Out,
InOut,
}

impl fmt::Display for ArgMode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ArgMode::In => write!(f, "IN"),
ArgMode::Out => write!(f, "OUT"),
ArgMode::InOut => write!(f, "INOUT"),
}
}
}

/// These attributes inform the query optimizer about the behavior of the function.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum FunctionBehavior {
Immutable,
Stable,
Volatile,
}

impl fmt::Display for FunctionBehavior {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionBehavior::Immutable => write!(f, "IMMUTABLE"),
FunctionBehavior::Stable => write!(f, "STABLE"),
FunctionBehavior::Volatile => write!(f, "VOLATILE"),
}
}
}

/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CreateFunctionBody {
/// LANGUAGE lang_name
pub language: Option<Ident>,
/// IMMUTABLE | STABLE | VOLATILE
pub behavior: Option<FunctionBehavior>,
/// AS 'definition'
///
/// Note that Hive's `AS class_name` is also parsed here.
pub as_: Option<String>,
/// RETURN expression
pub return_: Option<Expr>,
/// USING ... (Hive only)
pub using: Option<CreateFunctionUsing>,
}

impl fmt::Display for CreateFunctionBody {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(language) = &self.language {
write!(f, " LANGUAGE {language}")?;
}
if let Some(behavior) = &self.behavior {
write!(f, " {behavior}")?;
}
if let Some(definition) = &self.as_ {
write!(f, " AS '{definition}'")?;
}
if let Some(expr) = &self.return_ {
write!(f, " RETURN {expr}")?;
}
if let Some(using) = &self.using {
write!(f, " {using}")?;
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CreateFunctionUsing {
Expand Down
3 changes: 3 additions & 0 deletions src/keywords.rs
Expand Up @@ -284,6 +284,7 @@ define_keywords!(
IF,
IGNORE,
ILIKE,
IMMUTABLE,
IN,
INCREMENT,
INDEX,
Expand Down Expand Up @@ -518,6 +519,7 @@ define_keywords!(
SQLSTATE,
SQLWARNING,
SQRT,
STABLE,
START,
STATIC,
STATISTICS,
Expand Down Expand Up @@ -604,6 +606,7 @@ define_keywords!(
VERSIONING,
VIEW,
VIRTUAL,
VOLATILE,
WEEK,
WHEN,
WHENEVER,
Expand Down
130 changes: 118 additions & 12 deletions src/parser.rs
Expand Up @@ -2026,9 +2026,11 @@ impl<'a> Parser<'a> {
self.parse_create_view(or_replace)
} else if self.parse_keyword(Keyword::EXTERNAL) {
self.parse_create_external_table(or_replace)
} else if self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(or_replace, temporary)
} else if or_replace {
self.expected(
"[EXTERNAL] TABLE or [MATERIALIZED] VIEW after CREATE OR REPLACE",
"[EXTERNAL] TABLE or [MATERIALIZED] VIEW or FUNCTION after CREATE OR REPLACE",
self.peek_token(),
)
} else if self.parse_keyword(Keyword::INDEX) {
Expand All @@ -2041,8 +2043,6 @@ impl<'a> Parser<'a> {
self.parse_create_schema()
} else if self.parse_keyword(Keyword::DATABASE) {
self.parse_create_database()
} else if dialect_of!(self is HiveDialect) && self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(temporary)
} else if self.parse_keyword(Keyword::ROLE) {
self.parse_create_role()
} else if self.parse_keyword(Keyword::SEQUENCE) {
Expand Down Expand Up @@ -2253,20 +2253,126 @@ impl<'a> Parser<'a> {
}
}

pub fn parse_create_function(&mut self, temporary: bool) -> Result<Statement, ParserError> {
let name = self.parse_object_name()?;
self.expect_keyword(Keyword::AS)?;
let class_name = self.parse_literal_string()?;
let using = self.parse_optional_create_function_using()?;
pub fn parse_create_function(
&mut self,
or_replace: bool,
temporary: bool,
) -> Result<Statement, ParserError> {
if dialect_of!(self is HiveDialect) {
let name = self.parse_object_name()?;
self.expect_keyword(Keyword::AS)?;
let class_name = self.parse_literal_string()?;
let params = CreateFunctionBody {
as_: Some(class_name),
using: self.parse_optional_create_function_using()?,
..Default::default()
};

Ok(Statement::CreateFunction {
temporary,
Ok(Statement::CreateFunction {
or_replace,
temporary,
name,
args: None,
return_type: None,
params,
})
} else if dialect_of!(self is PostgreSqlDialect) {
let name = self.parse_object_name()?;
self.expect_token(&Token::LParen)?;
let args = self.parse_comma_separated(Parser::parse_create_function_arg)?;
self.expect_token(&Token::RParen)?;

let return_type = if self.parse_keyword(Keyword::RETURNS) {
Some(self.parse_data_type()?)
} else {
None
};

let params = self.parse_create_function_body()?;

Ok(Statement::CreateFunction {
or_replace,
temporary,
name,
args: Some(args),
return_type,
params,
})
} else {
self.prev_token();
self.expected("an object type after CREATE", self.peek_token())
}
}

fn parse_create_function_arg(&mut self) -> Result<CreateFunctionArg, ParserError> {
let mode = if self.parse_keyword(Keyword::IN) {
Some(ArgMode::In)
} else if self.parse_keyword(Keyword::OUT) {
Some(ArgMode::Out)
} else if self.parse_keyword(Keyword::INOUT) {
Some(ArgMode::InOut)
} else {
None
};

// parse: [ argname ] argtype
let mut name = None;
let mut data_type = self.parse_data_type()?;
if let DataType::Custom(n, _) = &data_type {
// the first token is actually a name
name = Some(n.0[0].clone());
data_type = self.parse_data_type()?;
}

let default_expr = if self.parse_keyword(Keyword::DEFAULT) || self.consume_token(&Token::Eq)
{
Some(self.parse_expr()?)
} else {
None
};
Ok(CreateFunctionArg {
mode,
name,
class_name,
using,
data_type,
default_expr,
})
}

fn parse_create_function_body(&mut self) -> Result<CreateFunctionBody, ParserError> {
let mut body = CreateFunctionBody::default();
loop {
fn ensure_not_set<T>(field: &Option<T>, name: &str) -> Result<(), ParserError> {
if field.is_some() {
return Err(ParserError::ParserError(format!(
"{name} specified more than once",
)));
}
Ok(())
}
if self.parse_keyword(Keyword::AS) {
ensure_not_set(&body.as_, "AS")?;
body.as_ = Some(self.parse_literal_string()?);
} else if self.parse_keyword(Keyword::LANGUAGE) {
ensure_not_set(&body.language, "LANGUAGE")?;
body.language = Some(self.parse_identifier()?);
} else if self.parse_keyword(Keyword::IMMUTABLE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Immutable);
} else if self.parse_keyword(Keyword::STABLE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Stable);
} else if self.parse_keyword(Keyword::VOLATILE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Volatile);
} else if self.parse_keyword(Keyword::RETURN) {
ensure_not_set(&body.return_, "RETURN")?;
body.return_ = Some(self.parse_expr()?);
} else {
return Ok(body);
}
}
}

pub fn parse_create_external_table(
&mut self,
or_replace: bool,
Expand Down

0 comments on commit 5b53df9

Please sign in to comment.