Skip to content

Commit

Permalink
feat: support multi ordered primary key (#633)
Browse files Browse the repository at this point in the history
* feat: support multi ordered primary key

Signed-off-by: Shmiwy <wyf000219@126.com>

* feat: support multi ordered primary key

Signed-off-by: Shmiwy <wyf000219@126.com>

* feat: support multi ordered primary key

Signed-off-by: Shmiwy <wyf000219@126.com>
  • Loading branch information
Shmiwy committed Apr 21, 2022
1 parent d28389a commit fc5a837
Show file tree
Hide file tree
Showing 15 changed files with 334 additions and 107 deletions.
194 changes: 166 additions & 28 deletions src/binder/statement/create_table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0.

use sqlparser::ast::TableConstraint;

use super::*;
Expand All @@ -14,6 +13,7 @@ pub struct BoundCreateTable {
pub schema_id: SchemaId,
pub table_name: String,
pub columns: Vec<ColumnCatalog>,
pub ordered_pk_ids: Vec<ColumnId>,
}

impl Binder {
Expand Down Expand Up @@ -45,46 +45,110 @@ impl Binder {
}
}

// record primary key names declared by "primary key" keywords
// if a sql like " create table t (a int, b int, c int, d int, primary key(a, b)); "
// then after the for loop, the extra_pk_name will contain "a" and "b"
let mut extra_pk_name: HashSet<String> = HashSet::new();
for constraint in constraints {
match constraint {
TableConstraint::Unique {
is_primary,
columns,
..
} if *is_primary => columns.iter().for_each(|indent| {
extra_pk_name.insert(indent.value.clone());
}),

_ => todo!(),
}
let mut ordered_pk_ids = Binder::ordered_pks_from_columns(columns);
let has_pk_from_column = !ordered_pk_ids.is_empty();

if ordered_pk_ids.len() > 1 {
// multi primary key should be declared by "primary key(c1, c2...)" syntax
return Err(BindError::NotSupportedTSQL);
}

let pks_name_from_constraints = Binder::pks_name_from_constraints(constraints);
if has_pk_from_column && !pks_name_from_constraints.is_empty() {
// can't get primary key both from "primary key(c1, c2...)" syntax and
// column's option
return Err(BindError::NotSupportedTSQL);
} else if !has_pk_from_column {
ordered_pk_ids =
Binder::ordered_pks_from_constraint(&pks_name_from_constraints, columns);
}

let columns = columns
let mut columns: Vec<ColumnCatalog> = columns
.iter()
.enumerate()
.map(|(idx, col)| {
let mut col = ColumnCatalog::from(col);
if extra_pk_name.contains(col.name()) && !col.is_primary() {
col.set_primary(true);
}
col.set_id(idx as ColumnId);
col
})
.collect();

// // TODO: when remove `is_primary` filed in `ColumnDesc`,
// // Remove this line and change `columns` above to immut.
for &index in &ordered_pk_ids {
columns[index as usize].set_primary(true);
}

Ok(BoundCreateTable {
database_id: db.id(),
schema_id: schema.id(),
table_name: table_name.into(),
columns,
ordered_pk_ids,
})
}
_ => panic!("mismatched statement type"),
}
}

/// get primary keys' id in declared order。
/// we use index in columns vector as column id
fn ordered_pks_from_columns(columns: &[ColumnDef]) -> Vec<ColumnId> {
let mut ordered_pks = Vec::new();

for (index, col_def) in columns.iter().enumerate() {
for option_def in &col_def.options {
let is_primary_ = if let ColumnOption::Unique { is_primary } = option_def.option {
is_primary
} else {
false
};
if is_primary_ {
ordered_pks.push(index as ColumnId);
}
}
}
ordered_pks
}

/// We have used `pks_name_from_constraints` to get the primary keys' name sorted by declaration
/// order in "primary key(c1, c2..)" syntax. Now we transfer the name to id to get the sorted
/// ID
fn ordered_pks_from_constraint(pks_name: &[String], columns: &[ColumnDef]) -> Vec<ColumnId> {
let mut ordered_pks = vec![0; pks_name.len()];
let mut pos_in_ordered_pk = HashMap::new(); // used to get pos from column name
pks_name.iter().enumerate().for_each(|(pos, name)| {
pos_in_ordered_pk.insert(name, pos);
});

columns.iter().enumerate().for_each(|(index, colum_desc)| {
let column_name = &colum_desc.name.value;
if pos_in_ordered_pk.contains_key(column_name) {
let id = index as ColumnId;
let pos = *(pos_in_ordered_pk.get(column_name).unwrap());
ordered_pks[pos] = id;
}
});
ordered_pks
}
/// get the primary keys' name sorted by declaration order in "primary key(c1, c2..)" syntax.
fn pks_name_from_constraints(constraints: &[TableConstraint]) -> Vec<String> {
let mut pks_name_from_constraints = vec![];

for constraint in constraints {
match constraint {
TableConstraint::Unique {
is_primary,
columns,
..
} if *is_primary => columns.iter().for_each(|ident| {
pks_name_from_constraints.push(ident.value.clone());
}),
_ => continue,
}
}
pks_name_from_constraints
}
}

