From 0e955fcc1136e1e3b409cb2bfbb82fc775b7c6af Mon Sep 17 00:00:00 2001 From: zensgit <77236085+zensgit@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:25:50 +0800 Subject: [PATCH 1/2] api: enforce single default ledger, JWT env secret, export_stream feature scaffolding --- .env.example | 7 ++- README.md | 10 +++ jive-api/Cargo.toml | 2 + .../028_add_unique_default_ledger_index.sql | 14 +++++ jive-api/src/auth.rs | 24 ++++++-- jive-api/src/handlers/transactions.rs | 61 ++++++++++++++++--- .../integration/family_default_ledger_test.rs | 16 ++++- 7 files changed, 120 insertions(+), 14 deletions(-) create mode 100644 jive-api/migrations/028_add_unique_default_ledger_index.sql diff --git a/.env.example b/.env.example index 9208ba4f..186c7dd7 100644 --- a/.env.example +++ b/.env.example @@ -22,4 +22,9 @@ HTTP_PORT=80 HTTPS_PORT=443 # Environment -ENVIRONMENT=development \ No newline at end of file +ENVIRONMENT=development +# ===================== +# Authentication / JWT +# ===================== +# Set a strong random secret in real deployments. If unset, a dev fallback is used (unsafe for production). +JWT_SECRET=please-change-me-in-production diff --git a/README.md b/README.md index f270fc26..9c9da3c7 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,16 @@ curl -s http://localhost:8012/health make db-dev-down ``` +### JWT 密钥配置 + +环境变量 `JWT_SECRET` 用于签发与验证访问令牌: + +```bash +export JWT_SECRET=$(openssl rand -hex 32) +``` + +未设置时(或留空)API 会在开发 / 测试自动使用一个不安全的占位并打印警告,不可在生产依赖该默认值。 + ### 超级管理员默认密码说明 仓库历史存在两个默认密码基线: diff --git a/jive-api/Cargo.toml b/jive-api/Cargo.toml index 3c86e941..f19f80ae 100644 --- a/jive-api/Cargo.toml +++ b/jive-api/Cargo.toml @@ -76,6 +76,8 @@ demo_endpoints = [] # Enable to use jive-core export service paths # When core_export is enabled, also enable jive-core's db feature so CSV helpers are available. core_export = ["dep:jive-core"] +# Stream CSV export incrementally instead of buffering whole response +export_stream = [] [dev-dependencies] tokio-test = "0.4" diff --git a/jive-api/migrations/028_add_unique_default_ledger_index.sql b/jive-api/migrations/028_add_unique_default_ledger_index.sql new file mode 100644 index 00000000..12078ccf --- /dev/null +++ b/jive-api/migrations/028_add_unique_default_ledger_index.sql @@ -0,0 +1,14 @@ +-- 028_add_unique_default_ledger_index.sql +-- Enforce at most one default ledger per family. +-- Safe to run multiple times (IF NOT EXISTS guard). + +CREATE UNIQUE INDEX IF NOT EXISTS idx_ledgers_one_default + ON ledgers(family_id) + WHERE is_default = true; + +-- Rationale: +-- Business rule: each family must have a single canonical default ledger used for +-- category and transaction fallbacks. Prior logic relies on code discipline; this +-- index guarantees integrity at the database layer and prevents race conditions +-- where two concurrent creations might both mark default. + diff --git a/jive-api/src/auth.rs b/jive-api/src/auth.rs index 66873b12..8a834c09 100644 --- a/jive-api/src/auth.rs +++ b/jive-api/src/auth.rs @@ -12,8 +12,24 @@ use serde::{Deserialize, Serialize}; use std::fmt::Display; use uuid::Uuid; -/// JWT密钥(实际生产中应该从环境变量读取) -const JWT_SECRET: &str = "your-secret-key-change-this-in-production"; +/// 获取 JWT 密钥(优先环境变量 JWT_SECRET;未设置时使用不安全占位并在非测试模式下警告) +fn jwt_secret() -> &'static str { + // Use once_cell to cache environment lookup + use std::sync::OnceLock; + static SECRET: OnceLock = OnceLock::new(); + SECRET.get_or_init(|| { + match std::env::var("JWT_SECRET") { + Ok(v) if !v.trim().is_empty() => v, + _ => { + // Fallback (dev/test only). In production this should be set; emit warning. + if !cfg!(test) { + eprintln!("WARNING: JWT_SECRET not set; using insecure default key"); + } + "insecure-dev-jwt-secret-change-me".to_string() + } + } + }) +} /// JWT Claims #[derive(Debug, Serialize, Deserialize, Clone)] @@ -51,7 +67,7 @@ impl Claims { let token = encode( &Header::default(), self, - &EncodingKey::from_secret(JWT_SECRET.as_ref()), + &EncodingKey::from_secret(jwt_secret().as_bytes()), ) .map_err(|_| AuthError::TokenCreation)?; @@ -62,7 +78,7 @@ impl Claims { pub fn from_token(token: &str) -> Result { let token_data = decode::( token, - &DecodingKey::from_secret(JWT_SECRET.as_ref()), + &DecodingKey::from_secret(jwt_secret().as_bytes()), &Validation::default(), ) .map_err(|_| AuthError::InvalidToken)?; diff --git a/jive-api/src/handlers/transactions.rs b/jive-api/src/handlers/transactions.rs index b95b887b..390b5696 100644 --- a/jive-api/src/handlers/transactions.rs +++ b/jive-api/src/handlers/transactions.rs @@ -446,13 +446,60 @@ pub async fn export_transactions_csv_stream( } query.push(" ORDER BY t.transaction_date DESC, t.id DESC"); - // Execute fully and build CSV body (simple, reliable) - let rows_all = query - .build() - .fetch_all(&pool) - .await - .map_err(|e| ApiError::DatabaseError(format!("查询交易失败: {}", e)))?; - // Build response body bytes depending on feature flag + // When export_stream feature enabled, stream rows instead of buffering entire CSV + #[cfg(feature = "export_stream")] + { + use futures::{StreamExt}; + use tokio_stream::wrappers::ReceiverStream; + use tokio::sync::mpsc; + let include_header = q.include_header.unwrap_or(true); + let (tx, rx) = mpsc::channel::>(8); + let built = query.build(); + let pool_clone = pool.clone(); + tokio::spawn(async move { + let mut stream = built.fetch_many(&pool_clone); + // Header + if include_header { + if tx.send(Ok(bytes::Bytes::from_static(b"Date,Description,Amount,Category,Account,Payee,Type\n"))).await.is_err() { return; } + } + while let Some(item) = stream.next().await { + match item { + Ok(sqlx::Either::Right(row)) => { + use sqlx::Row; + let date: NaiveDate = row.get("transaction_date"); + let desc: String = row.try_get::("description").unwrap_or_default(); + let amount: Decimal = row.get("amount"); + let category: Option = row.try_get::("category_name").ok().filter(|s| !s.is_empty()); + let account_id: Uuid = row.get("account_id"); + let payee: Option = row.try_get::("payee_name").ok().filter(|s| !s.is_empty()); + let ttype: String = row.get("transaction_type"); + let line = format!("{},{},{},{},{},{},{}\n", + date, + csv_escape_cell(desc, ','), + amount, + csv_escape_cell(category.clone().unwrap_or_default(), ','), + account_id, + csv_escape_cell(payee.clone().unwrap_or_default(), ','), + csv_escape_cell(ttype, ',')); + if tx.send(Ok(bytes::Bytes::from(line))).await.is_err() { return; } + } + Ok(sqlx::Either::Left(_)) => { /* ignore query result count */ } + Err(e) => { let _ = tx.send(Err(ApiError::DatabaseError(e.to_string()))).await; return; } + } + } + }); + let byte_stream = ReceiverStream::new(rx).map(|r| match r { Ok(b)=>Ok::<_,ApiError>(b), Err(e)=>Err(e) }); + let body = Body::from_stream(byte_stream.map(|res| res.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "stream error")))); + // Build headers & return early (skip buffered path below) + let mut headers_map = header::HeaderMap::new(); + headers_map.insert(header::CONTENT_TYPE, "text/csv; charset=utf-8".parse().unwrap()); + let filename = format!("transactions_export_{}.csv", Utc::now().format("%Y%m%d%H%M%S")); + headers_map.insert(header::CONTENT_DISPOSITION, format!("attachment; filename=\"{}\"", filename).parse().unwrap()); + return Ok((headers_map, body)); + } + + // Execute fully and build CSV body when streaming disabled + let rows_all = query.build().fetch_all(&pool).await.map_err(|e| ApiError::DatabaseError(format!("查询交易失败: {}", e)))?; #[cfg(feature = "core_export")] let body_bytes: Vec = { let include_header = q.include_header.unwrap_or(true); diff --git a/jive-api/tests/integration/family_default_ledger_test.rs b/jive-api/tests/integration/family_default_ledger_test.rs index 80322850..8eccbe6a 100644 --- a/jive-api/tests/integration/family_default_ledger_test.rs +++ b/jive-api/tests/integration/family_default_ledger_test.rs @@ -18,7 +18,7 @@ mod tests { let user_id = uc.user_id; let family_id = uc.current_family_id.expect("family id"); - // Query ledger(s) + // Query ledger(s) – should be exactly one default #[derive(sqlx::FromRow, Debug)] struct LedgerRow { id: uuid::Uuid, family_id: uuid::Uuid, is_default: Option, created_by: Option, name: String } let ledgers = sqlx::query_as::<_, LedgerRow>( @@ -34,6 +34,19 @@ mod tests { assert_eq!(ledger.created_by.unwrap(), user_id, "created_by should be owner user_id"); assert_eq!(ledger.name, "默认账本"); + // Attempt to manually insert a second default ledger to ensure DB uniqueness is enforced + let second_id = uuid::Uuid::new_v4(); + let dup = sqlx::query( + "INSERT INTO ledgers (id, family_id, name, currency, created_by, is_default, created_at, updated_at) VALUES ($1,$2,$3,'CNY',$4,true,NOW(),NOW())" + ) + .bind(second_id) + .bind(family_id) + .bind("竞争默认账本") + .bind(user_id) + .execute(&pool) + .await; + assert!(dup.is_err(), "second default ledger insertion should fail due to unique index"); + // Also ensure service context can fetch families list for sanity let fam_service = FamilyService::new(pool.clone()); let families = fam_service.get_user_families(user_id).await.expect("user families"); @@ -46,4 +59,3 @@ mod tests { .ok(); } } - From 8a449f1d30bbbfaeee6c72b0fc1f9ddd5ec7d3bc Mon Sep 17 00:00:00 2001 From: zensgit <77236085+zensgit@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:33:34 +0800 Subject: [PATCH 2/2] fix: apply rustfmt formatting to transactions.rs --- jive-api/src/handlers/transactions.rs | 74 +++++++++++++++++++++------ 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/jive-api/src/handlers/transactions.rs b/jive-api/src/handlers/transactions.rs index 390b5696..90229ba3 100644 --- a/jive-api/src/handlers/transactions.rs +++ b/jive-api/src/handlers/transactions.rs @@ -449,9 +449,9 @@ pub async fn export_transactions_csv_stream( // When export_stream feature enabled, stream rows instead of buffering entire CSV #[cfg(feature = "export_stream")] { - use futures::{StreamExt}; - use tokio_stream::wrappers::ReceiverStream; + use futures::StreamExt; use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; let include_header = q.include_header.unwrap_or(true); let (tx, rx) = mpsc::channel::>(8); let built = query.build(); @@ -460,46 +460,88 @@ pub async fn export_transactions_csv_stream( let mut stream = built.fetch_many(&pool_clone); // Header if include_header { - if tx.send(Ok(bytes::Bytes::from_static(b"Date,Description,Amount,Category,Account,Payee,Type\n"))).await.is_err() { return; } + if tx + .send(Ok(bytes::Bytes::from_static( + b"Date,Description,Amount,Category,Account,Payee,Type\n", + ))) + .await + .is_err() + { + return; + } } while let Some(item) = stream.next().await { match item { Ok(sqlx::Either::Right(row)) => { use sqlx::Row; let date: NaiveDate = row.get("transaction_date"); - let desc: String = row.try_get::("description").unwrap_or_default(); + let desc: String = + row.try_get::("description").unwrap_or_default(); let amount: Decimal = row.get("amount"); - let category: Option = row.try_get::("category_name").ok().filter(|s| !s.is_empty()); + let category: Option = row + .try_get::("category_name") + .ok() + .filter(|s| !s.is_empty()); let account_id: Uuid = row.get("account_id"); - let payee: Option = row.try_get::("payee_name").ok().filter(|s| !s.is_empty()); + let payee: Option = row + .try_get::("payee_name") + .ok() + .filter(|s| !s.is_empty()); let ttype: String = row.get("transaction_type"); - let line = format!("{},{},{},{},{},{},{}\n", + let line = format!( + "{},{},{},{},{},{},{}\n", date, csv_escape_cell(desc, ','), amount, csv_escape_cell(category.clone().unwrap_or_default(), ','), account_id, csv_escape_cell(payee.clone().unwrap_or_default(), ','), - csv_escape_cell(ttype, ',')); - if tx.send(Ok(bytes::Bytes::from(line))).await.is_err() { return; } + csv_escape_cell(ttype, ',') + ); + if tx.send(Ok(bytes::Bytes::from(line))).await.is_err() { + return; + } } Ok(sqlx::Either::Left(_)) => { /* ignore query result count */ } - Err(e) => { let _ = tx.send(Err(ApiError::DatabaseError(e.to_string()))).await; return; } + Err(e) => { + let _ = tx.send(Err(ApiError::DatabaseError(e.to_string()))).await; + return; + } } } }); - let byte_stream = ReceiverStream::new(rx).map(|r| match r { Ok(b)=>Ok::<_,ApiError>(b), Err(e)=>Err(e) }); - let body = Body::from_stream(byte_stream.map(|res| res.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "stream error")))); + let byte_stream = ReceiverStream::new(rx).map(|r| match r { + Ok(b) => Ok::<_, ApiError>(b), + Err(e) => Err(e), + }); + let body = Body::from_stream(byte_stream.map(|res| { + res.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "stream error")) + })); // Build headers & return early (skip buffered path below) let mut headers_map = header::HeaderMap::new(); - headers_map.insert(header::CONTENT_TYPE, "text/csv; charset=utf-8".parse().unwrap()); - let filename = format!("transactions_export_{}.csv", Utc::now().format("%Y%m%d%H%M%S")); - headers_map.insert(header::CONTENT_DISPOSITION, format!("attachment; filename=\"{}\"", filename).parse().unwrap()); + headers_map.insert( + header::CONTENT_TYPE, + "text/csv; charset=utf-8".parse().unwrap(), + ); + let filename = format!( + "transactions_export_{}.csv", + Utc::now().format("%Y%m%d%H%M%S") + ); + headers_map.insert( + header::CONTENT_DISPOSITION, + format!("attachment; filename=\"{}\"", filename) + .parse() + .unwrap(), + ); return Ok((headers_map, body)); } // Execute fully and build CSV body when streaming disabled - let rows_all = query.build().fetch_all(&pool).await.map_err(|e| ApiError::DatabaseError(format!("查询交易失败: {}", e)))?; + let rows_all = query + .build() + .fetch_all(&pool) + .await + .map_err(|e| ApiError::DatabaseError(format!("查询交易失败: {}", e)))?; #[cfg(feature = "core_export")] let body_bytes: Vec = { let include_header = q.include_header.unwrap_or(true);