Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docker/dev/rivet-engine/config.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
"postgres": {
"url": "postgresql://postgres:postgres@postgres:5432/rivet_engine"
},
"memory": {
"channel": "default"
},
"cache": {
"driver": "in_memory"
},
Expand Down
65 changes: 54 additions & 11 deletions packages/common/universaldb/src/driver/postgres/database.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::sync::{Arc, Mutex};
use std::{
sync::{Arc, Mutex},
time::Duration,
};

use anyhow::{Context, Result};
use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime};
use tokio::task::JoinHandle;
use tokio_postgres::NoTls;

use crate::{
Expand All @@ -14,9 +18,13 @@

use super::transaction::PostgresTransactionDriver;

const TXN_TIMEOUT: Duration = Duration::from_secs(5);
const GC_INTERVAL: Duration = Duration::from_secs(5);

pub struct PostgresDatabaseDriver {
pool: Arc<Pool>,
max_retries: Arc<Mutex<i32>>,
gc_handle: JoinHandle<()>,
}

impl PostgresDatabaseDriver {
Expand Down Expand Up @@ -53,7 +61,7 @@
.context("failed to create btree_gist extension")?;

conn.execute(
"CREATE SEQUENCE IF NOT EXISTS global_version_seq START WITH 1 INCREMENT BY 1 MINVALUE 1",
"CREATE UNLOGGED SEQUENCE IF NOT EXISTS global_version_seq START WITH 1 INCREMENT BY 1 MINVALUE 1",
&[],
)
.await
Expand Down Expand Up @@ -123,12 +131,39 @@
.await
.context("failed to create conflict_ranges table")?;

// Connection is automatically returned to the pool when dropped
drop(conn);
// Create index on ts column for efficient garbage collection
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_conflict_ranges_ts ON conflict_ranges (ts)",
&[],
)
.await
.context("failed to create index on conflict_ranges ts column")?;

let gc_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(GC_INTERVAL);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

loop {
interval.tick().await;

// NOTE: Transactions have a max limit of 5 seconds, we delete after 10 seconds for extra padding
// Delete old conflict ranges
if let Err(err) = conn
.execute(
"DELETE FROM conflict_ranges where ts < now() - interval '10 seconds'",
&[],
)
.await
{
tracing::error!(?err, "failed postgres gc task");
}
}
});

Ok(PostgresDatabaseDriver {
pool: Arc::new(pool),
max_retries: Arc::new(Mutex::new(100)),
gc_handle,
})
}
}
Expand All @@ -155,13 +190,15 @@
retryable.maybe_committed = maybe_committed;

// Execute transaction
let error = match closure(retryable.clone()).await {
Ok(res) => match retryable.inner.driver.commit_ref().await {
Ok(_) => return Ok(res),
Err(e) => e,
},
Err(e) => e,
};
let error =
match tokio::time::timeout(TXN_TIMEOUT, closure(retryable.clone())).await {
Ok(Ok(res)) => match retryable.inner.driver.commit_ref().await {
Ok(_) => return Ok(res),
Err(e) => e,
},
Ok(Err(e)) => e,
Err(e) => anyhow::Error::from(DatabaseError::TransactionTooOld),

Check failure on line 200 in packages/common/universaldb/src/driver/postgres/database.rs

View workflow job for this annotation

GitHub Actions / Check

unused variable: `e`

Check warning on line 200 in packages/common/universaldb/src/driver/postgres/database.rs

View workflow job for this annotation

GitHub Actions / Test

unused variable: `e`
};

let chain = error
.chain()
Expand Down Expand Up @@ -196,3 +233,9 @@
}
}
}

