Skip to content

Commit

Permalink
Add Connection::transactional_batch (#1366)
Browse files Browse the repository at this point in the history
* Add Connection::execute_transactional_batch

This commit contains only plumbing.
There are 3 implementations that need to be provided
and they are currently implemented as a `todo!()`.
Next commits will fill in those missing implementations.

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

* Implement execute_transactional_batch for local connection

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

* Implement execute_transactional_batch for HRANA connection

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

* Implement execute_transactional_batch for GRPC connection

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

---------

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>
  • Loading branch information
haaawk committed May 5, 2024
1 parent 372311a commit 41e17aa
Show file tree
Hide file tree
Showing 9 changed files with 396 additions and 5 deletions.
27 changes: 27 additions & 0 deletions libsql-hrana/src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,33 @@ impl Batch {
replication_index: None,
}
}
pub fn transactional<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
let mut steps = Vec::new();
steps.push(BatchStep {
condition: None,
stmt: Stmt::new("BEGIN TRANSACTION", false),
});
let mut count = 0u32;
for (step, stmt) in stmts.into_iter().enumerate() {
count += 1;
let condition = Some(BatchCond::Ok { step: step as u32 });
steps.push(BatchStep { condition, stmt });
}
steps.push(BatchStep {
condition: Some(BatchCond::Ok { step: count }),
stmt: Stmt::new("COMMIT", false),
});
steps.push(BatchStep {
condition: Some(BatchCond::Not {
cond: Box::new(BatchCond::Ok { step: count + 1 }),
}),
stmt: Stmt::new("ROLLBACK", false),
});
Batch {
steps,
replication_index: None,
}
}
}

