diff --git a/Cargo.lock b/Cargo.lock index 974c2a4..3b20bd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -528,6 +528,26 @@ dependencies = [ "serde", ] +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -2873,6 +2893,7 @@ dependencies = [ "anyhow", "async-nats", "bigdecimal", + "bincode 2.0.1", "bs58", "clap 4.5.20", "config", @@ -2887,6 +2908,7 @@ dependencies = [ "serde_json", "solana-account-decoder", "solana-client", + "solana-metrics", "solana-sdk", "tokio", "tokio-stream", @@ -3756,7 +3778,7 @@ checksum = "41d87c6ef8c13eb759fa8d887e12c67afd851799050b6afd501a27726551f52e" dependencies = [ "Inflector", "base64 0.22.1", - "bincode", + "bincode 1.3.3", "bs58", "bv", "lazy_static", @@ -3797,7 +3819,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67169e4f1faabb717ce81b5ca93960da21e3ac5c9b75cb6792f9b3ce38db459f" dependencies = [ "async-trait", - "bincode", + "bincode 1.3.3", "dashmap", "futures", "futures-util", @@ -3839,7 +3861,7 @@ version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f638e44fb308bdc1ce99eb0fee194b2cb212917b258999cdb4a8b056d48973d4" dependencies = [ - "bincode", + "bincode 1.3.3", "chrono", "serde", "serde_derive", @@ -3854,7 +3876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fd01a4d43b780996970cb3669946b002f71d34e6a26a19bd6d2a74513ecc0aa" dependencies = [ "async-trait", - "bincode", + "bincode 1.3.3", "crossbeam-channel", "futures-util", "indexmap 2.6.0", @@ -3934,7 +3956,7 @@ version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44bb419eb9293a277982cf14a58772e9b9ab30ff6f9421bc4ac0826d40122760" dependencies = [ - "bincode", + "bincode 1.3.3", "clap 3.2.25", "crossbeam-channel", "log", @@ -3958,7 +3980,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00c4128122787a61d8f94fdaa04cb71b3dbb017d9939ac4d632264c55ec345de" dependencies = [ "ahash", - "bincode", + "bincode 1.3.3", "bv", "caps", "curve25519-dalek 3.2.1", @@ -3989,7 +4011,7 @@ dependencies = [ "ark-ff", "ark-serialize", "base64 0.22.1", - "bincode", + "bincode 1.3.3", "bitflags 2.6.0", "blake3", "borsh 0.10.4", @@ -4031,7 +4053,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "948bfeb10ba38b55a8b2db2de8ccfa8f57b44b6d73c98e8e0de8b10f10ce043b" dependencies = [ "base64 0.22.1", - "bincode", + "bincode 1.3.3", "eager", "enum-iterator", "itertools 0.12.1", @@ -4141,7 +4163,7 @@ checksum = "bd96f6a505a492544ee2459b608af3fe07da6c8ffc0bd842489e836ac2c3fce6" dependencies = [ "async-trait", "base64 0.22.1", - "bincode", + "bincode 1.3.3", "bs58", "indicatif", "log", @@ -4203,7 +4225,7 @@ version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24dae5bda29858add4df3a6c5eaf71c0d2042ca3317a9fd81d7e9f436278a1fe" dependencies = [ - "bincode", + "bincode 1.3.3", "bitflags 2.6.0", "borsh 1.5.1", "bs58", @@ -4305,7 +4327,7 @@ version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8c880be4e50ff473b3e82b600162244b6eb28cb5a616dc90ee9232d34998680" dependencies = [ - "bincode", + "bincode 1.3.3", "log", "rayon", "solana-connection-cache", @@ -4321,7 +4343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e65c01edbca303273e735ae383dde54bd5c5b8a051c51162c0ff886b0939ec6" dependencies = [ "async-trait", - "bincode", + "bincode 1.3.3", "futures-util", "indexmap 2.6.0", "indicatif", @@ -4346,7 +4368,7 @@ checksum = "44727bef1f8c57a6ed9a74761d8b7ddfcf4b4e2237cbcc5dc7f8f59985e07755" dependencies = [ "Inflector", "base64 0.22.1", - "bincode", + "bincode 1.3.3", "lazy_static", "log", "rand 0.8.5", @@ -4362,7 +4384,7 @@ checksum = "d51d9d4a6004708f9563a29aa87fdf9960c1e7420b69dd82e8b817cf8f02430b" dependencies = [ "Inflector", "base64 0.22.1", - "bincode", + "bincode 1.3.3", "borsh 1.5.1", "bs58", "lazy_static", @@ -4441,7 +4463,7 @@ version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfd8e539a9963c2914ff8426dfe92351a902892aea465cd507e36d638ca0b7d6" dependencies = [ - "bincode", + "bincode 1.3.3", "log", "num-derive 0.4.2", "num-traits", @@ -4463,7 +4485,7 @@ checksum = "a1dd7a8d6843cb3de4c13c2cfec1994519735ea4110b7f36b80b41d57bea1c07" dependencies = [ "aes-gcm-siv", "base64 0.22.1", - "bincode", + "bincode 1.3.3", "bytemuck", "bytemuck_derive", "byteorder", @@ -5347,6 +5369,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "uriparse" version = "0.6.4" @@ -5399,6 +5427,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "void" version = "1.0.2" diff --git a/Cargo.toml b/Cargo.toml index 584d9a7..1e330b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ path = "src/bin/websocket_server.rs" async-nats = "0.37.0" anyhow = "1.0.89" bigdecimal = "0.4.0" +bincode = "2.0.1" bs58 = "0.5.1" clap = { version = "4.3", features = ["derive", "env"] } config = "0.14" @@ -34,6 +35,7 @@ serde_json = "1.0.128" solana-account-decoder = "2.0.13" solana-client = "2.0.13" solana-sdk = "2.0.13" +solana-metrics = "2.0.13" tokio = { version = "1.40.0", features = ["full"] } tokio-stream = "0.1.16" tokio-tungstenite = "0.24.0" diff --git a/example.reader.config.toml b/example.reader.config.toml deleted file mode 100644 index 2b48b05..0000000 --- a/example.reader.config.toml +++ /dev/null @@ -1,10 +0,0 @@ -[nats] -url = "nats://localhost:4222" - -[pyth] -http_addr = "https://api2.pythnet.pyth.network" -websocket_addr = "wss://api2.pythnet.pyth.network" -program_key = "FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH" - -[price_update] -max_slot_difference = 25 diff --git a/example.websocket.config.toml b/example.websocket.config.toml deleted file mode 100644 index 5776d91..0000000 --- a/example.websocket.config.toml +++ /dev/null @@ -1,8 +0,0 @@ -[nats] -url = "nats://localhost:4222" - -[websocket] -addr = "0.0.0.0:8080" - -[healthcheck] -addr = "0.0.0.0:8081" diff --git a/src/bin/pyth_reader.rs b/src/bin/pyth_reader.rs index bd7a1d6..8df1841 100644 --- a/src/bin/pyth_reader.rs +++ b/src/bin/pyth_reader.rs @@ -3,19 +3,26 @@ use async_nats::jetstream::{self}; use async_nats::HeaderMap; use clap::Parser; use config::Config; +use futures::future::join_all; use pyth_sdk_solana::state::{load_price_account, PriceStatus, PythnetPriceAccount}; use pyth_stream::utils::setup_jetstream; use serde::{Deserialize, Deserializer, Serialize}; use solana_account_decoder::UiAccountEncoding; use solana_client::nonblocking::pubsub_client::PubsubClient; use solana_client::rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig}; +use solana_client::rpc_response::{Response, RpcKeyedAccount}; +use solana_metrics::datapoint_info; use solana_sdk::account::Account; use solana_sdk::commitment_config::CommitmentConfig; use solana_sdk::pubkey::Pubkey; +use std::collections::HashMap; use std::collections::HashSet; use std::path::PathBuf; use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use std::time::Instant; +use tokio::sync::Mutex; use tokio::task; use tokio::time::Duration; use tokio_stream::StreamExt; @@ -23,21 +30,32 @@ use tracing::{debug, error, info, warn}; use tracing_subscriber::{fmt, EnvFilter}; use url::Url; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, bincode::Encode, bincode::Decode)] struct PriceUpdate { #[serde(rename = "type")] update_type: String, price_feed: PriceFeed, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, bincode::Encode, bincode::Decode)] +struct PublisherPriceUpdate { + publisher: String, + feed_id: String, + slot: u64, // Add this field + price: String, +} + +type PublisherKey = (String, String); // (feed_id, publisher) +type PublisherBuffer = HashMap; + +#[derive(Debug, Serialize, Deserialize, bincode::Encode, bincode::Decode)] struct PriceFeed { id: String, price: PriceInfo, ema_price: PriceInfo, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, bincode::Encode, bincode::Decode)] struct PriceInfo { price: String, conf: String, @@ -87,14 +105,114 @@ struct Args { config: Option, } +fn get_price_account_from_update( + update: &Response, +) -> Result { + let account: Account = match update.value.account.decode() { + Some(account) => account, + _none => { + warn!("Failed to decode account from update"); + return Err(anyhow::anyhow!("Failed to decode account from update")); + } + }; + + let price_account: PythnetPriceAccount = match load_price_account(&account.data) { + Ok(pyth_account) => *pyth_account, + Err(_) => { + debug!("Not a price account, skipping"); + return Err(anyhow::anyhow!("Not a price account, skipping")); + } + }; + + Ok(price_account) +} + +async fn publish_price_updates( + jetstream: jetstream::Context, + price_account: PythnetPriceAccount, + update: &Response, +) { + let price_update = PriceUpdate { + update_type: "price_update".to_string(), + price_feed: PriceFeed { + id: update.value.pubkey.to_string(), + price: PriceInfo { + price: price_account.agg.price.to_string(), + conf: price_account.agg.conf.to_string(), + expo: price_account.expo, + publish_time: price_account.timestamp, + slot: update.context.slot, + }, + ema_price: PriceInfo { + price: price_account.ema_price.val.to_string(), + conf: price_account.ema_conf.val.to_string(), + expo: price_account.expo, + publish_time: price_account.timestamp, + slot: update.context.slot, + }, + }, + }; + let price_update_message: Vec = + bincode::encode_to_vec(&price_update, bincode::config::standard()).unwrap(); + // Create a unique message ID + let price_update_message_id = format!( + "{}:{}", + price_update.price_feed.id, price_update.price_feed.price.slot + ); + + // Create headers with the Nats-Msg-Id + let mut price_update_headers = HeaderMap::new(); + price_update_headers.insert("Nats-Msg-Id", price_update_message_id.as_str()); + + match jetstream + .publish_with_headers( + "pyth.price.updates", + price_update_headers, + price_update_message.into(), + ) + .await + { + Ok(_) => debug!( + "Published price update to JetStream with ID: {}", + price_update_message_id + ), + Err(e) => warn!("Failed to publish price update to JetStream: {}", e), + } +} + +/** + * Process the publisher price updates for a given price account and update + * @param price_account: The price account + * @param update: The update + * @param publisher_buffer: The publisher buffer + */ +async fn process_publisher_price_updates( + price_account: PythnetPriceAccount, + update: &Response, + publisher_buffer: &mut PublisherBuffer, +) { + for component in price_account.comp { + if component.publisher != Pubkey::default() { + let publisher_price_update = PublisherPriceUpdate { + feed_id: update.value.pubkey.to_string(), + publisher: component.publisher.to_string(), + price: price_account.agg.price.to_string(), + slot: update.context.slot, + }; + + let key = ( + publisher_price_update.feed_id.clone(), + publisher_price_update.publisher.clone(), + ); + publisher_buffer.insert(key, publisher_price_update); + } + } +} + async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig) -> Result<()> { info!("Starting Pyth reader"); + let mut publisher_buffer: PublisherBuffer = HashMap::new(); let client = PubsubClient::new(config.pyth.websocket_addr.as_str()).await?; - info!( - "Connected to Pyth WebSocket at {}", - config.pyth.websocket_addr - ); - let rpc_config = RpcProgramAccountsConfig { account_config: RpcAccountInfoConfig { commitment: Some(CommitmentConfig::confirmed()), @@ -107,111 +225,108 @@ async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig) }; let (mut notif, _unsub) = client - .program_subscribe(&config.pyth.program_key, Some(rpc_config)) + .program_subscribe(&config.pyth.program_key, Some(rpc_config.clone())) .await?; + info!( + "Connected to Pyth WebSocket at {}", + config.pyth.websocket_addr + ); + let mut update_count = 0; - let mut unique_price_feeds = HashSet::new(); - let mut last_report_time = Instant::now(); - - while let Some(update) = notif.next().await { - debug!("Received price update"); - let account: Account = match update.value.account.decode() { - Some(account) => account, - _none => { - warn!("Failed to decode account from update"); - continue; - } - }; - let price_account: PythnetPriceAccount = match load_price_account(&account.data) { - Ok(pyth_account) => *pyth_account, - Err(_) => { - debug!("Not a price account, skipping"); - continue; + let mut duration_count = 0; + let jetstream_clone = jetstream.clone(); + let mut msg_id_counter = 0; + + let mut interval = tokio::time::interval(Duration::from_millis(50)); + let mut last_seen_slot = 0; + + loop { + tokio::select! { + _ = interval.tick() => { + let instant = Instant::now(); + let updates: Vec = { + if publisher_buffer.is_empty() { + continue; + } + publisher_buffer.drain().map(|(_, v)| v).collect() + }; + + // Serialize as JSON array + let body: Vec = bincode::encode_to_vec(&updates, bincode::config::standard()).unwrap(); + // Use a random ID as Nats-Msg-Id for the batch + let msg_id = format!("publisher_batch:{}", msg_id_counter); + msg_id_counter += 1; + let mut headers = HeaderMap::new(); + headers.insert("Nats-Msg-Id", msg_id.as_str()); + // info!("Serialized publisher updates, size: {}, elapsed time: {:?}, publisher buffer size: {:?}", body.len(), instant.elapsed(), publisher_buffer.len()); + // info!("Publishing {} publisher updates in a batch, total size {}, elapsed time: {:?}", updates.len(), body.len(), instant.elapsed()); + // info!( + // "Average duration: {:?}", + // duration_count/ update_count + // ); + duration_count = 0; + update_count = 0; + if let Err(e) = jetstream_clone + .publish_with_headers("pyth.publisher.updates", headers, body.into()) + .await + { + warn!("Failed to publish batch publisher updates: {}", e); + } else { + debug!("Published {} publisher updates in a batch", updates.len()); + } } - }; - - // We want to send price updates whenever the aggregate changes but sometimes the accounts can change without the aggregate changing - if price_account.agg.status == PriceStatus::Trading - && (update.context.slot - price_account.agg.pub_slot) - <= config.price_update.max_slot_difference - { - debug!( - "Processing valid price update for product: {:?}", - price_account.prod - ); - - let price_update = PriceUpdate { - update_type: "price_update".to_string(), - price_feed: PriceFeed { - id: update.value.pubkey.to_string(), - price: PriceInfo { - price: price_account.agg.price.to_string(), - conf: price_account.agg.conf.to_string(), - expo: price_account.expo, - publish_time: price_account.timestamp, - slot: update.context.slot, // Add this field - }, - ema_price: PriceInfo { - price: price_account.ema_price.val.to_string(), - conf: price_account.ema_conf.val.to_string(), - expo: price_account.expo, - publish_time: price_account.timestamp, - slot: update.context.slot, // Add this field + maybe_update = notif.next() => { + let start_time = Instant::now(); + + let update = match maybe_update { + None => { + let error_msg = "Pythnet network listener stream ended unexpectedly"; + error!("{}", error_msg); + break Ok(()); }, - }, - }; - - let message = serde_json::to_string(&price_update)?; - - // Create a unique message ID - let message_id = format!( - "{}:{}", - price_update.price_feed.id, price_update.price_feed.price.slot - ); - - // Create headers with the Nats-Msg-Id - let mut headers = HeaderMap::new(); - headers.insert("Nats-Msg-Id", message_id.as_str()); - - let jetstream_clone = jetstream.clone(); - task::spawn(async move { - match jetstream_clone - .publish_with_headers("pyth.price.updates", headers, message.into()) - .await + Some(update) => update + }; + + debug!("Received price update"); + let price_account: PythnetPriceAccount = match get_price_account_from_update(&update) { + Ok(account) => account, + _none => { + warn!("Failed to decode account from update"); + continue; + } + }; + + // We want to send price updates whenever the aggregate changes but sometimes the accounts can change without the aggregate changing + if price_account.agg.status == PriceStatus::Trading + && (update.context.slot - price_account.agg.pub_slot) + <= config.price_update.max_slot_difference { - Ok(_) => debug!( - "Published price update to JetStream with ID: {}", - message_id - ), - Err(e) => warn!("Failed to publish price update to JetStream: {}", e), + debug!( + "Processing valid price update for product: {:?}", + price_account.prod + ); + + let jetstream_clone = jetstream.clone(); + publish_price_updates(jetstream_clone, price_account, &update).await; + process_publisher_price_updates(price_account, &update, &mut publisher_buffer).await; + + let end_time = Instant::now(); + let duration = end_time.duration_since(start_time); + update_count += 1; + duration_count += duration.as_micros(); + if update.context.slot > last_seen_slot { + last_seen_slot = update.context.slot; + info!("Processing price update, slot: {}", update.context.slot); + } + + } else { + debug!("Skipping price update due to invalid status or slot difference"); } - }); - - update_count += 1; - unique_price_feeds.insert(price_account.prod); - } else { - debug!("Skipping price update due to invalid status or slot difference"); - } - // Report aggregate information every minute - // TODO: add this as metrics - if last_report_time.elapsed() >= Duration::from_secs(60) { - info!( - "Processed {} updates from {} unique price feeds in the last minute", - update_count, - unique_price_feeds.len() - ); - update_count = 0; - unique_price_feeds.clear(); - last_report_time = Instant::now(); + } } } - - // If we exit the loop, it means the stream has ended - let error_msg = "Pythnet network listener stream ended unexpectedly"; - error!("{}", error_msg); - Err(anyhow::anyhow!(error_msg)) } fn load_config(args: &Args) -> Result { diff --git a/src/bin/websocket_server.rs b/src/bin/websocket_server.rs index 244b492..a5e23c8 100644 --- a/src/bin/websocket_server.rs +++ b/src/bin/websocket_server.rs @@ -1,5 +1,6 @@ use anyhow::{Context, Result}; use async_nats::jetstream::{self, consumer}; +use async_nats::Client; use clap::Parser; use config::Config; use futures::{SinkExt, StreamExt}; @@ -10,6 +11,7 @@ use hyper::{Request, Response}; use hyper_util::rt::TokioIo; use pyth_stream::utils::setup_jetstream; use serde::{Deserialize, Serialize}; +use serde_json::json; use std::clone::Clone; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; @@ -27,10 +29,14 @@ use tracing_subscriber::{fmt, EnvFilter}; #[derive(Debug, Deserialize)] #[serde(tag = "type")] enum ClientMessage { - #[serde(rename = "subscribe")] - Subscribe { ids: Vec }, - #[serde(rename = "unsubscribe")] - Unsubscribe { ids: Vec }, + #[serde(rename = "subscribe_price")] + SubscribePrice { ids: Vec }, + #[serde(rename = "unsubscribe_price")] + UnsubscribePrice { ids: Vec }, + #[serde(rename = "subscribe_publisher")] + SubscribePublisher { ids: Vec }, + #[serde(rename = "unsubscribe_publisher")] + UnsubscribePublisher { ids: Vec }, } #[derive(Debug, Serialize, Deserialize)] @@ -68,8 +74,16 @@ struct Args { #[arg(short, long, env = "WEBSOCKET_SERVER_CONFIG_FILE", value_name = "FILE")] config: Option, } +#[derive(Debug)] +struct ClientData { + price_subscriptions: HashSet, + publisher_subscriptions: HashSet, + sender: mpsc::UnboundedSender, +} + +type Clients = Arc>>; -type Clients = Arc, mpsc::UnboundedSender)>>>; +// type Clients = Arc, mpsc::UnboundedSender)>>>; static NATS_CONNECTED: AtomicBool = AtomicBool::new(false); static WS_LISTENER_ACTIVE: AtomicBool = AtomicBool::new(false); @@ -155,7 +169,14 @@ async fn handle_connection(stream: TcpStream, clients: Clients) -> Result<()> { { let mut clients = clients.lock().await; - clients.insert(addr.to_string(), (HashSet::new(), tx)); + clients.insert( + addr.to_string(), + ClientData { + price_subscriptions: HashSet::new(), + publisher_subscriptions: HashSet::new(), + sender: tx, + }, + ); } let clients_clone = clients.clone(); @@ -175,10 +196,10 @@ async fn handle_connection(stream: TcpStream, clients: Clients) -> Result<()> { match msg { Ok(msg) => { if let Message::Text(text) = msg { - debug!("Received message from client {}: {}", addr, text); + info!("Received message from client {}: {}", addr, text); if let Err(e) = handle_client_message(&addr.to_string(), &text, &clients).await { - error!("Error handling client message: {}", e); + error!("Error handling client message: {}, content: {}", e, text); break; } } @@ -200,29 +221,39 @@ async fn handle_client_message(addr: &str, text: &str, clients: &Clients) -> Res let client_msg: ClientMessage = serde_json::from_str(text)?; let response = match client_msg { - ClientMessage::Subscribe { ids, .. } => handle_subscribe(addr, ids, clients).await, - ClientMessage::Unsubscribe { ids } => handle_unsubscribe(addr, ids, clients).await, + ClientMessage::SubscribePrice { ids, .. } => { + handle_subscribe_price(addr, ids, clients).await + } + ClientMessage::UnsubscribePrice { ids } => { + handle_unsubscribe_price(addr, ids, clients).await + } + ClientMessage::SubscribePublisher { ids } => { + handle_subscribe_publisher(addr, ids, clients).await + } + ClientMessage::UnsubscribePublisher { ids } => { + handle_unsubscribe_publisher(addr, ids, clients).await + } }; let response_json = serde_json::to_string(&response)?; let clients = clients.lock().await; - if let Some((_, sender)) = clients.get(addr) { - if let Err(e) = sender.send(Message::Text(response_json)) { + if let Some(client_data) = clients.get(addr) { + if let Err(e) = client_data.sender.send(Message::Text(response_json)) { error!("Channel send error: {}", e); } } Ok(()) } -async fn handle_subscribe(addr: &str, ids: Vec, clients: &Clients) -> ServerResponse { +async fn handle_subscribe_price(addr: &str, ids: Vec, clients: &Clients) -> ServerResponse { let mut clients = clients.lock().await; - if let Some((subscriptions, _)) = clients.get_mut(addr) { + if let Some(client_data) = clients.get_mut(addr) { for mut hex_id in ids { if !hex_id.starts_with("0x") { hex_id = format!("0x{}", hex_id); } - subscriptions.insert(hex_id.clone()); - debug!("Client {} subscribed to feed ID: {}", addr, hex_id); + client_data.price_subscriptions.insert(hex_id.clone()); + info!("Client {} subscribed to feed ID: {}", addr, hex_id); } } else { warn!("Client {} not found in clients map", addr); @@ -233,11 +264,54 @@ async fn handle_subscribe(addr: &str, ids: Vec, clients: &Clients) -> Se } } -async fn handle_unsubscribe(addr: &str, ids: Vec, clients: &Clients) -> ServerResponse { +async fn handle_unsubscribe_price( + addr: &str, + ids: Vec, + clients: &Clients, +) -> ServerResponse { let mut clients = clients.lock().await; - if let Some((subscriptions, _)) = clients.get_mut(addr) { + if let Some(client_data) = clients.get_mut(addr) { for id in ids { - subscriptions.remove(&id); + client_data.price_subscriptions.remove(&id); + } + } + ServerResponse { + message_type: "response".to_string(), + status: "success".to_string(), + } +} + +async fn handle_subscribe_publisher( + addr: &str, + ids: Vec, + clients: &Clients, +) -> ServerResponse { + let mut clients = clients.lock().await; + if let Some(client_data) = clients.get_mut(addr) { + for publisher_id in ids { + client_data + .publisher_subscriptions + .insert(publisher_id.clone()); + info!("Client {} subscribed to publisher: {}", addr, publisher_id); + } + } else { + warn!("Client {} not found in clients map", addr); + } + ServerResponse { + message_type: "response".to_string(), + status: "success".to_string(), + } +} + +async fn handle_unsubscribe_publisher( + addr: &str, + ids: Vec, + clients: &Clients, +) -> ServerResponse { + let mut clients = clients.lock().await; + if let Some(client_data) = clients.get_mut(addr) { + for id in ids { + client_data.publisher_subscriptions.remove(&id); } } ServerResponse { @@ -253,6 +327,14 @@ struct PriceUpdate { price_feed: PriceFeed, } +#[derive(Debug, Serialize, Deserialize, bincode::Encode, bincode::Decode)] +struct PublisherPriceUpdate { + publisher: String, + feed_id: String, + slot: u64, + price: String, +} + #[derive(Debug, Serialize, Deserialize)] struct PriceFeed { id: String, @@ -269,7 +351,120 @@ struct PriceInfo { slot: u64, // Add this field } -async fn handle_nats_messages(jetstream: jetstream::Context, clients: Clients) -> Result<()> { +async fn handle_nats_publisher_updates_messages( + jetstream: &jetstream::Context, + clients: &Clients, +) -> Result<()> { + let stream_name = "PYTH_PUBLISHER_UPDATES"; + + let consumer_config = consumer::pull::Config { + deliver_policy: consumer::DeliverPolicy::All, + ack_policy: consumer::AckPolicy::None, + ..Default::default() + }; + + let consumer = jetstream + .create_consumer_on_stream(consumer_config, stream_name) + .await + .context("Failed to create NATS consumer")?; + + info!(stream = %stream_name, "Started handling NATS publisher updates messages"); + + loop { + let mut messages = consumer + .messages() + .await + .context("Failed to get messages from NATS consumer")?; + + while let Some(msg) = messages.next().await { + match msg { + Ok(msg) => { + let updates: Vec = + match bincode::decode_from_slice(&msg.payload, bincode::config::standard()) + { + Ok((updates, _)) => updates, + Err(e) => { + warn!(error = %e, "Failed to parse publisher price update batch"); + continue; + } + }; + info!("Parsed {} publisher updates in batch", updates.len()); + // Build per-client payloads while holding the lock, + // but DO NOT send while holding it. + let mut to_send: Vec<(String, mpsc::UnboundedSender, String)> = + Vec::new(); + { + let clients = clients.lock().await; + + for (client_addr, client_data) in clients.iter() { + // Filter only updates the client cares about + let filtered: Vec<_> = updates + .iter() + .filter(|u| { + client_data.publisher_subscriptions.contains(&u.publisher) + }) + .map(|u| { + json!({ + "publisher": u.publisher, + "feed_id": u.feed_id, + "slot": u.slot, + "price": u.price, + }) + }) + .collect(); + if filtered.is_empty() { + continue; + } + + info!( + "Preparing batch for client {} ({} updates, subs={:?})", + client_addr, + filtered.len(), + client_data.publisher_subscriptions + ); + + let batch_json = serde_json::to_string(&json!({ + "type": "publisher_price_update", + "updates": filtered + })) + .unwrap(); + + // Clone the sender so we can drop the lock before sending + to_send.push(( + client_addr.clone(), + client_data.sender.clone(), + batch_json, + )); + } + } + + // Now send the prepared batches + for (client_addr, sender, batch_json) in to_send { + info!( + "Sending {}-byte batch to client {}", + batch_json.len(), + client_addr + ); + if let Err(e) = sender.send(Message::Text(batch_json)) { + warn!(client_addr = %client_addr, error = %e, "Failed to send publisher batch"); + } else { + info!(client_addr = %client_addr, "Successfully sent publisher batch"); + } + } + } + Err(e) => { + error!(error = %e, "Error receiving message from NATS"); + return Err(e.into()); + } + } + } + } +} + +async fn handle_nats_price_updates_messages( + jetstream: &jetstream::Context, + clients: &Clients, +) -> Result<()> { let stream_name = "PYTH_PRICE_UPDATES"; let consumer_config = consumer::pull::Config { @@ -283,7 +478,7 @@ async fn handle_nats_messages(jetstream: jetstream::Context, clients: Clients) - .await .context("Failed to create NATS consumer")?; - info!(stream = %stream_name, "Started handling NATS messages"); + info!(stream = %stream_name, "Started handling NATS price updates messages"); loop { let mut messages = consumer @@ -320,11 +515,11 @@ async fn handle_nats_messages(jetstream: jetstream::Context, clients: Clients) - let clients = clients.lock().await; debug!("Number of connected clients: {}", clients.len()); - for (client_addr, (subscriptions, sender)) in clients.iter() { - if subscriptions.contains(&hex_id) { + for (client_addr, client_data) in clients.iter() { + if client_data.price_subscriptions.contains(&hex_id) { debug!("Sending update to client: {}", client_addr); let update_json = serde_json::to_string(&price_update).unwrap(); - if let Err(e) = sender.send(Message::Text(update_json)) { + if let Err(e) = client_data.sender.send(Message::Text(update_json)) { warn!(client_addr = %client_addr, error = %e, "Failed to send price update to client"); } else { debug!( @@ -362,7 +557,8 @@ async fn connect_and_handle_nats(config: &NatsConfig, clients: Clients) -> Resul NATS_CONNECTED.store(true, Ordering::SeqCst); info!("Connected to NATS server"); - handle_nats_messages(jetstream, clients).await?; + // handle_nats_price_updates_messages(&jetstream, &clients).await?; + handle_nats_publisher_updates_messages(&jetstream, &clients).await?; NATS_CONNECTED.store(false, Ordering::SeqCst); Ok(()) diff --git a/src/utils.rs b/src/utils.rs index dec2dcf..1946269 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -6,10 +6,22 @@ use tracing::info; pub async fn setup_jetstream(nats_client: &async_nats::Client) -> Result { let jetstream = jetstream::new(nats_client.clone()); - let stream_config = stream::Config { + let price_updates_stream_config = stream::Config { name: "PYTH_PRICE_UPDATES".to_string(), subjects: vec!["pyth.price.updates".to_string()], - max_bytes: 1024 * 1024 * 4000, + max_bytes: 1024 * 1024 * 40000, + duplicate_window: Duration::from_secs(60), + discard: stream::DiscardPolicy::Old, + allow_direct: true, + allow_rollup: true, + ..Default::default() + }; + + let publisher_updates_stream_config = stream::Config { + name: "PYTH_PUBLISHER_UPDATES".to_string(), + subjects: vec!["pyth.publisher.updates".to_string()], + max_bytes: 1024 * 1024 * 40000, + max_messages: 1000000, duplicate_window: Duration::from_secs(60), discard: stream::DiscardPolicy::Old, allow_direct: true, @@ -21,15 +33,32 @@ pub async fn setup_jetstream(nats_client: &async_nats::Client) -> Result { // Stream exists, update its configuration - jetstream.update_stream(stream_config).await?; + jetstream.update_stream(price_updates_stream_config).await?; info!("JetStream stream updated: PYTH_PRICE_UPDATES"); } Err(_) => { // Stream doesn't exist, create it - jetstream.create_stream(stream_config).await?; + jetstream.create_stream(price_updates_stream_config).await?; info!("JetStream stream created: PYTH_PRICE_UPDATES"); } } + match jetstream.get_stream("PYTH_PUBLISHER_UPDATES").await { + Ok(_) => { + // Stream exists, update its configuration + jetstream + .update_stream(publisher_updates_stream_config) + .await?; + info!("JetStream stream updated: PYTH_PUBLISHER_UPDATES"); + } + Err(_) => { + // Stream doesn't exist, create it + jetstream + .create_stream(publisher_updates_stream_config) + .await?; + info!("JetStream stream created: PYTH_PUBLISHER_UPDATES"); + } + } + Ok(jetstream) }