impl Drop for PostgresDatabaseDriver {
fn drop(&mut self) {
self.gc_handle.abort();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,20 @@
tx: Transaction<'_>,
start_version: i64,
operations: Vec<Operation>,
conflict_ranges: Vec<(Vec<u8>, Vec<u8>, ConflictRangeType)>,
mut conflict_ranges: Vec<(Vec<u8>, Vec<u8>, ConflictRangeType)>,

Check failure on line 333 in packages/common/universaldb/src/driver/postgres/transaction_task.rs

View workflow job for this annotation

GitHub Actions / Check

variable does not need to be mutable

Check warning on line 333 in packages/common/universaldb/src/driver/postgres/transaction_task.rs

View workflow job for this annotation

GitHub Actions / Test

variable does not need to be mutable
) -> Result<()> {
let commit_version = tx
.query_one("SELECT nextval('global_version_seq')", &[])
.await
// // Defer all constraint checks until commit
// tx.execute("SET CONSTRAINTS ALL DEFERRED", &[])
// .await
// .map_err(map_postgres_error)?;

let (_, _, version_res) = tokio::join!(
tx.execute("SET LOCAL lock_timeout = '0'", &[],),
tx.execute("SET LOCAL deadlock_timeout = '10ms'", &[],),
tx.query_one("SELECT nextval('global_version_seq')", &[]),
);

let commit_version = version_res
.context("failed to get postgres txn commit_version")?
.get::<_, i64>(0);

Expand All @@ -355,7 +364,7 @@

let query = "
INSERT INTO conflict_ranges (range_data, conflict_type, start_version, commit_version)
SELECT
SELECT
bytearange(begin_key, end_key, '[)'),
conflict_type::range_type,
$4,
Expand All @@ -377,13 +386,22 @@
.await
.map_err(map_postgres_error)?;

// TODO: Parallelize
for op in operations {
match op {
Operation::Set { key, value } => {
// TODO: versionstamps need to be calculated on the sql side, not in rust
let value = substitute_versionstamp_if_incomplete(value.clone(), 0);

// // Poor man's upsert, you cant use ON CONFLICT with deferred constraints
// let query = "WITH updated AS (
// UPDATE kv
// SET value = $2
// WHERE key = $1
// RETURNING 1
// )
// INSERT INTO kv (key, value)
// SELECT $1, $2
// WHERE NOT EXISTS (SELECT 1 FROM updated)";
let query = "INSERT INTO kv (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $2";
let stmt = tx.prepare_cached(query).await.map_err(map_postgres_error)?;

Expand Down Expand Up @@ -435,6 +453,16 @@

// Store the result
if let Some(new_value) = new_value {
// // Poor man's upsert, you cant use ON CONFLICT with deferred constraints
// let update_query = "WITH updated AS (
// UPDATE kv
// SET value = $2
// WHERE key = $1
// RETURNING 1
// )
// INSERT INTO kv (key, value)
// SELECT $1, $2
// WHERE NOT EXISTS (SELECT 1 FROM updated)";
let update_query = "INSERT INTO kv (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $2";
let stmt = tx
.prepare_cached(update_query)
Expand Down
2 changes: 1 addition & 1 deletion packages/common/universaldb/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ pub enum MutationType {
/// Performs an atomic ``compare and clear`` operation. If the existing value in the database is equal to the given value, then given key is cleared.
CompareAndClear,
}
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum ConflictRangeType {
/// Used to add a read conflict range
Expand Down
11 changes: 11 additions & 0 deletions packages/common/universaldb/src/tx_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@
},
}

impl Operation {
pub fn sorting_key(&self) -> &[u8] {

Check failure on line 38 in packages/common/universaldb/src/tx_ops.rs

View workflow job for this annotation

GitHub Actions / Check

method `sorting_key` is never used

Check warning on line 38 in packages/common/universaldb/src/tx_ops.rs

View workflow job for this annotation

GitHub Actions / Test

method `sorting_key` is never used
match self {
Operation::Set { key, .. } => key,
Operation::Clear { key } => key,
Operation::ClearRange { begin, .. } => begin,
Operation::AtomicOp { key, .. } => key,
}
}
}

#[derive(Debug, Clone)]
pub enum GetOutput {
Value(Vec<u8>),
Expand Down
2 changes: 1 addition & 1 deletion packages/core/guard/core/src/websocket_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl WebSocketHandleInner {
let mut state = self.state.lock().await;
match &mut *state {
WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => {
bail!("websocket has not been accepted")
bail!("websocket has not been accepted");
}
WebSocketState::Split { ws_tx } => {
ws_tx.send(message).await?;
Expand Down
16 changes: 8 additions & 8 deletions scripts/tests/actor_e2e.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env tsx

import { RIVET_ENDPOINT, createActor, destroyActor } from "./utils";
import { RIVET_ENDPOINT, RIVET_TOKEN, createActor, destroyActor } from "./utils";

async function main() {
try {
Expand Down Expand Up @@ -31,9 +31,7 @@ async function main() {

console.log("Actor ping response:", pingResult);

// Test WebSocket connection
console.log("Testing WebSocket connection to actor...");
// await testWebSocket(actorResponse.actor.actor_id);
await testWebSocket(actorResponse.actor.actor_id);

console.log("Destroying actor...");
await destroyActor("default", actorResponse.actor.actor_id);
Expand All @@ -49,6 +47,8 @@ async function main() {
}

function testWebSocket(actorId: string): Promise<void> {
console.log("Testing WebSocket connection to actor...");

return new Promise((resolve, reject) => {
// Parse the RIVET_ENDPOINT to get WebSocket URL
const wsEndpoint = RIVET_ENDPOINT.replace("http://", "ws://").replace(
Expand All @@ -59,7 +59,7 @@ function testWebSocket(actorId: string): Promise<void> {

console.log(`Connecting WebSocket to: ${wsUrl}`);

const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`];
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`, `rivet_token.${RIVET_TOKEN}`];
const ws = new WebSocket(wsUrl, protocols);

let pingReceived = false;
Expand All @@ -81,9 +81,9 @@ function testWebSocket(actorId: string): Promise<void> {
ws.send("ping");
});

ws.addEventListener("message", (data) => {
const message = data.toString();
console.log(`WebSocket received raw data:`, data);
ws.addEventListener("message", (ev) => {
const message = ev.data.toString();
console.log(`WebSocket received raw data:`, ev.data);
console.log(`WebSocket received message: "${message}"`);

if (
Expand Down
12 changes: 6 additions & 6 deletions scripts/tests/spam_actors.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env tsx

import { RIVET_ENDPOINT, createActor, destroyActor } from "./utils";
import { RIVET_ENDPOINT, RIVET_TOKEN, createActor, destroyActor } from "./utils";

const ACTORS = parseInt(process.argv[2]) || 15;

Expand Down Expand Up @@ -44,7 +44,7 @@ async function testActor(i: number) {

console.log(`Actor ${i} ping response:`, pingResult);

// await testWebSocket(actorResponse.actor.actor_id);
await testWebSocket(actorResponse.actor.actor_id);

console.log(`Destroying actor ${i}...`);
await destroyActor("default", actorResponse.actor.actor_id);
Expand All @@ -66,7 +66,7 @@ function testWebSocket(actorId: string): Promise<void> {

console.log(`Connecting WebSocket to: ${wsUrl}`);

const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`];
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`, `rivet_token.${RIVET_TOKEN}`];
const ws = new WebSocket(wsUrl, protocols);

let pingReceived = false;
Expand All @@ -88,9 +88,9 @@ function testWebSocket(actorId: string): Promise<void> {
ws.send("ping");
});

ws.addEventListener("message", (data) => {
const message = data.toString();
console.log(`WebSocket received raw data:`, data);
ws.addEventListener("message", (ev) => {
const message = ev.data.toString();
console.log(`WebSocket received raw data:`, ev.data);
console.log(`WebSocket received message: "${message}"`);

if (
Expand Down
2 changes: 1 addition & 1 deletion scripts/tests/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export const RIVET_ENDPOINT =
process.env.RIVET_ENDPOINT ?? "http://localhost:6420";
const RIVET_TOKEN = process.env.RIVET_TOKEN ?? "dev";
export const RIVET_TOKEN = process.env.RIVET_TOKEN ?? "dev";

export async function createActor(
namespaceName: string,
Expand Down
Loading