diff --git a/core/rs/core/src/changes_vtab_write.rs b/core/rs/core/src/changes_vtab_write.rs index 534cd52a..6f73abc9 100644 --- a/core/rs/core/src/changes_vtab_write.rs +++ b/core/rs/core/src/changes_vtab_write.rs @@ -94,42 +94,16 @@ fn did_cid_win( reset_cached_stmt(col_val_stmt.stmt)?; if ret == 0 && unsafe { (*ext_data).mergeEqualValues == 1 } { // values are the same (ret == 0) and the option to tie break on site_id is true - let col_site_id_stmt_ref = tbl_info.get_col_site_id_stmt(db)?; - let col_site_id_stmt = col_site_id_stmt_ref.as_ref().ok_or(ResultCode::ERROR)?; - - let bind_result = col_site_id_stmt.bind_int64(1, key).and_then(|_| { - col_site_id_stmt.bind_text(2, col_name, sqlite::Destructor::STATIC) - }); - if let Err(rc) = bind_result { - reset_cached_stmt(col_site_id_stmt.stmt)?; - return Err(rc); - } - - match col_site_id_stmt.step() { - Ok(ResultCode::ROW) => { - let local_site_id = col_site_id_stmt.column_blob(0)?; - ret = insert_site_id.cmp(local_site_id) as c_int; - - // reset the stmt after, we're accessing a slice in-memory - reset_cached_stmt(col_site_id_stmt.stmt)?; - } - Ok(ResultCode::DONE) => { - reset_cached_stmt(col_site_id_stmt.stmt)?; - let err = CString::new(format!( - "could not find site_id for previous change, cr-sqlite clock table might be corrupt for tbl {}", - insert_tbl - ))?; - unsafe { *errmsg = err.into_raw() }; - return Err(ResultCode::ERROR); - } - Ok(rc) | Err(rc) => { - reset_cached_stmt(col_site_id_stmt.stmt)?; - let err = - CString::new("Bad return code when selecting local column site_id")?; - unsafe { *errmsg = err.into_raw() }; - return Err(rc); - } - } + let won = did_site_id_win( + db, + insert_tbl, + tbl_info, + key, + col_name, + insert_site_id, + errmsg, + )?; + return Ok(won); } return Ok(ret > 0); } @@ -147,6 +121,51 @@ fn did_cid_win( } } +fn did_site_id_win( + db: *mut sqlite3, + insert_tbl: &str, + tbl_info: &TableInfo, + key: sqlite::int64, + col_name: &str, + insert_site_id: &[u8], + errmsg: *mut *mut c_char, +) -> Result { + let col_site_id_stmt_ref = tbl_info.get_col_site_id_stmt(db)?; + let col_site_id_stmt = col_site_id_stmt_ref.as_ref().ok_or(ResultCode::ERROR)?; + + let bind_result = col_site_id_stmt + .bind_int64(1, key) + .and_then(|_| col_site_id_stmt.bind_text(2, col_name, sqlite::Destructor::STATIC)); + if let Err(rc) = bind_result { + reset_cached_stmt(col_site_id_stmt.stmt)?; + return Err(rc); + } + + match col_site_id_stmt.step() { + Ok(ResultCode::ROW) => { + let local_site_id = col_site_id_stmt.column_blob(0)?; + let ret = insert_site_id.cmp(local_site_id) as c_int; + reset_cached_stmt(col_site_id_stmt.stmt)?; + Ok(ret > 0) + } + Ok(ResultCode::DONE) => { + reset_cached_stmt(col_site_id_stmt.stmt)?; + let err = CString::new(format!( + "could not find site_id for previous change, cr-sqlite clock table might be corrupt for tbl {}", + insert_tbl + ))?; + unsafe { *errmsg = err.into_raw() }; + return Err(ResultCode::ERROR); + } + Ok(rc) | Err(rc) => { + reset_cached_stmt(col_site_id_stmt.stmt)?; + let err = CString::new("Bad return code when selecting local column site_id")?; + unsafe { *errmsg = err.into_raw() }; + return Err(rc); + } + } +} + fn set_winner_clock( db: *mut sqlite3, ext_data: *mut crsql_ExtData, @@ -543,6 +562,31 @@ unsafe fn merge_insert( // We got a delete event but we've already processed a delete at that version. // Just bail. if insert_cl == local_cl { + if unsafe { (*(*tab).pExtData).mergeEqualValues == 1 } + && did_site_id_win( + db, + insert_tbl, + &tbl_info, + key, + insert_col, + insert_site_id, + errmsg, + )? + { + // here we set the same winner for the clock if incoming site_id won + set_winner_clock( + db, + (*tab).pExtData, + &tbl_info, + key, + insert_col, + insert_col_vrsn, + insert_db_vrsn, + insert_site_id, + insert_seq, + insert_ts, + )?; + } return Ok(ResultCode::OK); } // else, it is a delete and the cl is > than ours. Drop the row. diff --git a/py/correctness/tests/test_sync.py b/py/correctness/tests/test_sync.py index 417bef0a..7924625e 100644 --- a/py/correctness/tests/test_sync.py +++ b/py/correctness/tests/test_sync.py @@ -438,6 +438,34 @@ def test_merge_same_w_tie_breaker(): assert (changes1 == changes2 == changes3) + + # test that delete also merge when conflicts are shared + db1.execute("DELETE FROM foo WHERE a = 1;") + db1.commit() + + db2.execute("DELETE FROM foo WHERE a = 1;") + db2.commit() + + sync_left_to_right(db1, db2, 0) + sync_left_to_right(db2, db1, 0) + + changes1 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + changes2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + + assert (changes1 == changes2) + + db3.execute("DELETE FROM foo WHERE a = 1;") + db3.commit() + + for dba, dbb in possibilities: + sync_left_to_right(dba, dbb, 0) + + changes1 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + changes2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + changes3 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + + assert (changes1 == changes2 == changes3) + def test_merge_matching_clocks_lesser_value(): def make_dbs(): db1 = create_basic_db()