Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions jive-api/src/handlers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,62 +124,76 @@ pub async fn register(
.map_err(|_| ApiError::InternalServerError)?
.to_string();

// 创建用户
// 创建用户与家庭的 ID
let user_id = Uuid::new_v4();
let family_id = Uuid::new_v4(); // 为新用户创建默认家庭
let family_id = Uuid::new_v4();

// 开始事务
let mut tx = pool.begin().await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// 创建家庭
sqlx::query(
r#"
INSERT INTO families (id, name, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
"#
)
.bind(family_id)
.bind(format!("{}'s Family", req.name))
.execute(&mut *tx)
.await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// 创建用户(将 name 写入 name 与 full_name,便于后续使用)
// 先创建用户(避免 families.owner_id 外键约束失败)
tracing::info!(target: "auth_register", user_id = %user_id, family_id = %family_id, email = %final_email, "Creating user then family with owner_id");
sqlx::query(
r#"
INSERT INTO users (
id, email, username, full_name, password_hash, current_family_id,
status, email_verified, created_at, updated_at
id, email, username, name, full_name, password_hash,
is_active, email_verified, created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5, $6, 'active', false, NOW(), NOW()
$1, $2, $3, $4, $5, $6,
true, false, NOW(), NOW()
)
"#
)
.bind(user_id)
.bind(&final_email)
.bind(&username_opt)
.bind(&req.name)
.bind(&req.name)
.bind(password_hash)
.execute(&mut *tx)
.await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// 再创建家庭(带 owner_id)
tracing::info!(target: "auth_register", user_id = %user_id, family_id = %family_id, "Inserting family with owner_id in register");
sqlx::query(
r#"
INSERT INTO families (id, name, owner_id, created_at, updated_at)
VALUES ($1, $2, $3, NOW(), NOW())
"#
)
.bind(family_id)
.bind(format!("{}'s Family", req.name))
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// 创建默认账本
// 创建默认账本(标记 is_default,记录创建者)
let ledger_id = Uuid::new_v4();
sqlx::query(
r#"
INSERT INTO ledgers (id, family_id, name, currency, created_at, updated_at)
VALUES ($1, $2, '默认账本', 'CNY', NOW(), NOW())
INSERT INTO ledgers (id, family_id, name, currency, created_by, is_default, created_at, updated_at)
VALUES ($1, $2, '默认账本', 'CNY', $3, true, NOW(), NOW())
"#
)
.bind(ledger_id)
.bind(family_id)
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// 绑定用户的当前家庭并提交事务
tracing::info!(target: "auth_register", user_id = %user_id, family_id = %family_id, "Binding current_family_id and committing");
sqlx::query("UPDATE users SET current_family_id = $1 WHERE id = $2")
.bind(family_id)
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// 提交事务
tx.commit().await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;
Expand Down
13 changes: 9 additions & 4 deletions jive-api/src/handlers/enhanced_profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub async fn register_with_preferences(
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;

// Create family with user's preferences
tracing::info!(target: "enhanced_register", user_id = %user_id, name = %req.name, "Creating family via FamilyService (owner_id)");
let family_service = FamilyService::new(pool.clone());
let family_request = CreateFamilyRequest {
name: Some(format!("{}的家庭", req.name)),
Expand All @@ -155,12 +156,16 @@ pub async fn register_with_preferences(
locale: Some(req.language.clone()),
};

let family = family_service
.create_family(user_id, family_request)
.await
.map_err(|_e| ApiError::InternalServerError)?;
let family = match family_service.create_family(user_id, family_request).await {
Ok(f) => f,
Err(e) => {
tracing::error!(target: "enhanced_register", error=?e, user_id=%user_id, "create_family failed");
return Err(ApiError::InternalServerError);
}
};

// Update user's current family
tracing::info!(target: "enhanced_register", user_id = %user_id, family_id = %family.id, "Binding current_family_id after enhanced register");
sqlx::query("UPDATE users SET current_family_id = $1 WHERE id = $2")
.bind(family.id)
.bind(user_id)
Expand Down
2 changes: 1 addition & 1 deletion jive-api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.route("/api/v1/rules/execute", post(execute_rules))

// 认证 API
.route("/api/v1/auth/register", post(auth_handlers::register_with_family))
.route("/api/v1/auth/register", post(auth_handlers::register))
.route("/api/v1/auth/login", post(auth_handlers::login))
.route("/api/v1/auth/refresh", post(auth_handlers::refresh_token))
.route("/api/v1/auth/user", get(auth_handlers::get_current_user))
Expand Down
8 changes: 5 additions & 3 deletions jive-api/src/services/family_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,20 @@ impl FamilyService {
};

// Create family
tracing::info!(target: "family_service", user_id = %user_id, name = %family_name, "Inserting family with owner_id");
let family_id = Uuid::new_v4();
let invite_code = Family::generate_invite_code();

let family = sqlx::query_as::<_, Family>(
r#"
INSERT INTO families (id, name, currency, timezone, locale, invite_code, member_count, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, 1, $7, $8)
INSERT INTO families (id, name, owner_id, currency, timezone, locale, invite_code, member_count, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, 1, $8, $9)
RETURNING *
"#
)
.bind(family_id)
.bind(&family_name)
.bind(user_id)
.bind(request.currency.as_deref().unwrap_or("CNY"))
.bind(request.timezone.as_deref().unwrap_or("Asia/Shanghai"))
.bind(request.locale.as_deref().unwrap_or("zh-CN"))
Expand Down Expand Up @@ -103,7 +105,7 @@ impl FamilyService {
// Create default ledger
sqlx::query(
r#"
INSERT INTO ledgers (id, family_id, name, currency, owner_id, is_default, created_at, updated_at)
INSERT INTO ledgers (id, family_id, name, currency, created_by, is_default, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, true, $6, $7)
"#
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#[cfg(test)]
mod tests {
use axum::{Router, routing::{post, get}};
use http::{Request, header, StatusCode};
use hyper::Body;
use tower::ServiceExt; // for oneshot
use serde_json::json;
use uuid::Uuid;

use jive_money_api::handlers::{enhanced_profile::register_with_preferences, transactions::export_transactions_csv_stream};
use crate::fixtures::create_test_pool;

#[tokio::test]
async fn register_enhanced_route_creates_family_and_allows_export() {
let pool = create_test_pool().await;

let app = Router::new()
.route("/api/v1/auth/register-enhanced", post(register_with_preferences))
.route("/api/v1/transactions/export.csv", get(export_transactions_csv_stream))
.with_state(pool.clone());

let email = format!("enh_{}@example.com", Uuid::new_v4());
let body = json!({
"email": email,
"password": "EnhE2e123!",
"name": "EnhE2E",
"country": "CN",
"currency": "CNY",
"language": "zh-CN",
"timezone": "Asia/Shanghai",
"date_format": "YYYY-MM-DD"
});

let req = Request::builder()
.method("POST")
.uri("/api/v1/auth/register-enhanced")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK, "register-enhanced should return 200");
let bytes = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
let token = v.pointer("/data/token").and_then(|x| x.as_str()).unwrap_or("");
assert!(!token.is_empty(), "token should be present");

let req2 = Request::builder()
.method("GET")
.uri("/api/v1/transactions/export.csv?include_header=true")
.header(header::AUTHORIZATION, format!("Bearer {}", token))
.body(Body::empty())
.unwrap();
let resp2 = app.clone().oneshot(req2).await.unwrap();
assert_eq!(resp2.status(), StatusCode::OK);
let body_bytes = hyper::body::to_bytes(resp2.into_body()).await.unwrap();
assert!(body_bytes.starts_with(b"Date,Description"), "CSV header missing or incorrect");
}
}

93 changes: 93 additions & 0 deletions jive-api/tests/integration/auth_register_route_e2e_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#[cfg(test)]
mod tests {
use axum::{Router, routing::{post, get}};
use http::{Request, header, StatusCode};
use hyper::Body;
use tower::ServiceExt; // for oneshot
use serde_json::json;
use uuid::Uuid;

use jive_money_api::handlers::{auth, transactions::export_transactions_csv_stream};
use crate::fixtures::create_test_pool;

#[tokio::test]
async fn register_route_creates_family_and_default_ledger_and_allows_export() {
let pool = create_test_pool().await;

// Build minimal router for the two endpoints under test
let app = Router::new()
.route("/api/v1/auth/register", post(auth::register))
.route("/api/v1/transactions/export.csv", get(export_transactions_csv_stream))
.with_state(pool.clone());

// Unique username-style email (no @) to exercise username path as well
let uname = format!("route_e2e_{}", Uuid::new_v4());
let body = json!({
"email": uname,
"password": "RouteE2e123!",
"name": "RouteE2E"
});
let req = Request::builder()
.method("POST")
.uri("/api/v1/auth/register")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK, "register should return 200");

let bytes = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
let token = v.get("token").and_then(|x| x.as_str()).unwrap_or("");
assert!(!token.is_empty(), "token should be present in register response");

// Find created user_id from response and assert family/ledger rows
let user_id: Uuid = serde_json::from_value(v.get("user_id").cloned().unwrap()).unwrap();

// families.owner_id must equal user_id
let fam_row: Option<(Uuid, Uuid)> = sqlx::query_as(
"SELECT id, owner_id FROM families WHERE owner_id = $1 ORDER BY created_at DESC LIMIT 1"
)
.bind(user_id)
.fetch_optional(&pool)
.await
.expect("query families");
let (family_id, owner_id) = fam_row.expect("family created");
assert_eq!(owner_id, user_id, "families.owner_id should equal user_id");

// default ledger exists with created_by = user_id and is_default = true
#[derive(sqlx::FromRow, Debug)]
struct LedgerRow { id: Uuid, is_default: Option<bool>, created_by: Option<Uuid> }
let ledgers: Vec<LedgerRow> = sqlx::query_as(
"SELECT id, is_default, created_by FROM ledgers WHERE family_id = $1"
)
.bind(family_id)
.fetch_all(&pool)
.await
.expect("query ledgers");
assert_eq!(ledgers.len(), 1, "exactly one default ledger expected");
let l = &ledgers[0];
assert_eq!(l.is_default.unwrap_or(false), true, "ledger should be default");
assert_eq!(l.created_by.unwrap(), user_id, "ledger.created_by should equal user_id");

// Now call export.csv using the token; expect header-only CSV
let req2 = Request::builder()
.method("GET")
.uri("/api/v1/transactions/export.csv?include_header=true")
.header(header::AUTHORIZATION, format!("Bearer {}", token))
.body(Body::empty())
.unwrap();
let resp2 = app.clone().oneshot(req2).await.unwrap();
assert_eq!(resp2.status(), StatusCode::OK, "export.csv should be 200");
let body_bytes = hyper::body::to_bytes(resp2.into_body()).await.unwrap();
let head = String::from_utf8_lossy(&body_bytes);
assert!(head.starts_with("Date,Description"), "CSV header missing or incorrect");

// Cleanup user rows (cascade should remove memberships/related rows)
let _ = sqlx::query("DELETE FROM users WHERE id = $1")
.bind(user_id)
.execute(&pool)
.await;
}
}

1 change: 1 addition & 0 deletions jive-api/tests/integration/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod family_flow_test;
mod transactions_export_test;
mod auth_register_route_e2e_test;
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AccountAdapter extends TypeAdapter<Account> {
..writeByte(6)
..write(obj.description)
..writeByte(7)
..write(obj.color == null ? null : obj.color!.toARGB32())
..write(obj.color?.toARGB32())
..writeByte(8)
..write(obj.isDefault)
..writeByte(9)
Expand Down
Loading