Skip to content

Commit

Permalink
Allow to include vector into search result (#176)
Browse files Browse the repository at this point in the history
* feat(#50): include vector into search result

allow to specify 'with_vector' parameter in search api to get search results vector data

* test(#50): fix tests

* chore(#50): apply cargo fmt

* chore(#50): update api docs

run tools/generate_openapi_models.sh

Co-authored-by: Daniil Sunyaev <dasforrum@gmail.com>
  • Loading branch information
Daniil and daniilsunyaev committed Jan 3, 2022
1 parent b53c752 commit 1ad529c
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 22 deletions.
16 changes: 15 additions & 1 deletion docs/redoc/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,7 @@
"additionalProperties": {
"$ref": "#/components/schemas/PayloadType"
},
"description": "Payload storage",
"description": "Payload - values assigned to the point",
"nullable": true,
"type": "object"
},
Expand All @@ -1539,6 +1539,15 @@
"format": "float",
"type": "number"
},
"vector": {
"description": "Vector of the point",
"items": {
"format": "float",
"type": "number"
},
"nullable": true,
"type": "array"
},
"version": {
"description": "Point version",
"format": "uint64",
Expand Down Expand Up @@ -1685,6 +1694,11 @@
}
],
"description": "Payload interface"
},
"with_vector": {
"description": "Return point vector with the result. Default: false",
"nullable": true,
"type": "boolean"
}
},
"required": [
Expand Down
1 change: 1 addition & 0 deletions lib/collection/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ impl Collection {
})]),
}),
with_payload: None,
with_vector: None,
params: request.params,
top: request.top,
};
Expand Down
36 changes: 26 additions & 10 deletions lib/collection/src/collection_manager/holders/proxy_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ impl SegmentEntry for ProxySegment {
&self,
vector: &[VectorElementType],
with_payload: &WithPayload,
with_vector: bool,
filter: Option<&Filter>,
top: usize,
params: Option<&SearchParams>,
Expand All @@ -137,22 +138,30 @@ impl SegmentEntry for ProxySegment {
self.wrapped_segment.get().read().search(
vector,
with_payload,
with_vector,
Some(&wrapped_filter),
top,
params,
)?
} else {
self.wrapped_segment
.get()
.read()
.search(vector, with_payload, filter, top, params)?
self.wrapped_segment.get().read().search(
vector,
with_payload,
with_vector,
filter,
top,
params,
)?
};

let mut write_result =
self.write_segment
.get()
.read()
.search(vector, with_payload, filter, top, params)?;
let mut write_result = self.write_segment.get().read().search(
vector,
with_payload,
with_vector,
filter,
top,
params,
)?;

wrapped_result.append(&mut write_result);
Ok(wrapped_result)
Expand Down Expand Up @@ -452,7 +461,14 @@ mod tests {

let query_vector = vec![1.0, 1.0, 1.0, 1.0];
let search_result = proxy_segment
.search(&query_vector, &WithPayload::default(), None, 10, None)
.search(
&query_vector,
&WithPayload::default(),
false,
None,
10,
None,
)
.unwrap();

eprintln!("search_result = {:#?}", search_result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,13 @@ async fn search_in_segment(
.as_ref()
.unwrap_or(&WithPayloadInterface::Bool(false));
let with_payload = WithPayload::from(with_payload_interface);

let with_vector = request.with_vector.unwrap_or(false);

let res = segment.get().read().search(
&request.vector,
&with_payload,
with_vector,
request.filter.as_ref(),
request.top,
request.params.as_ref(),
Expand Down Expand Up @@ -171,6 +175,7 @@ mod tests {
let req = Arc::new(SearchRequest {
vector: query,
with_payload: None,
with_vector: None,
filter: None,
params: None,
top: 5,
Expand Down
2 changes: 2 additions & 0 deletions lib/collection/src/operations/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ pub struct SearchRequest {
pub top: usize,
/// Payload interface
pub with_payload: Option<WithPayloadInterface>,
/// Return point vector with the result. Default: false
pub with_vector: Option<bool>,
}

/// Recommendation request.
Expand Down
11 changes: 5 additions & 6 deletions lib/collection/tests/collection_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async fn test_collection_updater() {
let search_request = SearchRequest {
vector: vec![1.0, 1.0, 1.0, 1.0],
with_payload: None,
with_vector: None,
filter: None,
params: None,
top: 3,
Expand All @@ -76,7 +77,7 @@ async fn test_collection_updater() {
}

#[tokio::test]
async fn test_collection_search_with_payload() {
async fn test_collection_search_with_payload_and_vector() {
let collection_dir = TempDir::new("collection").unwrap();

let collection = simple_collection_fixture(collection_dir.path()).await;
Expand All @@ -103,6 +104,7 @@ async fn test_collection_search_with_payload() {
let search_request = SearchRequest {
vector: vec![1.0, 0.0, 1.0, 1.0],
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: Some(true),
filter: None,
params: None,
top: 3,
Expand All @@ -121,11 +123,8 @@ async fn test_collection_search_with_payload() {
Ok(res) => {
assert_eq!(res.len(), 2);
assert_eq!(res[0].id, 0);
if let Some(payload) = &res[0].payload {
assert_eq!(payload.len(), 1)
} else {
panic!("Payload was expected")
}
assert_eq!(res[0].payload.as_ref().unwrap().len(), 1);
assert_eq!(&res[0].vector, &Some(vec![1.0, 0.0, 1.0, 1.0]));
}
Err(err) => panic!("search failed: {:?}", err),
}
Expand Down
1 change: 1 addition & 0 deletions lib/segment/src/entry/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub trait SegmentEntry {
&self,
vector: &[VectorElementType],
with_payload: &WithPayload,
with_vector: bool,
filter: Option<&Filter>,
top: usize,
params: Option<&SearchParams>,
Expand Down
11 changes: 11 additions & 0 deletions lib/segment/src/segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ impl SegmentEntry for Segment {
&self,
vector: &[VectorElementType],
with_payload: &WithPayload,
with_vector: bool,
filter: Option<&Filter>,
top: usize,
params: Option<&SearchParams>,
Expand Down Expand Up @@ -262,11 +263,19 @@ impl SegmentEntry for Segment {
} else {
None
};

let vector = if with_vector {
Some(self.vector(point_id)?)
} else {
None
};

Ok(ScoredPoint {
id: point_id,
version: point_version,
score: scored_point_offset.score,
payload,
vector,
})
})
.collect();
Expand Down Expand Up @@ -678,6 +687,7 @@ mod tests {
.search(
&[1.0, 1.0],
&WithPayload::default(),
false,
Some(&filter_valid),
1,
None,
Expand All @@ -689,6 +699,7 @@ mod tests {
.search(
&[1.0, 1.0],
&WithPayload::default(),
false,
Some(&filter_invalid),
1,
None,
Expand Down
4 changes: 3 additions & 1 deletion lib/segment/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ pub struct ScoredPoint {
pub version: SeqNumberType,
/// Points vector distance to the query vector
pub score: ScoreType,
/// Payload storage
/// Payload - values assigned to the point
pub payload: Option<TheMap<PayloadKeyType, PayloadType>>,
/// Vector of the point
pub vector: Option<Vec<VectorElementType>>,
}

impl Eq for ScoredPoint {}
Expand Down
2 changes: 2 additions & 0 deletions lib/segment/tests/payload_index_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ mod tests {
.search(
&query_vector,
&WithPayload::default(),
false,
Some(&query_filter),
5,
None,
Expand All @@ -156,6 +157,7 @@ mod tests {
.search(
&query_vector,
&WithPayload::default(),
false,
Some(&query_filter),
5,
None,
Expand Down
11 changes: 9 additions & 2 deletions lib/segment/tests/segment_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod tests {
let query_vector = vec![1.0, 1.0, 1.0, 1.0];

let res = segment
.search(&query_vector, &WithPayload::default(), None, 1, None)
.search(&query_vector, &WithPayload::default(), false, None, 1, None)
.unwrap();

let best_match = res.get(0).expect("Non-empty result");
Expand All @@ -35,7 +35,14 @@ mod tests {
};

let res = segment
.search(&query_vector, &WithPayload::default(), Some(&frt), 1, None)
.search(
&query_vector,
&WithPayload::default(),
false,
Some(&frt),
1,
None,
)
.unwrap();

let best_match = res.get(0).expect("Non-empty result");
Expand Down
16 changes: 15 additions & 1 deletion openapi/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,7 @@
"minimum": 0
},
"payload": {
"description": "Payload storage",
"description": "Payload - values assigned to the point",
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/PayloadType"
Expand All @@ -1522,6 +1522,15 @@
"type": "number",
"format": "float"
},
"vector": {
"description": "Vector of the point",
"type": "array",
"items": {
"type": "number",
"format": "float"
},
"nullable": true
},
"version": {
"description": "Point version",
"type": "integer",
Expand Down Expand Up @@ -1667,6 +1676,11 @@
"nullable": true
}
]
},
"with_vector": {
"description": "Return point vector with the result. Default: false",
"type": "boolean",
"nullable": true
}
}
},
Expand Down
13 changes: 12 additions & 1 deletion openapi/openapi-merged.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1762,13 +1762,20 @@ components:
payload:
additionalProperties:
$ref: '#/components/schemas/PayloadType'
description: Payload storage
description: Payload - values assigned to the point
nullable: true
type: object
score:
description: Points vector distance to the query vector
format: float
type: number
vector:
description: Vector of the point
items:
format: float
type: number
nullable: true
type: array
version:
description: Point version
format: uint64
Expand Down Expand Up @@ -1865,6 +1872,10 @@ components:
- $ref: '#/components/schemas/WithPayloadInterface'
- nullable: true
description: Payload interface
with_vector:
description: 'Return point vector with the result. Default: false'
nullable: true
type: boolean
required:
- top
- vector
Expand Down

0 comments on commit 1ad529c

Please sign in to comment.