impl FromIterator<Stmt> for Batch {
Expand Down
8 changes: 8 additions & 0 deletions libsql/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub(crate) trait Conn {

async fn execute_batch(&self, sql: &str) -> Result<()>;

async fn execute_transactional_batch(&self, sql: &str) -> Result<()>;

async fn prepare(&self, sql: &str) -> Result<Statement>;

async fn transaction(&self, tx_behavior: TransactionBehavior) -> Result<Transaction>;
Expand Down Expand Up @@ -57,6 +59,12 @@ impl Connection {
self.conn.execute_batch(sql).await
}

/// Execute a batch set of statements atomically in a transaction.
pub async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
tracing::trace!("executing batch transactional `{}`", sql);
self.conn.execute_transactional_batch(sql).await
}

/// Execute sql query provided some type that implements [`IntoParams`] returning
/// on success the [`Rows`].
///
Expand Down
2 changes: 2 additions & 0 deletions libsql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub enum Error {
InvalidParserState(String),
#[error("TLS error: {0}")]
InvalidTlsConfiguration(std::io::Error),
#[error("Transactional batch error: {0}")]
TransactionalBatchError(String),
}

#[cfg(feature = "hrana")]
Expand Down
23 changes: 22 additions & 1 deletion libsql/src/hrana/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::hrana::{bind_params, unwrap_err, HranaError, HttpSend, Result};
use crate::params::Params;
use crate::transaction::Tx;
use crate::util::ConnectorService;
use crate::{Rows, Statement};
use crate::{Error, Rows, Statement};
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::{Stream, TryStreamExt};
Expand Down Expand Up @@ -121,6 +121,10 @@ impl Conn for HttpConnection<HttpSender> {
self.current_stream().execute_batch(sql).await
}

async fn execute_transactional_batch(&self, sql: &str) -> crate::Result<()> {
self.current_stream().execute_transactional_batch(sql).await
}

async fn prepare(&self, sql: &str) -> crate::Result<Statement> {
let stream = self.current_stream().clone();
let stmt = crate::hrana::Statement::new(stream, sql.to_string(), true)?;
Expand Down Expand Up @@ -273,6 +277,23 @@ impl Conn for HranaStream<HttpSender> {
unwrap_err(res)
}

async fn execute_transactional_batch(&self, sql: &str) -> crate::Result<()> {
let mut stmts = Vec::new();
let parse = crate::parser::Statement::parse(sql);
for s in parse {
let s = s?;
if s.kind == crate::parser::StmtKind::TxnBegin || s.kind == crate::parser::StmtKind::TxnBeginReadOnly || s.kind == crate::parser::StmtKind::TxnEnd {
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
}
stmts.push(Stmt::new(s.stmt, false));
}
let res = self
.batch_inner(Batch::transactional(stmts), true)
.await
.map_err(|e| crate::Error::Hrana(e.into()))?;
unwrap_err(res)
}

async fn prepare(&self, sql: &str) -> crate::Result<Statement> {
let stmt = crate::hrana::Statement::new(self.clone(), sql.to_string(), true)?;
Ok(Statement {
Expand Down
56 changes: 56 additions & 0 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,62 @@ impl Connection {
Ok(())
}

fn execute_transactional_batch_inner<S>(&self, sql: S) -> Result<()>
where
S: Into<String>,
{
let sql = sql.into();
let mut sql = sql.as_str();
while !sql.is_empty() {
let stmt = self.prepare(sql)?;

let tail = stmt.tail();
let stmt_sql = if tail == 0 || tail >= sql.len() {
sql
} else {
&sql[..tail]
};
let prefix_count = stmt_sql
.chars()
.take_while(|c| c.is_whitespace())
.count();
let stmt_sql = &stmt_sql[prefix_count..];
if stmt_sql.starts_with("BEGIN") || stmt_sql.starts_with("COMMIT") || stmt_sql.starts_with("ROLLBACK") || stmt_sql.starts_with("END") {
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
}

if !stmt.inner.raw_stmt.is_null() {
stmt.step()?;
}

if tail == 0 || tail >= sql.len() {
break;
}

sql = &sql[tail..];
}

Ok(())
}

pub fn execute_transactional_batch<S>(&self, sql: S) -> Result<()>
where
S: Into<String>,
{
self.execute("BEGIN TRANSACTION", Params::None)?;

match self.execute_transactional_batch_inner(sql) {
Ok(_) => {
self.execute("COMMIT", Params::None)?;
Ok(())
}
Err(e) => {
self.execute("ROLLBACK", Params::None)?;
Err(e)
}
}
}

/// Execute the SQL statement synchronously.
///
/// If you execute a SQL query statement (e.g. `SELECT` statement) that
Expand Down
4 changes: 4 additions & 0 deletions libsql/src/local/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ impl Conn for LibsqlConnection {
self.conn.execute_batch(sql)
}

async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
self.conn.execute_transactional_batch(sql)
}

async fn prepare(&self, sql: &str) -> Result<Statement> {
let sql = sql.to_string();

Expand Down
135 changes: 131 additions & 4 deletions libsql/src/replication/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
use std::str::FromStr;
use std::sync::Arc;

use libsql_replication::rpc::proxy::{
describe_result, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows,
State as RemoteState,
};
use libsql_replication::rpc::proxy::{describe_result, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows, State as RemoteState, Step, Query, Cond, OkCond, NotCond, Positional};
use parking_lot::Mutex;

use crate::parser;
Expand Down Expand Up @@ -207,6 +204,34 @@ impl RemoteConnection {
Ok(res)
}

pub(self) async fn execute_steps_remote(
&self,
steps: Vec<Step>,
) -> Result<ExecuteResults> {
let Some(ref writer) = self.writer else {
return Err(Error::Misuse(
"Cannot delegate write in local replica mode.".into(),
));
};
let res = writer
.execute_steps(steps)
.await
.map_err(|e| Error::WriteDelegation(e.into()))?;

{
let mut inner = self.inner.lock();
inner.state = RemoteState::try_from(res.state)
.expect("Invalid state enum")
.into();
}

if let Some(replicator) = writer.replicator() {
replicator.sync_oneshot().await?;
}

Ok(res)
}

pub(self) async fn describe(&self, stmt: impl Into<String>) -> Result<DescribeResult> {
let Some(ref writer) = self.writer else {
return Err(Error::Misuse(
Expand Down Expand Up @@ -321,6 +346,108 @@ impl Conn for RemoteConnection {
Ok(())
}

async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
let mut stmts = Vec::new();
let parse = crate::parser::Statement::parse(sql);
for s in parse {
let s = s?;
if s.kind == StmtKind::TxnBegin || s.kind == StmtKind::TxnBeginReadOnly || s.kind == StmtKind::TxnEnd {
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
}
stmts.push(s);
}

if self.should_execute_local(&stmts[..])? {
self.local.execute_transactional_batch(sql).await?;

if !self.maybe_execute_rollback().await? {
return Ok(());
}
}

let mut steps = Vec::with_capacity(stmts.len() + 3);
steps.push(Step {
query: Some(Query {
stmt: "BEGIN TRANSACTION".to_string(),
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
});
let count = stmts.len() as i64;
for (idx, stmt) in stmts.into_iter().enumerate() {
let step = Step {
cond: Some(Cond {
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
step: idx as i64,
..Default::default()
})),
}),
query: Some(Query {
stmt: stmt.stmt,
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
};
steps.push(step);
}
steps.push(Step {
cond: Some(Cond {
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
step: count,
..Default::default()
})),
..Default::default()
}),
query: Some(Query {
stmt: "COMMIT".to_string(),
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
});
steps.push(Step {
cond: Some(Cond {
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Not(Box::new(NotCond {
cond: Some(Box::new(Cond{
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
step: count + 1,
..Default::default()
})),
..Default::default()
})),
..Default::default()
}))),
..Default::default()
}),
query: Some(Query {
stmt: "ROLLBACK".to_string(),
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
});

let res = self.execute_steps_remote(steps).await?;

for result in res.results {
match result.row_result {
Some(RowResult::Row(row)) => self.update_state(&row),
Some(RowResult::Error(e)) => {
return Err(Error::RemoteSqliteFailure(
e.code,
e.extended_code,
e.message,
))
}
None => panic!("unexpected empty result row"),
};
}

Ok(())
}

async fn prepare(&self, sql: &str) -> Result<Statement> {
let stmt = RemoteStatement::prepare(self.clone(), sql).await?;

Expand Down
7 changes: 7 additions & 0 deletions libsql/src/replication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ impl Writer {
})
.collect();

self.execute_steps(steps).await
}

pub(crate) async fn execute_steps(
&self,
steps: Vec<Step>,
) -> anyhow::Result<ExecuteResults> {
self.client
.execute_program(ProgramReq {
client_id: self.client.client_id(),
Expand Down
Loading

0 comments on commit 41e17aa

Please sign in to comment.