impl From<&ColumnDef> for ColumnCatalog {
Expand Down Expand Up @@ -124,10 +188,15 @@ mod tests {
let catalog = Arc::new(RootCatalog::new());
let mut binder = Binder::new(catalog.clone());
let sql = "
create table t1 (v1 int not null, v2 int);
create table t1 (v1 int not null, v2 int);
create table t2 (a int not null, a int not null);
create table t3 (v1 int not null);
create table t4 (a int not null, b int not null, c int, primary key(a,b));";
create table t4 (a int not null, b int not null, c int, primary key(a, b));
create table t5 (a int not null, b int not null, c int, primary key(b, a));
create table t6 (a int primary key, b int not null, c int not null, primary key(b, c));
create table t7 (a int primary key, b int);
create table t8 (a int not null, b int, primary key(a));";

let stmts = parse(sql).unwrap();

assert_eq!(
Expand All @@ -139,13 +208,14 @@ mod tests {
columns: vec![
ColumnCatalog::new(
0,
DataTypeKind::Int(None).not_null().to_column("v1".into())
DataTypeKind::Int(None).not_null().to_column("v1".into()),
),
ColumnCatalog::new(
1,
DataTypeKind::Int(None).nullable().to_column("v2".into())
DataTypeKind::Int(None).nullable().to_column("v2".into()),
),
],
ordered_pk_ids: vec![],
}
);

Expand All @@ -156,7 +226,9 @@ mod tests {

let database = catalog.get_database_by_id(0).unwrap();
let schema = database.get_schema_by_id(0).unwrap();
schema.add_table("t3".into(), vec![], false).unwrap();
schema
.add_table("t3".into(), vec![], false, vec![])
.unwrap();
assert_eq!(
binder.bind_create_table(&stmts[2]),
Err(BindError::DuplicatedTable("t3".into()))
Expand All @@ -181,11 +253,77 @@ mod tests {
.not_null()
.to_column_primary_key("b".into()),
),
ColumnCatalog::new(2, DataTypeKind::Int(None).nullable().to_column("c".into())),
],
ordered_pk_ids: vec![0, 1],
}
);

assert_eq!(
binder.bind_create_table(&stmts[4]).unwrap(),
BoundCreateTable {
database_id: 0,
schema_id: 0,
table_name: "t5".into(),
columns: vec![
ColumnCatalog::new(
2,
DataTypeKind::Int(None).nullable().to_column("c".into(),),
0,
DataTypeKind::Int(None)
.not_null()
.to_column_primary_key("a".into()),
),
ColumnCatalog::new(
1,
DataTypeKind::Int(None)
.not_null()
.to_column_primary_key("b".into()),
),
ColumnCatalog::new(2, DataTypeKind::Int(None).nullable().to_column("c".into())),
],
ordered_pk_ids: vec![1, 0],
}
);

assert_eq!(
binder.bind_create_table(&stmts[5]),
Err(BindError::NotSupportedTSQL)
);

assert_eq!(
binder.bind_create_table(&stmts[6]).unwrap(),
BoundCreateTable {
database_id: 0,
schema_id: 0,
table_name: "t7".into(),
columns: vec![
ColumnCatalog::new(
0,
DataTypeKind::Int(None)
.nullable()
.to_column_primary_key("a".into()),
),
ColumnCatalog::new(1, DataTypeKind::Int(None).nullable().to_column("b".into())),
],
ordered_pk_ids: vec![0],
}
);

assert_eq!(
binder.bind_create_table(&stmts[7]).unwrap(),
BoundCreateTable {
database_id: 0,
schema_id: 0,
table_name: "t8".into(),
columns: vec![
ColumnCatalog::new(
0,
DataTypeKind::Int(None)
.not_null()
.to_column_primary_key("a".into()),
),
ColumnCatalog::new(1, DataTypeKind::Int(None).nullable().to_column("b".into())),
],
ordered_pk_ids: vec![0],
}
);
}
Expand Down
4 changes: 3 additions & 1 deletion src/binder/statement/drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ mod tests {

let database = catalog.get_database_by_id(0).unwrap();
let schema = database.get_schema_by_id(0).unwrap();
schema.add_table("mytable".into(), vec![], false).unwrap();
schema
.add_table("mytable".into(), vec![], false, vec![])
.unwrap();

let stmts = parse("drop table mytable").unwrap();
assert_eq!(
Expand Down
1 change: 1 addition & 0 deletions src/binder/statement/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ mod tests {
ColumnCatalog::new(1, DataTypeKind::Int(None).not_null().to_column("b".into())),
],
false,
vec![],
)
.unwrap();

Expand Down
1 change: 1 addition & 0 deletions src/catalog/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl DatabaseCatalog {
.to_column("github_id".into()),
)],
false,
vec![],
)
.unwrap();
}
Expand Down
4 changes: 3 additions & 1 deletion src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use super::{CatalogError, ColumnCatalog, TableCatalog};
use crate::types::{SchemaId, TableId};
use crate::types::{ColumnId, SchemaId, TableId};

/// The catalog of a schema.
pub struct SchemaCatalog {
Expand Down Expand Up @@ -37,6 +37,7 @@ impl SchemaCatalog {
name: String,
columns: Vec<ColumnCatalog>,
is_materialized_view: bool,
ordered_pk_ids: Vec<ColumnId>,
) -> Result<TableId, CatalogError> {
let mut inner = self.inner.lock().unwrap();
if inner.table_idxs.contains_key(&name) {
Expand All @@ -49,6 +50,7 @@ impl SchemaCatalog {
name.clone(),
columns,
is_materialized_view,
ordered_pk_ids,
));
inner.table_idxs.insert(name, table_id);
inner.tables.insert(table_id, table_catalog);
Expand Down
Loading

0 comments on commit fc5a837

Please sign in to comment.