Skip to content

Commit

Permalink
Merge #33
Browse files Browse the repository at this point in the history
33: batch query r=penberg a=MarinPostma

Add support for batch queries

This was more work than I expected, because of the different protocols we support and different modes of operations. Queries are always batched now (singular queries are in batches on 1).

If any statement in the batch does not parse, then the whole batch is rejected.

fix #18 

Co-authored-by: ad hoc <postma.marin@protonmail.com>
  • Loading branch information
bors[bot] and MarinPostma committed Jan 14, 2023
2 parents 57847a4 + 4352723 commit 15c75e7
Show file tree
Hide file tree
Showing 14 changed files with 493 additions and 270 deletions.
3 changes: 1 addition & 2 deletions libsql-server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions libsql-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ members = [
mvfs = { git = "https://github.com/MarinPostma/mvsqlite", branch = "use-cchar" }
mwal = { git = "https://github.com/MarinPostma/mvsqlite", branch = "use-cchar" }

[patch.crates-io]
sqlite3-parser = { git = "https://github.com/MarinPostma/lemon-rs.git", rev = "d3a6365" }
32 changes: 20 additions & 12 deletions libsql-server/server/proto/proxy.proto
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
syntax = "proto3";
package proxy;

message SimpleQuery {
string q = 1;
// Uuid
bytes clientId = 2;
}
message Queries {
repeated string queries = 1;
// Uuid
bytes clientId = 2;
}

message QueryResult {
optional Error error = 1;
optional ResultRows rows = 2;
enum Result {
Ok = 0;
Err = 1;
oneof row_result {
Error error = 1;
ResultRows row = 2;
}
Result result = 3;
}

message Error {
Expand Down Expand Up @@ -62,7 +59,18 @@ message DisconnectMessage {

message Ack {}

message ExecuteResults {
repeated QueryResult results = 1;
enum State {
Init = 0;
Invalid = 1;
Txn = 2;
}
/// State after executing the queries
State state = 2;
}

service Proxy {
rpc Query(SimpleQuery) returns (QueryResult) {}
rpc Execute(Queries) returns (ExecuteResults) {}
rpc Disconnect(DisconnectMessage) returns (Ack) {}
}
130 changes: 92 additions & 38 deletions libsql-server/server/src/database/libsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@ use tracing::warn;
use crate::libsql::wal_hook::WalHook;
use crate::libsql::WalConnection;
use crate::query::{
Column, ErrorCode, QueryError, QueryResponse, QueryResult, ResultSet, Row, Value,
Column, ErrorCode, Queries, Query, QueryError, QueryResponse, QueryResult, ResultSet, Row,
Value,
};
use crate::query_analysis::{State, Statement};

use super::{Database, TXN_TIMEOUT_SECS};

/// Internal message used to communicate between the database thread and the `LibSqlDb` handle.
struct Message {
queries: Queries,
resp: oneshot::Sender<(Vec<QueryResult>, State)>,
}

#[derive(Clone)]
pub struct LibSqlDb {
sender: crossbeam::channel::Sender<(Statement, Vec<Value>, oneshot::Sender<QueryResult>)>,
sender: crossbeam::channel::Sender<Message>,
}

fn execute_query(conn: &rusqlite::Connection, stmt: &Statement, params: Vec<Value>) -> QueryResult {
Expand All @@ -39,9 +46,11 @@ fn execute_query(conn: &rusqlite::Connection, stmt: &Statement, params: Vec<Valu
.flatten(),
})
.collect::<Vec<_>>();

let mut qresult = prepared.query(params_from_iter(
params.into_iter().map(rusqlite::types::Value::from),
))?;

while let Some(row) = qresult.next()? {
let mut values = vec![];
for (i, _) in columns.iter().enumerate() {
Expand All @@ -53,6 +62,61 @@ fn execute_query(conn: &rusqlite::Connection, stmt: &Statement, params: Vec<Valu
Ok(QueryResponse::ResultSet(ResultSet { columns, rows }))
}

struct ConnectionState {
state: State,
timeout_deadline: Option<Instant>,
}

impl ConnectionState {
fn initial() -> Self {
Self {
state: State::Init,
timeout_deadline: None,
}
}

fn deadline(&self) -> Option<Instant> {
self.timeout_deadline
}

fn reset(&mut self) {
self.state.reset();
self.timeout_deadline.take();
}

fn step(&mut self, stmt: &Statement) {
let old_state = self.state;

self.state.step(stmt.kind);

match (old_state, self.state) {
(State::Init, State::Txn) => {
self.timeout_deadline
.replace(Instant::now() + Duration::from_secs(TXN_TIMEOUT_SECS));
}
(State::Txn, State::Init) => self.reset(),
(_, State::Invalid) => panic!("invalid state"),
_ => (),
}
}
}

fn handle_query(
conn: &rusqlite::Connection,
query: Query,
state: &mut ConnectionState,
) -> QueryResult {
let result = execute_query(conn, &query.stmt, query.params);

// We drive the connection state on success. This is how we keep track of whether
// a transaction timeouts
if result.is_ok() {
state.step(&query.stmt)
}

result
}

fn rollback(conn: &rusqlite::Connection) {
conn.execute("rollback transaction;", ())
.expect("failed to rollback");
Expand Down Expand Up @@ -138,9 +202,7 @@ impl LibSqlDb {
>,
wal_hook: impl WalHook + Send + Clone + 'static,
) -> anyhow::Result<Self> {
let (sender, receiver) =
crossbeam::channel::unbounded::<(Statement, Vec<Value>, oneshot::Sender<QueryResult>)>(
);
let (sender, receiver) = crossbeam::channel::unbounded::<Message>();

tokio::task::spawn_blocking(move || {
let conn = open_db(
Expand All @@ -150,19 +212,18 @@ impl LibSqlDb {
wal_hook,
)
.unwrap();
let mut state = State::Start;
let mut timeout_deadline = None;

let mut state = ConnectionState::initial();
let mut timedout = false;
loop {
let (stmt, params, sender) = match timeout_deadline {
let Message { queries, resp } = match state.deadline() {
Some(deadline) => match receiver.recv_deadline(deadline) {
Ok(msg) => msg,
Err(RecvTimeoutError::Timeout) => {
warn!("transaction timed out");
rollback(&conn);
timeout_deadline = None;
timedout = true;
state = State::Start;
state.reset();
continue;
}
Err(RecvTimeoutError::Disconnected) => break,
Expand All @@ -174,30 +235,23 @@ impl LibSqlDb {
};

if !timedout {
let old_state = state;
let result = execute_query(&conn, &stmt, params);
if result.is_ok() {
state.step(stmt.kind);
match (old_state, state) {
(State::Start, State::TxnOpened) => {
timeout_deadline.replace(
Instant::now() + Duration::from_secs(TXN_TIMEOUT_SECS),
);
}
(State::TxnOpened, State::TxnClosed) => {
timeout_deadline.take();
state.reset();
}
(_, State::Invalid) => panic!("invalid state"),
_ => (),
}
let mut results = Vec::with_capacity(queries.len());
for query in queries {
let result = handle_query(&conn, query, &mut state);
results.push(result);
}
ok_or_exit!(sender.send(result));
ok_or_exit!(resp.send((results, state.state)));
} else {
ok_or_exit!(sender.send(Err(QueryError::new(
ErrorCode::TxTimeout,
"transaction timedout",
))));
// fail all the queries in the batch with timeout error
let errors = (0..queries.len())
.map(|_| {
Err(QueryError::new(
ErrorCode::TxTimeout,
"transaction timedout",
))
})
.collect();
ok_or_exit!(resp.send((errors, state.state)));
timedout = false;
}
}
Expand All @@ -209,11 +263,11 @@ impl LibSqlDb {

#[async_trait::async_trait]
impl Database for LibSqlDb {
async fn execute(&self, query: Statement, params: Vec<Value>) -> QueryResult {
let (sender, receiver) = oneshot::channel();
let _ = self.sender.send((query, params, sender));
receiver
.await
.map_err(|e| QueryError::new(ErrorCode::Internal, e.to_string()))?
async fn execute(&self, queries: Queries) -> anyhow::Result<(Vec<QueryResult>, State)> {
let (resp, receiver) = oneshot::channel();
let msg = Message { queries, resp };
let _ = self.sender.send(msg);

Ok(receiver.await?)
}
}
8 changes: 5 additions & 3 deletions libsql-server/server/src/database/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::query::{QueryResult, Value};
use crate::query_analysis::Statement;
use crate::query::{Queries, QueryResult};
use crate::query_analysis::State;

pub mod libsql;
pub mod service;
Expand All @@ -9,5 +9,7 @@ const TXN_TIMEOUT_SECS: u64 = 5;

#[async_trait::async_trait]
pub trait Database {
async fn execute(&self, query: Statement, params: Vec<Value>) -> QueryResult;
/// Executes a batch of queries, and return the a vec of results corresponding to the queries,
/// and the state the database is in after the call to execute.
async fn execute(&self, queries: Queries) -> anyhow::Result<(Vec<QueryResult>, State)>;
}
22 changes: 7 additions & 15 deletions libsql-server/server/src/database/service.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::future::ready;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
Expand All @@ -7,8 +6,7 @@ use futures::Future;
use tower::Service;

use super::Database;
use crate::query::{ErrorCode, Query, QueryError, QueryResponse, QueryResult, ResultSet};
use crate::query_analysis::Statement;
use crate::query::{Queries, QueryResult};
pub trait DbFactory: Send + Sync + 'static {
type Future: Future<Output = anyhow::Result<Self::Db>> + Send;
type Db: Database + Send + Sync;
Expand Down Expand Up @@ -76,24 +74,18 @@ impl<DB> Drop for DbService<DB> {
}
}

impl<DB: Database + 'static + Send + Sync> Service<Query> for DbService<DB> {
type Response = QueryResponse;
type Error = QueryError;
type Future = Pin<Box<dyn Future<Output = QueryResult> + Send>>;
impl<DB: Database + 'static + Send + Sync> Service<Queries> for DbService<DB> {
type Response = Vec<QueryResult>;
type Error = anyhow::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// need to implement backpressure: one req at a time.
Ok(()).into()
}

fn call(&mut self, query: Query) -> Self::Future {
fn call(&mut self, queries: Queries) -> Self::Future {
let db = self.db.clone();
match query {
Query::SimpleQuery(stmts, params) => match Statement::parse(stmts) {
Ok(None) => Box::pin(ready(Ok(QueryResponse::ResultSet(ResultSet::empty())))),
Ok(Some(stmt)) => Box::pin(async move { db.execute(stmt, params).await }),
Err(e) => Box::pin(ready(Err(QueryError::new(ErrorCode::SQLError, e)))),
},
}
Box::pin(async move { Ok(db.execute(queries).await?.0) })
}
}
Loading

0 comments on commit 15c75e7

Please sign in to comment.