Skip to content

Commit 628b757

Browse files
yhyyzdoctatortot
andcommitted
feat: apply PR #6, #9, #10 from doctatortot
- test(api): install rustls CryptoProvider in setup_app (PR #6) Fixes ~20 api test failures from missing TLS provider - fix(retrieve): replace min-max RRF normalization with scale-and-clamp (PR #9) Preserves absolute quality signal; weak matches no longer inflate to 1.0 RRF_SCALE=61 so ideal dual-leg rank-1 maps to ~1.0. Closes #7 - feat(api): expose memory_type on create and update (PR #10) POST /v1/memories accepts memory_type (pinned|insight|session) PUT /v1/memories/:id can change memory_type Default remains 'pinned' for backwards compat Co-authored-by: doctatortot <doctatortot@users.noreply.github.com>
1 parent 508e9e6 commit 628b757

3 files changed

Lines changed: 224 additions & 46 deletions

File tree

omem-server/src/api/handlers/memory.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub struct CreateMemoryBody {
3737
#[serde(default)]
3838
pub tags: Option<Vec<String>>,
3939
pub source: Option<String>,
40+
pub memory_type: Option<String>,
4041
}
4142

4243
#[derive(Deserialize)]
@@ -95,6 +96,7 @@ pub struct UpdateMemoryBody {
9596
pub content: Option<String>,
9697
pub tags: Option<Vec<String>>,
9798
pub state: Option<String>,
99+
pub memory_type: Option<String>,
98100
}
99101

100102
#[derive(Serialize)]
@@ -204,10 +206,15 @@ pub async fn create_memory(
204206
return Err(OmemError::Validation("content cannot be empty".to_string()));
205207
}
206208

209+
let memory_type = match body.memory_type {
210+
Some(s) => s.parse().map_err(OmemError::Validation)?,
211+
None => MemoryType::Pinned,
212+
};
213+
207214
let mut memory = Memory::new(
208215
&content,
209216
Category::Preferences,
210-
MemoryType::Pinned,
217+
memory_type,
211218
&auth.tenant_id,
212219
);
213220
memory.tags = body.tags.unwrap_or_default();
@@ -553,6 +560,12 @@ pub async fn update_memory(
553560
.map_err(|e: String| OmemError::Validation(e))?;
554561
}
555562

563+
if let Some(memory_type_str) = body.memory_type {
564+
memory.memory_type = memory_type_str
565+
.parse()
566+
.map_err(|e: String| OmemError::Validation(e))?;
567+
}
568+
556569
memory.updated_at = chrono::Utc::now().to_rfc3339();
557570

558571
let vector = if need_reembed {

omem-server/src/api/mod.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,16 @@ mod tests {
5151
}
5252
}
5353

54+
fn install_crypto_provider() {
55+
use std::sync::Once;
56+
static INIT: Once = Once::new();
57+
INIT.call_once(|| {
58+
let _ = rustls::crypto::ring::default_provider().install_default();
59+
});
60+
}
61+
5462
async fn setup_app() -> (axum::Router, tempfile::TempDir) {
63+
install_crypto_provider();
5564
let dir = tempfile::TempDir::new().expect("temp dir");
5665
let uri = dir.path().to_str().expect("path");
5766

@@ -481,6 +490,153 @@ mod tests {
481490
assert_eq!(json["tags"][0], "new-tag");
482491
}
483492

493+
#[tokio::test]
494+
async fn test_create_memory_with_type() {
495+
let (app, _dir) = setup_app().await;
496+
let api_key = create_test_tenant(&app).await;
497+
498+
let create_resp = app
499+
.clone()
500+
.oneshot(
501+
Request::builder()
502+
.method("POST")
503+
.uri("/v1/memories")
504+
.header("content-type", "application/json")
505+
.header("x-api-key", &api_key)
506+
.body(Body::from(
507+
r#"{"content":"an insight","memory_type":"insight"}"#,
508+
))
509+
.expect("request"),
510+
)
511+
.await
512+
.expect("response");
513+
let bytes = create_resp
514+
.into_body()
515+
.collect()
516+
.await
517+
.expect("body")
518+
.to_bytes();
519+
let created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
520+
assert_eq!(created["memory_type"], "insight");
521+
522+
let default_resp = app
523+
.clone()
524+
.oneshot(
525+
Request::builder()
526+
.method("POST")
527+
.uri("/v1/memories")
528+
.header("content-type", "application/json")
529+
.header("x-api-key", &api_key)
530+
.body(Body::from(r#"{"content":"default"}"#))
531+
.expect("request"),
532+
)
533+
.await
534+
.expect("response");
535+
let bytes = default_resp
536+
.into_body()
537+
.collect()
538+
.await
539+
.expect("body")
540+
.to_bytes();
541+
let default_created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
542+
assert_eq!(default_created["memory_type"], "pinned");
543+
}
544+
545+
#[tokio::test]
546+
async fn test_update_memory_type() {
547+
let (app, _dir) = setup_app().await;
548+
let api_key = create_test_tenant(&app).await;
549+
550+
let create_resp = app
551+
.clone()
552+
.oneshot(
553+
Request::builder()
554+
.method("POST")
555+
.uri("/v1/memories")
556+
.header("content-type", "application/json")
557+
.header("x-api-key", &api_key)
558+
.body(Body::from(r#"{"content":"originally pinned"}"#))
559+
.expect("request"),
560+
)
561+
.await
562+
.expect("response");
563+
let bytes = create_resp
564+
.into_body()
565+
.collect()
566+
.await
567+
.expect("body")
568+
.to_bytes();
569+
let created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
570+
let memory_id = created["id"].as_str().expect("id");
571+
assert_eq!(created["memory_type"], "pinned");
572+
573+
let update_resp = app
574+
.clone()
575+
.oneshot(
576+
Request::builder()
577+
.method("PUT")
578+
.uri(format!("/v1/memories/{memory_id}"))
579+
.header("content-type", "application/json")
580+
.header("x-api-key", &api_key)
581+
.body(Body::from(r#"{"memory_type":"insight"}"#))
582+
.expect("request"),
583+
)
584+
.await
585+
.expect("response");
586+
assert_eq!(update_resp.status(), StatusCode::OK);
587+
let bytes = update_resp
588+
.into_body()
589+
.collect()
590+
.await
591+
.expect("body")
592+
.to_bytes();
593+
let json: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
594+
assert_eq!(json["memory_type"], "insight");
595+
}
596+
597+
#[tokio::test]
598+
async fn test_update_memory_type_invalid() {
599+
let (app, _dir) = setup_app().await;
600+
let api_key = create_test_tenant(&app).await;
601+
602+
let create_resp = app
603+
.clone()
604+
.oneshot(
605+
Request::builder()
606+
.method("POST")
607+
.uri("/v1/memories")
608+
.header("content-type", "application/json")
609+
.header("x-api-key", &api_key)
610+
.body(Body::from(r#"{"content":"test"}"#))
611+
.expect("request"),
612+
)
613+
.await
614+
.expect("response");
615+
let bytes = create_resp
616+
.into_body()
617+
.collect()
618+
.await
619+
.expect("body")
620+
.to_bytes();
621+
let created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
622+
let memory_id = created["id"].as_str().expect("id");
623+
624+
let update_resp = app
625+
.clone()
626+
.oneshot(
627+
Request::builder()
628+
.method("PUT")
629+
.uri(format!("/v1/memories/{memory_id}"))
630+
.header("content-type", "application/json")
631+
.header("x-api-key", &api_key)
632+
.body(Body::from(r#"{"memory_type":"bogus"}"#))
633+
.expect("request"),
634+
)
635+
.await
636+
.expect("response");
637+
assert_eq!(update_resp.status(), StatusCode::BAD_REQUEST);
638+
}
639+
484640
#[tokio::test]
485641
async fn test_search_memories() {
486642
let (app, _dir) = setup_app().await;

omem-server/src/retrieve/pipeline.rs

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -300,36 +300,18 @@ impl RetrievalPipeline {
300300
(fused, stage)
301301
}
302302

303-
/// Normalize RRF scores to [0, 1] range so downstream thresholds (min_score, hard_cutoff) work correctly.
304-
/// RRF raw scores are tiny (max ~0.033 for K=60 with 2 legs), but thresholds expect [0, 1].
305-
/// - Multiple results: min-max normalization (best=1.0, worst=0.0)
306-
/// - Single result: scale by RRF_SCALE (40.0) and clamp to [0, 1]
303+
/// Normalize RRF scores into [0, 1] while preserving absolute quality signal.
304+
/// Raw RRF scores are tiny (~1/(K+1) for ideal dual-leg rank-1 match with K=60),
305+
/// but downstream thresholds (min_score, hard_cutoff) expect [0, 1]. We scale by
306+
/// RRF_SCALE so the best-possible hybrid match maps to ~1.0 and clamp; everything
307+
/// weaker stays proportionally smaller.
307308
fn stage_rrf_normalize(mut entries: Vec<FusionEntry>) -> (Vec<FusionEntry>, StageTrace) {
308-
const RRF_SCALE: f32 = 40.0;
309+
const RRF_SCALE: f32 = 61.0;
309310
let stage_start = Instant::now();
310311
let input_count = entries.len();
311312

312-
if entries.len() > 1 {
313-
let max_score = entries
314-
.iter()
315-
.map(|e| e.rrf_score)
316-
.fold(f32::NEG_INFINITY, f32::max);
317-
let min_score = entries
318-
.iter()
319-
.map(|e| e.rrf_score)
320-
.fold(f32::INFINITY, f32::min);
321-
let range = max_score - min_score;
322-
if range > 0.0 {
323-
for entry in &mut entries {
324-
entry.rrf_score = (entry.rrf_score - min_score) / range;
325-
}
326-
} else if max_score > 0.0 {
327-
for entry in &mut entries {
328-
entry.rrf_score = 1.0;
329-
}
330-
}
331-
} else if entries.len() == 1 {
332-
entries[0].rrf_score = (entries[0].rrf_score * RRF_SCALE).min(1.0);
313+
for entry in &mut entries {
314+
entry.rrf_score = (entry.rrf_score * RRF_SCALE).clamp(0.0, 1.0);
333315
}
334316

335317
let score_range = fusion_score_range(&entries);
@@ -1160,48 +1142,64 @@ mod tests {
11601142

11611143
#[test]
11621144
fn test_rrf_normalize_multiple_results() {
1145+
let ideal = 1.0 / 61.0;
11631146
let entries = vec![
1164-
make_entry("best", 0.033),
1165-
make_entry("mid", 0.020),
1166-
make_entry("worst", 0.010),
1147+
make_entry("best", ideal),
1148+
make_entry("mid", ideal * 0.5),
1149+
make_entry("worst", ideal * 0.25),
11671150
];
11681151

11691152
let (result, stage) = RetrievalPipeline::stage_rrf_normalize(entries);
11701153
assert_eq!(stage.name, "rrf_normalize");
11711154
assert_eq!(result.len(), 3);
11721155

11731156
let best = result.iter().find(|e| e.memory.content == "best").unwrap();
1174-
let worst = result.iter().find(|e| e.memory.content == "worst").unwrap();
11751157
let mid = result.iter().find(|e| e.memory.content == "mid").unwrap();
1158+
let worst = result.iter().find(|e| e.memory.content == "worst").unwrap();
11761159

11771160
assert!(
1178-
(best.rrf_score - 1.0).abs() < 1e-6,
1179-
"best should be 1.0, got {}",
1161+
(best.rrf_score - 1.0).abs() < 1e-4,
1162+
"ideal RRF should map to ~1.0, got {}",
11801163
best.rrf_score
11811164
);
11821165
assert!(
1183-
(worst.rrf_score - 0.0).abs() < 1e-6,
1184-
"worst should be 0.0, got {}",
1166+
(mid.rrf_score - 0.5).abs() < 1e-4,
1167+
"half-ideal RRF should map to ~0.5, got {}",
1168+
mid.rrf_score
1169+
);
1170+
assert!(
1171+
(worst.rrf_score - 0.25).abs() < 1e-4,
1172+
"quarter-ideal RRF should map to ~0.25, got {}",
11851173
worst.rrf_score
11861174
);
1175+
assert!(best.rrf_score > mid.rrf_score && mid.rrf_score > worst.rrf_score);
1176+
}
1177+
1178+
#[test]
1179+
fn test_rrf_normalize_weak_top_not_inflated() {
1180+
let entries = vec![
1181+
make_entry("weak-top", 0.003),
1182+
make_entry("weak-mid", 0.002),
1183+
make_entry("weak-bot", 0.001),
1184+
];
1185+
1186+
let (result, _) = RetrievalPipeline::stage_rrf_normalize(entries);
1187+
let top = result.iter().find(|e| e.memory.content == "weak-top").unwrap();
11871188
assert!(
1188-
mid.rrf_score > 0.0 && mid.rrf_score < 1.0,
1189-
"mid should be between 0 and 1, got {}",
1190-
mid.rrf_score
1189+
top.rrf_score < 0.25,
1190+
"weak top result should stay below 0.25, got {}",
1191+
top.rrf_score
11911192
);
11921193
}
11931194

11941195
#[test]
11951196
fn test_rrf_normalize_single_result() {
1196-
let entries = vec![make_entry("only", 0.016)];
1197+
let entries = vec![make_entry("only", 1.0 / 61.0)];
11971198

11981199
let (result, _) = RetrievalPipeline::stage_rrf_normalize(entries);
11991200
assert_eq!(result.len(), 1);
12001201
let score = result[0].rrf_score;
1201-
assert!(
1202-
(score - 0.64).abs() < 1e-4,
1203-
"0.016 * 40 = 0.64, got {score}"
1204-
);
1202+
assert!((score - 1.0).abs() < 1e-4, "1/61 * 61 = 1.0, got {score}");
12051203
}
12061204

12071205
#[test]
@@ -1218,11 +1216,22 @@ mod tests {
12181216

12191217
#[test]
12201218
fn test_rrf_normalize_equal_scores() {
1221-
let entries = vec![make_entry("a", 0.016), make_entry("b", 0.016)];
1219+
let entries = vec![make_entry("a", 1.0 / 61.0), make_entry("b", 1.0 / 61.0)];
1220+
1221+
let (result, _) = RetrievalPipeline::stage_rrf_normalize(entries);
1222+
assert!((result[0].rrf_score - 1.0).abs() < 1e-4);
1223+
assert!((result[1].rrf_score - 1.0).abs() < 1e-4);
1224+
}
1225+
1226+
#[test]
1227+
fn test_rrf_normalize_equal_weak_scores() {
1228+
let entries = vec![make_entry("a", 0.005), make_entry("b", 0.005)];
12221229

12231230
let (result, _) = RetrievalPipeline::stage_rrf_normalize(entries);
1224-
assert!((result[0].rrf_score - 1.0).abs() < 1e-6);
1225-
assert!((result[1].rrf_score - 1.0).abs() < 1e-6);
1231+
let expected = 0.005_f32 * 61.0;
1232+
assert!((result[0].rrf_score - expected).abs() < 1e-4);
1233+
assert!((result[1].rrf_score - expected).abs() < 1e-4);
1234+
assert!(result[0].rrf_score < 0.5);
12261235
}
12271236

12281237
#[test]

0 commit comments

Comments
 (0)