diff --git a/Cargo.toml b/Cargo.toml index 7f9070d5..49703b47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ indexset = { version = "0.12.2", features = ["concurrent", "cdc", "multimap"] } # indexset = { git = "https://github.com/Handy-caT/indexset", branch = "multimap-range-fix", version = "0.12.0", features = ["concurrent", "cdc", "multimap"] } convert_case = "0.6.0" ordered-float = "5.0.0" -serde = { version = "1.0.215", features = ["derive"] } +parking_lot = "0.12.3" prettytable-rs = "^0.10" [dev-dependencies] diff --git a/codegen/src/worktable/generator/queries/in_place.rs b/codegen/src/worktable/generator/queries/in_place.rs new file mode 100644 index 00000000..1a971b78 --- /dev/null +++ b/codegen/src/worktable/generator/queries/in_place.rs @@ -0,0 +1,151 @@ +use crate::name_generator::WorktableNameGenerator; +use crate::worktable::generator::Generator; +use crate::worktable::model::Operation; +use convert_case::{Case, Casing}; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::quote; +use std::collections::HashMap; + +impl Generator { + pub fn gen_query_in_place_impl(&self) -> syn::Result { + let name_generator = WorktableNameGenerator::from_table_name(self.name.to_string()); + let table_ident = name_generator.get_work_table_ident(); + + let custom_in_place = if let Some(q) = &self.queries { + let custom_in_place = self.gen_in_place_queries(q.in_place.clone()); + quote! { + #custom_in_place + } + } else { + quote! {} + }; + + Ok(quote! { + impl #table_ident { + #custom_in_place + } + }) + } + + fn gen_in_place_queries(&self, in_place_queries: HashMap) -> TokenStream { + let defs = in_place_queries + .iter() + .map(|(name, op)| { + let snake_case_name = name + .to_string() + .from_case(Case::Pascal) + .to_case(Case::Snake); + let index = self.columns.indexes.values().find(|idx| idx.field == op.by); + let by_type = self.columns.columns_map.get(&op.by).unwrap(); + if let Some(index) = index { + let _index_name = &index.name; + + if index.is_unique { + todo!() + } else { + todo!() + } + } else if self.columns.primary_keys.len() == 1 { + self.gen_primary_key_in_place(snake_case_name, by_type, &op.columns) + } else { + todo!() + } + }) + .collect::>(); + + quote! { + #(#defs)* + } + } + + fn gen_primary_key_in_place( + &self, + snake_case_name: String, + by_type: &TokenStream, + columns: &[Ident], + ) -> TokenStream { + let name_generator = WorktableNameGenerator::from_table_name(self.name.to_string()); + let lock_type_ident = name_generator.get_lock_type_ident(); + let pk_type = name_generator.get_primary_key_type_ident(); + + let lock_await_ident = + WorktableNameGenerator::get_update_in_place_query_lock_await_ident(&snake_case_name); + let unlock_ident = + WorktableNameGenerator::get_update_in_place_query_unlock_ident(&snake_case_name); + let lock_ident = + WorktableNameGenerator::get_update_in_place_query_lock_ident(&snake_case_name); + + let method_ident = Ident::new( + format!("update_{snake_case_name}_in_place").as_str(), + Span::mixed_site(), + ); + + let types = columns + .iter() + .map(|c| self.columns.columns_map.get(c).unwrap()) + .collect::>(); + let column_types = if types.len() == 1 { + let t = types[0]; + quote! { + &mut <#t as Archive>::Archived + } + } else { + let types = types.iter().map(|t| { + quote! { + &mut <#t as Archive>::Archived + } + }); + quote! { + ( #(#types),* ) + } + }; + let column_fields = if columns.len() == 1 { + let i = &columns[0]; + quote! { + &mut archived.inner.#i + } + } else { + let columns = columns.iter().map(|i| { + quote! { + &mut archived.inner.#i + } + }); + quote! { + ( #(#columns),* ) + } + }; + + quote! { + pub async fn #method_ident( + &self, + mut f: F, + by: #by_type, + ) -> eyre::Result<()> { + let pk: #pk_type = by.into(); + let lock_id = self.0.lock_map.next_id(); + let mut lock = #lock_type_ident::new(lock_id.into()); + lock.#lock_ident(); + let lock = std::sync::Arc::new(lock); + if let Some(lock) = self.0.lock_map.insert(pk.clone(), lock.clone()) { + lock.#lock_await_ident().await; + } + let link = self + .0 + .pk_map + .get(&pk) + .map(|v| v.get().value) + .ok_or(WorkTableError::NotFound)?; + unsafe { + self.0 + .data + .with_mut_ref(link, move |archived| f(#column_fields)) + .map_err(WorkTableError::PagesError)? + }; + lock.#unlock_ident(); + self.0.lock_map.remove_with_lock_check(&pk, lock); + + Ok(()) + } + } + } +} diff --git a/codegen/src/worktable/generator/queries/locks.rs b/codegen/src/worktable/generator/queries/locks.rs index a4b2d9dc..ad04dfd1 100644 --- a/codegen/src/worktable/generator/queries/locks.rs +++ b/codegen/src/worktable/generator/queries/locks.rs @@ -1,117 +1,208 @@ -use crate::name_generator::WorktableNameGenerator; -use crate::worktable::generator::Generator; +use std::collections::HashMap; + use convert_case::{Case, Casing}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; +use crate::name_generator::WorktableNameGenerator; +use crate::worktable::generator::Generator; +use crate::worktable::model::Operation; + +impl WorktableNameGenerator { + pub fn get_update_query_lock_await_ident(snake_case_name: &String) -> Ident { + Ident::new( + format!("lock_await_update_{snake_case_name}").as_str(), + Span::mixed_site(), + ) + } + + pub fn get_update_query_lock_ident(snake_case_name: &String) -> Ident { + Ident::new( + format!("lock_update_{snake_case_name}").as_str(), + Span::mixed_site(), + ) + } + + pub fn get_update_query_unlock_ident(snake_case_name: &String) -> Ident { + Ident::new( + format!("unlock_update_{snake_case_name}").as_str(), + Span::mixed_site(), + ) + } + + pub fn get_update_in_place_query_lock_await_ident(snake_case_name: &String) -> Ident { + Ident::new( + format!("lock_await_update_in_place_{snake_case_name}").as_str(), + Span::mixed_site(), + ) + } + + pub fn get_update_in_place_query_lock_ident(snake_case_name: &String) -> Ident { + Ident::new( + format!("lock_update_in_place_{snake_case_name}").as_str(), + Span::mixed_site(), + ) + } + + pub fn get_update_in_place_query_unlock_ident(snake_case_name: &String) -> Ident { + Ident::new( + format!("unlock_update_in_place_{snake_case_name}").as_str(), + Span::mixed_site(), + ) + } +} + impl Generator { pub fn gen_query_locks_impl(&mut self) -> syn::Result { if let Some(q) = &self.queries { let name_generator = WorktableNameGenerator::from_table_name(self.name.to_string()); let lock_type_ident = name_generator.get_lock_type_ident(); - let fns = q - .updates - .keys() - .map(|name| { - let snake_case_name = name - .to_string() - .from_case(Case::Pascal) - .to_case(Case::Snake); - - let lock_await_ident = Ident::new( - format!("lock_await_{snake_case_name}").as_str(), - Span::mixed_site(), - ); + let update_fns = Self::gen_update_query_locks(&q.updates); + let update_in_place_fns = Self::gen_in_place_update_query_locks(&q.in_place); - let lock_ident = Ident::new( - format!("lock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); + Ok(quote! { + impl #lock_type_ident { + #update_fns + #update_in_place_fns + } + }) + } else { + Ok(quote! {}) + } + } - let unlock_ident = Ident::new( - format!("unlock_{snake_case_name}").as_str(), - Span::mixed_site(), + fn gen_in_place_update_query_locks(updates: &HashMap) -> TokenStream { + let fns = updates + .keys() + .map(|name| { + let snake_case_name = name + .to_string() + .from_case(Case::Pascal) + .to_case(Case::Snake); + + let lock_await_ident = + WorktableNameGenerator::get_update_in_place_query_lock_await_ident( + &snake_case_name, ); + let lock_ident = + WorktableNameGenerator::get_update_in_place_query_lock_ident(&snake_case_name); + let unlock_ident = WorktableNameGenerator::get_update_in_place_query_unlock_ident( + &snake_case_name, + ); + + let columns = &updates.get(name).as_ref().expect("exists").columns; + let lock_await_fn = Self::gen_rows_lock_await_fn(columns, lock_await_ident); + let lock_fn = Self::gen_rows_lock_fn(columns, lock_ident); + let unlock_fn = Self::gen_rows_unlock_fn(columns, unlock_ident); - let rows_lock_await = q - .updates - .get(name) - .expect("exists") - .columns - .iter() - .map(|col| { - let col = - Ident::new(format!("{}_lock", col).as_str(), Span::mixed_site()); - quote! { - if let Some(lock) = &self.#col { - futures.push(lock.as_ref()); - } - } - }) - .collect::>(); - - let rows_lock = q - .updates - .get(name) - .expect("exists") - .columns - .iter() - .map(|col| { - let col = - Ident::new(format!("{}_lock", col).as_str(), Span::mixed_site()); - quote! { - if self.#col.is_none() { - self.#col = Some(std::sync::Arc::new(Lock::new())); - } - } - }) - .collect::>(); - - let rows_unlock = q - .updates - .get(name) - .expect("exists") - .columns - .iter() - .map(|col| { - let col = - Ident::new(format!("{}_lock", col).as_str(), Span::mixed_site()); - quote! { - if let Some(#col) = &self.#col { - #col.unlock(); - } - } - }) - .collect::>(); - - quote! { - - pub fn #lock_ident(&mut self) { - #(#rows_lock)* - } - - pub fn #unlock_ident(&self) { - #(#rows_unlock)* - } - - pub async fn #lock_await_ident(&self) { - let mut futures = Vec::new(); - - #(#rows_lock_await)* - futures::future::join_all(futures).await; - } + quote! { + #lock_fn + #lock_await_fn + #unlock_fn + } + }) + .collect::>(); + + quote! { + #(#fns)* + } + } + + fn gen_update_query_locks(updates: &HashMap) -> TokenStream { + let fns = updates + .keys() + .map(|name| { + let snake_case_name = name + .to_string() + .from_case(Case::Pascal) + .to_case(Case::Snake); + + let lock_await_ident = + WorktableNameGenerator::get_update_query_lock_await_ident(&snake_case_name); + let lock_ident = + WorktableNameGenerator::get_update_query_lock_ident(&snake_case_name); + let unlock_ident = + WorktableNameGenerator::get_update_query_unlock_ident(&snake_case_name); + + let columns = &updates.get(name).as_ref().expect("exists").columns; + let lock_await_fn = Self::gen_rows_lock_await_fn(columns, lock_await_ident); + let lock_fn = Self::gen_rows_lock_fn(columns, lock_ident); + let unlock_fn = Self::gen_rows_unlock_fn(columns, unlock_ident); + + quote! { + #lock_fn + #lock_await_fn + #unlock_fn + } + }) + .collect::>(); + + quote! { + #(#fns)* + } + } + + fn gen_rows_unlock_fn(columns: &[Ident], ident: Ident) -> TokenStream { + let inner = columns + .iter() + .map(|col| { + let col = Ident::new(format!("{}_lock", col).as_str(), Span::mixed_site()); + quote! { + if let Some(#col) = &self.#col { + #col.unlock(); } - }) - .collect::>(); + } + }) + .collect::>(); - Ok(quote! { - impl #lock_type_ident { - #(#fns)* + quote! { + pub fn #ident(&self) { + #(#inner)* + } + } + } + + fn gen_rows_lock_fn(columns: &[Ident], ident: Ident) -> TokenStream { + let inner = columns + .iter() + .map(|col| { + let col = Ident::new(format!("{}_lock", col).as_str(), Span::mixed_site()); + quote! { + if self.#col.is_none() { + self.#col = Some(std::sync::Arc::new(Lock::new())); + } } }) - } else { - Ok(quote! {}) + .collect::>(); + + quote! { + pub fn #ident(&mut self) { + #(#inner)* + } + } + } + + fn gen_rows_lock_await_fn(columns: &[Ident], ident: Ident) -> TokenStream { + let inner = columns + .iter() + .map(|col| { + let col = Ident::new(format!("{}_lock", col).as_str(), Span::mixed_site()); + quote! { + if let Some(lock) = &self.#col { + futures.push(lock.as_ref()); + } + } + }) + .collect::>(); + + quote! { + pub async fn #ident(&self) { + let mut futures = Vec::new(); + + #(#inner)* + futures::future::join_all(futures).await; + } } } } diff --git a/codegen/src/worktable/generator/queries/mod.rs b/codegen/src/worktable/generator/queries/mod.rs index d4385753..83f7a274 100644 --- a/codegen/src/worktable/generator/queries/mod.rs +++ b/codegen/src/worktable/generator/queries/mod.rs @@ -1,4 +1,5 @@ mod delete; +mod in_place; mod locks; mod select; pub mod r#type; diff --git a/codegen/src/worktable/generator/queries/update.rs b/codegen/src/worktable/generator/queries/update.rs index 4d8b9bef..8ef341f7 100644 --- a/codegen/src/worktable/generator/queries/update.rs +++ b/codegen/src/worktable/generator/queries/update.rs @@ -368,18 +368,10 @@ impl Generator { ); let query_ident = Ident::new(format!("{name}Query").as_str(), Span::mixed_site()); - let lock_await_ident = Ident::new( - format!("lock_await_{snake_case_name}").as_str(), - Span::mixed_site(), - ); - let unlock_ident = Ident::new( - format!("unlock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); - let lock_ident = Ident::new( - format!("lock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); + let lock_await_ident = + WorktableNameGenerator::get_update_query_lock_await_ident(&snake_case_name); + let unlock_ident = WorktableNameGenerator::get_update_query_unlock_ident(&snake_case_name); + let lock_ident = WorktableNameGenerator::get_update_query_lock_ident(&snake_case_name); let row_updates = idents .iter() @@ -403,8 +395,8 @@ impl Generator { let lock_id = self.0.lock_map.next_id(); let mut lock = #lock_type_ident::new(lock_id.into()); //Creates new LockType with None lock.#lock_ident(); - - self.0.lock_map.insert(pk.clone(), std::sync::Arc::new(lock.clone())); + let lock = std::sync::Arc::new(lock); + self.0.lock_map.insert(pk.clone(), lock.clone()); let mut bytes = rkyv::to_bytes::(&row).map_err(|_| WorkTableError::SerializeError)?; let mut archived_row = unsafe { rkyv::access_unchecked_mut::<<#query_ident as rkyv::Archive>::Archived>(&mut bytes[..]).unseal_unchecked() }; @@ -423,7 +415,7 @@ impl Generator { }).map_err(WorkTableError::PagesError)? }; lock.#unlock_ident(); - self.0.lock_map.remove(&pk); + self.0.lock_map.remove_with_lock_check(&pk, lock); #persist_call @@ -452,18 +444,10 @@ impl Generator { let query_ident = Ident::new(format!("{name}Query").as_str(), Span::mixed_site()); let by_ident = Ident::new(format!("{name}By").as_str(), Span::mixed_site()); - let lock_await_ident = Ident::new( - format!("lock_await_{snake_case_name}").as_str(), - Span::mixed_site(), - ); - let lock_ident = Ident::new( - format!("lock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); - let unlock_ident = Ident::new( - format!("unlock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); + let lock_await_ident = + WorktableNameGenerator::get_update_query_lock_await_ident(&snake_case_name); + let unlock_ident = WorktableNameGenerator::get_update_query_unlock_ident(&snake_case_name); + let lock_ident = WorktableNameGenerator::get_update_query_lock_ident(&snake_case_name); let row_updates = idents .iter() @@ -544,7 +528,7 @@ impl Generator { let lock_id = self.0.lock_map.next_id(); let mut lock = #lock_type_ident::new(lock_id.into()); lock.#lock_ident(); - self.0.lock_map.insert(pk.clone(), std::sync::Arc::new(lock.clone())); + self.0.lock_map.insert(pk.clone(), std::sync::Arc::new(lock)); } let mut links_to_unlock = vec![]; @@ -574,7 +558,7 @@ impl Generator { let pk = self.0.data.select(*link)?.get_primary_key(); if let Some(lock) = self.0.lock_map.get(&pk) { lock.#unlock_ident(); - self.0.lock_map.remove(&pk); + self.0.lock_map.remove_with_lock_check(&pk, lock); } } core::result::Result::Ok(()) @@ -602,18 +586,10 @@ impl Generator { let query_ident = Ident::new(format!("{name}Query").as_str(), Span::mixed_site()); let by_ident = Ident::new(format!("{name}By").as_str(), Span::mixed_site()); - let lock_await_ident = Ident::new( - format!("lock_await_{snake_case_name}").as_str(), - Span::mixed_site(), - ); - let lock_ident = Ident::new( - format!("lock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); - let unlock_ident = Ident::new( - format!("unlock_{snake_case_name}").as_str(), - Span::mixed_site(), - ); + let lock_await_ident = + WorktableNameGenerator::get_update_query_lock_await_ident(&snake_case_name); + let unlock_ident = WorktableNameGenerator::get_update_query_unlock_ident(&snake_case_name); + let lock_ident = WorktableNameGenerator::get_update_query_lock_ident(&snake_case_name); let row_updates = idents .iter() @@ -659,7 +635,8 @@ impl Generator { let lock_id = self.0.lock_map.next_id(); let mut lock = #lock_type_ident::new(lock_id.into()); lock.#lock_ident(); - self.0.lock_map.insert(pk.clone(), std::sync::Arc::new(lock.clone())); + let lock = std::sync::Arc::new(lock); + self.0.lock_map.insert(pk.clone(), lock.clone()); #size_check #diff_process @@ -672,7 +649,7 @@ impl Generator { } lock.#unlock_ident(); - self.0.lock_map.remove(&pk); + self.0.lock_map.remove_with_lock_check(&pk, lock); #persist_call diff --git a/codegen/src/worktable/mod.rs b/codegen/src/worktable/mod.rs index 983e45cc..45364582 100644 --- a/codegen/src/worktable/mod.rs +++ b/codegen/src/worktable/mod.rs @@ -58,6 +58,7 @@ pub fn expand(input: TokenStream) -> syn::Result { let query_locks_impls = generator.gen_query_locks_impl()?; let select_impls = generator.gen_query_select_impl()?; let update_impls = generator.gen_query_update_impl()?; + let update_in_place_impls = generator.gen_query_in_place_impl()?; let delete_impls = generator.gen_query_delete_impl()?; let unsized_impl = generator.gen_unsized_impls(); @@ -73,6 +74,7 @@ pub fn expand(input: TokenStream) -> syn::Result { #query_locks_impls #select_impls #update_impls + #update_in_place_impls #delete_impls #unsized_impl }) diff --git a/codegen/src/worktable/model/queries.rs b/codegen/src/worktable/model/queries.rs index a744d0f0..fdc8b170 100644 --- a/codegen/src/worktable/model/queries.rs +++ b/codegen/src/worktable/model/queries.rs @@ -8,4 +8,5 @@ use crate::worktable::model::Operation; pub struct Queries { pub updates: HashMap, pub deletes: HashMap, + pub in_place: HashMap, } diff --git a/codegen/src/worktable/parser/queries/in_place.rs b/codegen/src/worktable/parser/queries/in_place.rs new file mode 100644 index 00000000..0bd44fee --- /dev/null +++ b/codegen/src/worktable/parser/queries/in_place.rs @@ -0,0 +1,71 @@ +use std::collections::HashMap; + +use proc_macro2::{Ident, TokenTree}; +use syn::spanned::Spanned; + +use crate::worktable::model::Operation; +use crate::worktable::Parser; + +impl Parser { + pub fn parse_in_place(&mut self) -> syn::Result> { + let ident = self.input_iter.next().ok_or(syn::Error::new( + self.input.span(), + "Expected `in_place` field in declaration", + ))?; + if let TokenTree::Ident(ident) = ident { + if ident.to_string().as_str() != "in_place" { + return Err(syn::Error::new(ident.span(), "Expected `in_place` field")); + } + } else { + return Err(syn::Error::new( + ident.span(), + "Expected field name identifier.", + )); + }; + + self.parse_colon()?; + + let ops = self.input_iter.next().ok_or(syn::Error::new( + self.input.span(), + "Expected operation declarations", + ))?; + if let TokenTree::Group(ops) = ops { + let mut parser = Parser::new(ops.stream()); + parser.parse_operations() + } else { + Err(syn::Error::new( + ops.span(), + "Expected operation declarations", + )) + } + } +} + +#[cfg(test)] +mod tests { + use proc_macro2::{Ident, Span}; + use quote::quote; + + use crate::worktable::Parser; + + #[test] + fn test_update() { + let tokens = quote! { + in_place: { + TestQuery(id) by name, + } + }; + let mut parser = Parser::new(tokens); + let ops = parser.parse_in_place().unwrap(); + + assert_eq!(ops.len(), 1); + let op = ops + .get(&Ident::new("TestQuery", Span::mixed_site())) + .unwrap(); + + assert_eq!(op.name, "TestQuery"); + assert_eq!(op.columns.len(), 1); + assert_eq!(op.columns[0], "id"); + assert_eq!(op.by, "name"); + } +} diff --git a/codegen/src/worktable/parser/queries/mod.rs b/codegen/src/worktable/parser/queries/mod.rs index dba0a358..a35bb8c7 100644 --- a/codegen/src/worktable/parser/queries/mod.rs +++ b/codegen/src/worktable/parser/queries/mod.rs @@ -1,4 +1,5 @@ mod delete; +mod in_place; mod operation; mod select; mod update; @@ -48,6 +49,10 @@ impl Parser { let deletes = parser.parse_deletes()?; queries.deletes = deletes; } + "in_place" => { + let in_place = parser.parse_in_place()?; + queries.in_place = in_place; + } _ => return Err(syn::Error::new(ident.span(), "Unexpected identifier")), } } diff --git a/docs/queries.md b/docs/queries.md new file mode 100644 index 00000000..b43a2097 --- /dev/null +++ b/docs/queries.md @@ -0,0 +1,91 @@ +# Queries + +WorkTable support query definition feature. Users can add custom `update`, `delete` and `update_in_place` queries. + +```rust +worktable!( + name: Something, + columns: { + id: u64 primary_key autoincrement, + name: String + amount: u64, + some_value: i64, + }, + indexes: { + value_idx: value unique, + name_idx: name, + some_value_idx: some_value, + }, + // Queries declaration section. + queries: { + // `update` queries + update: { + AmountById(amount) by id, + }, + // `delete` queries + delete: { + ByName() by name, + }, + in_place: { + SomeValueById(some_value) by id, + } + } +); +``` + +### `update` queries + +`TODO` + +### `update_in_place` queries + +`update_in_place` queries are special update queries that allow you to update field's value +without need to select it before query. It is useful for counters, as example, because with +internal mutation queries locking logic user's don't need to add explicit locks over `WorkTable` +object. So you can safely use `update_in_place` queries in multiple threads simultaneously. + +!!! For now only `by {pk_field}` queries are supported !!! + +To declare `update_in_place` query you need to add `in_place` section to `queries`. Query definition is +same to `update`: `{YourQueryNameCamelCase}({fields_you_want_to_update}) by {by_field_name}`. +For example, in declaration above `update_in_place` is declared like this: + +``` +in_place: { + SomeValueById(some_value) by id, +} +``` + +It will generate `update_some_value_by_id_in_place` method for `WorkTable` object (name generation logic is same +as for other queries). It will have two arguments: your `by` field value and closure, where you can use mutable +field value itself. + +```rust +#[tokio::main] +async fn main() -> eyre::Result<()> { + // Table creation. + let table = SomethingWorkTable::default(); + let row = SomethingRow { + // Autoincrement primary key generation. + id: table.get_next_pk().into(), + name: "SomeName".to_string(), + amount: 100, + some_value: 0, + }; + let pk = table.insert(row)?; + // This will lead to `some_value` field update by adding 100 to it value. + table + .update_some_value_by_id_in_place(|some_value| *some_value += 100, pk.0) + .await?; + let row = table.select(pk)?; + assert_eq!(row.some_value, 100); + + Ok(()) +} +``` + +You can find tests that covers `update_in_place` queries [here](../tests/worktable/in_place.rs). + +### `delete` queries + +`TODO` \ No newline at end of file diff --git a/src/lock/mod.rs b/src/lock/mod.rs index d405d117..114f27aa 100644 --- a/src/lock/mod.rs +++ b/src/lock/mod.rs @@ -25,12 +25,12 @@ impl Lock { } pub fn unlock(&self) { - self.locked.store(false, Ordering::Relaxed); + self.locked.store(false, Ordering::Release); self.waker.wake() } pub fn lock(&self) { - self.locked.store(true, Ordering::Relaxed); + self.locked.store(true, Ordering::Release); self.waker.wake() } } @@ -40,7 +40,7 @@ impl Future for &Lock { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.as_ref().waker.register(cx.waker()); - if self.locked.load(Ordering::Relaxed) { + if self.locked.load(Ordering::Acquire) { Poll::Pending } else { Poll::Ready(()) diff --git a/src/lock/set.rs b/src/lock/set.rs index 948eb676..89e41911 100644 --- a/src/lock/set.rs +++ b/src/lock/set.rs @@ -1,14 +1,15 @@ +use std::collections::HashMap; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::Arc; -use lockfree::map::Map; +use parking_lot::RwLock; #[derive(Debug)] pub struct LockMap where PkType: std::hash::Hash + Ord, { - set: Map>>, + set: RwLock>>, next_id: AtomicU16, } @@ -27,21 +28,33 @@ where { pub fn new() -> Self { Self { - set: Map::new(), + set: RwLock::new(HashMap::new()), next_id: AtomicU16::default(), } } - pub fn insert(&self, key: PkType, lock: Arc) { - self.set.insert(key, Some(lock)); + pub fn insert(&self, key: PkType, lock: Arc) -> Option> { + self.set.write().insert(key, lock) } pub fn get(&self, key: &PkType) -> Option> { - self.set.get(key).map(|v| v.val().clone())? + self.set.read().get(key).cloned() } pub fn remove(&self, key: &PkType) { - self.set.remove(key); + self.set.write().remove(key); + } + + pub fn remove_with_lock_check(&self, key: &PkType, lock: Arc) + where + PkType: Clone, + { + let mut set = self.set.write(); + if let Some(l) = set.remove(key) { + if !Arc::ptr_eq(&l, &lock) { + set.insert(key.clone(), l); + } + } } pub fn next_id(&self) -> u16 { diff --git a/tests/worktable/in_place.rs b/tests/worktable/in_place.rs new file mode 100644 index 00000000..cbe2b195 --- /dev/null +++ b/tests/worktable/in_place.rs @@ -0,0 +1,91 @@ +use std::sync::Arc; + +use rkyv::Archive; +use worktable::prelude::*; +use worktable::worktable; + +worktable!( + name: Test, + columns: { + id: u64 primary_key autoincrement, + val: i64, + val1: u64, + val2: i16, + }, + queries: { + in_place: { + ValById(val) by id, + Val2ById(val2) by id, + } + } +); + +#[tokio::test] +async fn test_update_val_by_id() -> eyre::Result<()> { + let table = TestWorkTable::default(); + let row = TestRow { + id: table.get_next_pk().0, + val: 0, + val1: 0, + val2: 0, + }; + let pk = table.insert(row)?; + for _ in 0..10000 { + table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await? + } + let row = table.select(pk).unwrap(); + assert_eq!(row.val, 10000); + Ok(()) +} + +#[tokio::test] +async fn test_update_val2_by_id() -> eyre::Result<()> { + let table = TestWorkTable::default(); + let row = TestRow { + id: table.get_next_pk().0, + val: 0, + val1: 0, + val2: 0, + }; + let pk = table.insert(row)?; + for _ in 0..100 { + table + .update_val_2_by_id_in_place(|val| *val += 1, pk.0) + .await? + } + let row = table.select(pk).unwrap(); + assert_eq!(row.val2, 100); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_update_val_by_id_multi_thread() -> eyre::Result<()> { + let table = Arc::new(TestWorkTable::default()); + let row = TestRow { + id: table.get_next_pk().0, + val: 0, + val1: 0, + val2: 0, + }; + let pk = table.insert(row)?; + let shared_table = table.clone(); + let h = tokio::spawn(async move { + for _ in 0..10000 { + shared_table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + for _ in 0..10000 { + table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await? + } + h.await?; + let row = table.select(pk).unwrap(); + assert_eq!(row.val, 20000); + Ok(()) +} diff --git a/tests/worktable/mod.rs b/tests/worktable/mod.rs index 827a5934..ba143fdf 100644 --- a/tests/worktable/mod.rs +++ b/tests/worktable/mod.rs @@ -5,6 +5,7 @@ mod config; mod count; mod custom_pk; mod float; +mod in_place; mod index; mod option; mod tuple_primary_key;