From 59efcdcf0cdf25e7e8b6b7e6b4a90bbd347958e0 Mon Sep 17 00:00:00 2001 From: Germano Rizzo Date: Mon, 9 Oct 2023 19:24:30 +0200 Subject: [PATCH] #1: Implement positional parameters for SQL --- README.md | 1 + src/commons.rs | 14 +++++ src/logic.rs | 55 +++++++++++++++--- tests/go_test.go | 145 +++++++++++++++++++++++++++++++++-------------- tests/structs.go | 14 ++--- 5 files changed, 172 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index add8a19..02c3bfa 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ Obtaining an answer of: - Directly call `sqliterg` on a database (as above), many options available using a YAML companion file; - [**In-memory DBs**](https://docs.sqliterg.dev/documentation/running#file-based-and-in-memory) are supported; - Serving of [**multiple databases**](https://docs.sqliterg.dev/documentation/configuration-file) in the same server instance; +- Named or positional parameters in SQL are supported; - [**Batching**](https://docs.sqliterg.dev/documentation/requests#batch-parameter-values-for-a-statement) of multiple value sets for a single statement; - All queries of a call are executed in a [**transaction**](https://docs.sqliterg.dev/documentation/requests); - For each query/statement, specify if a failure should rollback the whole transaction, or the failure is [**limited**](https://docs.sqliterg.dev/documentation/errors#managed-errors) to that query; diff --git a/src/commons.rs b/src/commons.rs index 7249eb8..503970b 100644 --- a/src/commons.rs +++ b/src/commons.rs @@ -179,3 +179,17 @@ impl From)>> for NamedParamsContain Self(src) } } + +pub struct PositionalParamsContainer(Vec>); + +impl PositionalParamsContainer { + pub fn slice(&self) -> Vec<&dyn rusqlite::types::ToSql> { + self.0.iter().map(|el| (el.borrow())).collect() + } +} + +impl From>> for PositionalParamsContainer { + fn from(src: Vec>) -> Self { + Self(src) + } +} diff --git a/src/logic.rs b/src/logic.rs index 36260c6..e2bc908 100644 --- a/src/logic.rs +++ b/src/logic.rs @@ -22,7 +22,7 @@ use serde_json::{json, Map as JsonMap, Value as JsonValue}; use crate::{ auth::process_auth, - commons::{check_stored_stmt, NamedParamsContainer}, + commons::{check_stored_stmt, NamedParamsContainer, PositionalParamsContainer}, db_config::{AuthMode, DbConfig}, main_config::Db, req_res::{self, Response, ResponseItem}, @@ -56,6 +56,21 @@ fn calc_named_params(params: &JsonMap) -> NamedParamsContaine NamedParamsContainer::from(named_params) } +fn calc_positional_params(params: &Vec) -> PositionalParamsContainer { + let mut ret_params: Vec> = Vec::new(); + + for v in params { + let val: Box = if v.is_string() { + Box::new(v.as_str().unwrap().to_owned()) + } else { + Box::new(v.to_owned()) + }; + ret_params.push(val); + } + + PositionalParamsContainer::from(ret_params) +} + #[allow(clippy::type_complexity)] fn do_query( tx: &Transaction, @@ -70,8 +85,17 @@ fn do_query( .collect(); let mut rows = match values { Some(p) => { - let map = p.as_object().unwrap(); - stmt.query(calc_named_params(map).slice().as_slice())? + // FIXME this code is repeated three times; I wish I could make a common + // function but, no matter how I try, it seems not to be possible + if p.is_object() { + let map = p.as_object().unwrap(); + stmt.query(calc_named_params(map).slice().as_slice())? + } else if p.is_array() { + let array = p.as_array().unwrap(); + stmt.query(calc_positional_params(array).slice().as_slice())? + } else { + return Err(eyre!("Values are neither positional nor named".to_string())); + } } None => stmt.query([])?, }; @@ -107,16 +131,33 @@ fn do_statement( let changed_rows = tx.execute(sql, [])?; (None, Some(changed_rows), None) } else if values.is_some() { - let map = values.as_ref().unwrap().as_object().unwrap(); - let changed_rows = tx.execute(sql, calc_named_params(map).slice().as_slice())?; + let p = values.as_ref().unwrap(); + let changed_rows = if p.is_object() { + let map = p.as_object().unwrap(); + tx.execute(sql, calc_named_params(map).slice().as_slice())? + } else if p.is_array() { + let array = p.as_array().unwrap(); + tx.execute(sql, calc_positional_params(array).slice().as_slice())? + } else { + return Err(eyre!("Values are neither positional nor named".to_string())); + }; + (None, Some(changed_rows), None) } else { // values_batch.is_some() let mut stmt = tx.prepare(sql)?; let mut ret = vec![]; for p in values_batch.as_ref().unwrap() { - let map = p.as_object().unwrap(); - let changed_rows = stmt.execute(calc_named_params(map).slice().as_slice())?; + let changed_rows = if p.is_object() { + let map = p.as_object().unwrap(); + stmt.execute(calc_named_params(map).slice().as_slice())? + } else if p.is_array() { + let array = p.as_array().unwrap(); + stmt.execute(calc_positional_params(array).slice().as_slice())? + } else { + return Err(eyre!("Values are neither positional nor named".to_string())); + }; + ret.push(changed_rows); } (None, None, Some(ret)) diff --git a/tests/go_test.go b/tests/go_test.go index 5180975..3413073 100644 --- a/tests/go_test.go +++ b/tests/go_test.go @@ -94,7 +94,7 @@ func setupTest(t *testing.T, cfg *db, printOutput bool, argv ...string) func(boo } } -func mkRaw(mapp map[string]interface{}) map[string]json.RawMessage { +func mkNamedParams(mapp map[string]interface{}) map[string]json.RawMessage { ret := make(map[string]json.RawMessage) for k, v := range mapp { bytes, _ := json.Marshal(v) @@ -103,6 +103,15 @@ func mkRaw(mapp map[string]interface{}) map[string]json.RawMessage { return ret } +func mkPositionalParams(arr []interface{}) []json.RawMessage { + ret := make([]json.RawMessage, len(arr)) + for i, v := range arr { + bytes, _ := json.Marshal(v) + ret[i] = bytes + } + return ret +} + func call(t *testing.T, url string, req request) (int, string, response) { reqbytes, err := json.Marshal(req) require.NoError(t, err) @@ -303,7 +312,7 @@ func TestTx(t *testing.T) { }, { Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "ID": 2, "VAL": "TWO", }), @@ -311,18 +320,18 @@ func TestTx(t *testing.T) { { Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{ + mkNamedParams(map[string]interface{}{ "ID": 3, "VAL": "THREE", }), - mkRaw(map[string]interface{}{ + mkNamedParams(map[string]interface{}{ "ID": 4, "VAL": "FOUR", })}, }, { Query: "SELECT * FROM T1 WHERE ID > :ID", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "ID": 0, }), }, @@ -348,6 +357,60 @@ func TestTx(t *testing.T) { require.Equal(t, 4, len(res.Results[6].ResultSet)) } +func TestTxPositional(t *testing.T) { + defer setupTest(t, nil, false, "--db", "env/test.db")(true) + req := request{ + Transaction: []requestItem{ + { + Statement: "CREATE TABLE T1 (ID INT PRIMARY KEY, VAL TEXT NOT NULL)", + }, + { + Statement: "INSERT INTO T1 (ID, VAL) VALUES (1, 'ONE')", + }, + { + Statement: "INSERT INTO T1 (ID, VAL) VALUES (1, 'TWO')", + NoFail: true, + }, + { + Query: "SELECT * FROM T1 WHERE ID = 1", + }, + { + Statement: "INSERT INTO T1 (ID, VAL) VALUES (?, ?)", + Values: mkPositionalParams([]interface{}{2, "TWO"}), + }, + { + Statement: "INSERT INTO T1 (ID, VAL) VALUES (?, ?)", + ValuesBatch: [][]json.RawMessage{ + mkPositionalParams([]interface{}{3, "THREE"}), + mkPositionalParams([]interface{}{4, "FOUR"}), + }, + }, + { + Query: "SELECT * FROM T1 WHERE ID > ?", + Values: mkPositionalParams([]interface{}{0}), + }, + }, + } + + code, _, res := call(t, "http://localhost:12321/test", req) + + require.Equal(t, http.StatusOK, code) + + require.True(t, res.Results[1].Success) + require.False(t, res.Results[2].Success) + require.True(t, res.Results[3].Success) + require.True(t, res.Results[4].Success) + require.True(t, res.Results[5].Success) + require.True(t, res.Results[6].Success) + + require.Equal(t, 1, *res.Results[1].RowsUpdated) + require.Equal(t, "ONE", res.Results[3].ResultSet[0]["VAL"]) + require.Equal(t, 1, *res.Results[4].RowsUpdated) + require.Equal(t, 2, len(res.Results[5].RowsUpdatedBatch)) + require.Equal(t, 1, res.Results[5].RowsUpdatedBatch[0]) + require.Equal(t, 4, len(res.Results[6].ResultSet)) +} + func TestTxRollback(t *testing.T) { defer setupTest(t, nil, false, "--db", "env/test.db")(true) req := request{ @@ -446,7 +509,7 @@ func TestConcurrent(t *testing.T) { }, { Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "ID": 2, "VAL": "TWO", }), @@ -454,18 +517,18 @@ func TestConcurrent(t *testing.T) { { Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{ + mkNamedParams(map[string]interface{}{ "ID": 3, "VAL": "THREE", }), - mkRaw(map[string]interface{}{ + mkNamedParams(map[string]interface{}{ "ID": 4, "VAL": "FOUR", })}, }, { Query: "SELECT * FROM T1 WHERE ID > :ID", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "ID": 0, }), }, @@ -620,7 +683,7 @@ func TestTxMem(t *testing.T) { }, { Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "ID": 2, "VAL": "TWO", }), @@ -628,18 +691,18 @@ func TestTxMem(t *testing.T) { { Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{ + mkNamedParams(map[string]interface{}{ "ID": 3, "VAL": "THREE", }), - mkRaw(map[string]interface{}{ + mkNamedParams(map[string]interface{}{ "ID": 4, "VAL": "FOUR", })}, }, { Query: "SELECT * FROM T1 WHERE ID > :ID", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "ID": 0, }), }, @@ -1498,32 +1561,32 @@ func TestProfilerPayloadOnFile(t *testing.T) { }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", - Values: mkRaw(map[string]interface{}{"id": 0, "val": "zero"}), + Values: mkNamedParams(map[string]interface{}{"id": 0, "val": "zero"}), }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, { NoFail: true, Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val, 1)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 3, "val": "tre"}), + mkNamedParams(map[string]interface{}{"id": 3, "val": "tre"}), }, }, { Query: "SELECT * FROM TBL WHERE ID=:id", - Values: mkRaw(map[string]interface{}{"id": 1}), + Values: mkNamedParams(map[string]interface{}{"id": 1}), }, { Statement: "DELETE FROM TBL", @@ -1613,32 +1676,32 @@ func TestProfilerPayloadOnMem(t *testing.T) { }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", - Values: mkRaw(map[string]interface{}{"id": 0, "val": "zero"}), + Values: mkNamedParams(map[string]interface{}{"id": 0, "val": "zero"}), }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, { NoFail: true, Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val, 1)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 3, "val": "tre"}), + mkNamedParams(map[string]interface{}{"id": 3, "val": "tre"}), }, }, { Query: "SELECT * FROM TBL WHERE ID=:id", - Values: mkRaw(map[string]interface{}{"id": 1}), + Values: mkNamedParams(map[string]interface{}{"id": 1}), }, { Statement: "DELETE FROM TBL", @@ -1729,32 +1792,32 @@ func TestJournalMode(t *testing.T) { }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", - Values: mkRaw(map[string]interface{}{"id": 0, "val": "zero"}), + Values: mkNamedParams(map[string]interface{}{"id": 0, "val": "zero"}), }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, { NoFail: true, Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val, 1)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 3, "val": "tre"}), + mkNamedParams(map[string]interface{}{"id": 3, "val": "tre"}), }, }, { Query: "SELECT * FROM TBL WHERE ID=:id", - Values: mkRaw(map[string]interface{}{"id": 1}), + Values: mkNamedParams(map[string]interface{}{"id": 1}), }, { Statement: "DELETE FROM TBL", @@ -2398,10 +2461,10 @@ func TestBothValueAndBatchFail(t *testing.T) { Transaction: []requestItem{ { Statement: "INSERT INTO TBL (ID, VAL) VALUES (:id, :val)", - Values: mkRaw(map[string]interface{}{"id": 0, "val": "zero"}), + Values: mkNamedParams(map[string]interface{}{"id": 0, "val": "zero"}), ValuesBatch: []map[string]json.RawMessage{ - mkRaw(map[string]interface{}{"id": 1, "val": "uno"}), - mkRaw(map[string]interface{}{"id": 2, "val": "due"}), + mkNamedParams(map[string]interface{}{"id": 1, "val": "uno"}), + mkNamedParams(map[string]interface{}{"id": 2, "val": "due"}), }, }, }, @@ -2533,7 +2596,7 @@ func TestReturnedString(t *testing.T) { Transaction: []requestItem{ { Statement: "INSERT INTO TBL (ID, VAL) VALUES (1, :val)", - Values: mkRaw(map[string]interface{}{"val": ciao}), + Values: mkNamedParams(map[string]interface{}{"val": ciao}), }, { Query: "SELECT VAL FROM TBL WHERE ID = 1", }, @@ -2570,7 +2633,7 @@ func TestReturnedBigInteger(t *testing.T) { Transaction: []requestItem{ { Statement: "INSERT INTO TBL VALUES(:VAL)", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "VAL": test, }), }, @@ -2608,7 +2671,7 @@ func TestReturnedFloat(t *testing.T) { Transaction: []requestItem{ { Statement: "INSERT INTO TBL VALUES(:VAL)", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "VAL": test, }), }, @@ -2644,7 +2707,7 @@ func TestReturnedBool(t *testing.T) { Transaction: []requestItem{ { Statement: "INSERT INTO TBL VALUES(:VAL)", - Values: mkRaw(map[string]interface{}{ + Values: mkNamedParams(map[string]interface{}{ "VAL": true, }), }, diff --git a/tests/structs.go b/tests/structs.go index 30fdf37..3b0f7aa 100644 --- a/tests/structs.go +++ b/tests/structs.go @@ -16,10 +16,6 @@ package main -import ( - "encoding/json" -) - // These are for parsing the config file (from YAML) type credentialsCfg struct { @@ -86,11 +82,11 @@ type credentials struct { } type requestItem struct { - Query string `json:"query,omitempty"` - Statement string `json:"statement,omitempty"` - NoFail bool `json:"noFail,omitempty"` - Values map[string]json.RawMessage `json:"values,omitempty"` - ValuesBatch []map[string]json.RawMessage `json:"valuesBatch,omitempty"` + Query string `json:"query,omitempty"` + Statement string `json:"statement,omitempty"` + NoFail bool `json:"noFail,omitempty"` + Values interface{} `json:"values,omitempty"` + ValuesBatch interface{} `json:"valuesBatch,omitempty"` } type request struct {