Skip to content

Commit

Permalink
Filter "resharded" points from search, scroll by, count and retrieve …
Browse files Browse the repository at this point in the history
…request results
  • Loading branch information
ffuugoo committed Jun 18, 2024
1 parent 76255ab commit 8739505
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 10 deletions.
6 changes: 5 additions & 1 deletion lib/api/src/grpc/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,11 @@ fn conditions_helper_to_grpc(conditions: Option<Vec<segment::types::Condition>>)
if conditions.is_empty() {
vec![]
} else {
conditions.into_iter().map(|c| c.into()).collect()
conditions
.into_iter()
.filter(|c| !matches!(c, segment::types::Condition::Resharding(_))) // TODO(resharding)!?
.map(|c| c.into())
.collect()
}
}
}
Expand Down
41 changes: 38 additions & 3 deletions lib/collection/src/collection/point_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures::stream::FuturesUnordered;
use futures::{future, StreamExt as _, TryFutureExt, TryStreamExt as _};
use itertools::Itertools;
use segment::data_types::order_by::{Direction, OrderBy};
use segment::types::{ShardKey, WithPayload, WithPayloadInterface};
use segment::types::{Filter, ShardKey, WithPayload, WithPayloadInterface};
use validator::Validate as _;

use super::Collection;
Expand Down Expand Up @@ -207,10 +207,15 @@ impl Collection {

pub async fn scroll_by(
&self,
request: ScrollRequestInternal,
mut request: ScrollRequestInternal,
read_consistency: Option<ReadConsistency>,
shard_selection: &ShardSelectorInternal,
) -> CollectionResult<ScrollResult> {
merge_filters(
&mut request.filter,
self.shards_holder.read().await.resharding_filter(),
);

let default_request = ScrollRequestInternal::default();

let id_offset = request.offset;
Expand Down Expand Up @@ -333,10 +338,15 @@ impl Collection {

pub async fn count(
&self,
request: CountRequestInternal,
mut request: CountRequestInternal,
read_consistency: Option<ReadConsistency>,
shard_selection: &ShardSelectorInternal,
) -> CollectionResult<CountResult> {
merge_filters(
&mut request.filter,
self.shards_holder.read().await.resharding_filter(),
);

let shards_holder = self.shards_holder.read().await;
let shards = shards_holder.select_shards(shard_selection)?;

Expand Down Expand Up @@ -374,8 +384,16 @@ impl Collection {
.unwrap_or(&WithPayloadInterface::Bool(false));
let with_payload = WithPayload::from(with_payload_interface);
let request = Arc::new(request);

#[allow(unused_assignments)] // 🤦‍♀️
let mut resharding_filter = None;

let all_shard_collection_results = {
let shard_holder = self.shards_holder.read().await;

// Get resharding filter, while we hold the lock to shard holder
resharding_filter = shard_holder.resharding_filter_impl();

let target_shards = shard_holder.select_shards(shard_selection)?;
let retrieve_futures = target_shards.into_iter().map(|(shard, shard_key)| {
let shard_key = shard_key.cloned();
Expand All @@ -397,15 +415,32 @@ impl Collection {
Ok(records)
})
});

future::try_join_all(retrieve_futures).await?
};

let mut covered_point_ids = HashSet::new();
let points = all_shard_collection_results
.into_iter()
.flatten()
// If resharding is in progress, and *read* hash-ring is committed, filter-out "resharded" points
.filter(|point| match &resharding_filter {
Some(filter) => filter.check(point.id),
None => true,
})
// Add each point only once, deduplicate point IDs
.filter(|point| covered_point_ids.insert(point.id))
.collect();

Ok(points)
}
}

fn merge_filters(filter: &mut Option<Filter>, resharding_filter: Option<Filter>) {
if let Some(resharding_filter) = resharding_filter {
*filter = Some(match filter.take() {
Some(filter) => filter.merge_owned(resharding_filter),
None => resharding_filter,
});
}
}
16 changes: 15 additions & 1 deletion lib/collection/src/collection/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,25 @@ impl Collection {

async fn do_core_search_batch(
&self,
request: CoreSearchRequestBatch,
mut request: CoreSearchRequestBatch,
read_consistency: Option<ReadConsistency>,
shard_selection: &ShardSelectorInternal,
timeout: Option<Duration>,
) -> CollectionResult<Vec<Vec<ScoredPoint>>> {
if let Some(resharding_filter) = self.shards_holder.read().await.resharding_filter() {
for search in &mut request.searches {
match &mut search.filter {
Some(filter) => {
*filter = filter.merge(&resharding_filter);
}

None => {
search.filter = Some(resharding_filter.clone());
}
}
}
}

let request = Arc::new(request);

let instant = Instant::now();
Expand Down
2 changes: 2 additions & 0 deletions lib/collection/src/shards/resharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct ReshardState {
pub peer_id: PeerId,
pub shard_id: ShardId,
pub shard_key: Option<ShardKey>,
pub filter_read_operations: bool, // TODO(resharding): Add proper resharding state!
}

impl ReshardState {
Expand All @@ -41,6 +42,7 @@ impl ReshardState {
peer_id,
shard_id,
shard_key,
filter_read_operations: false,
}
}

Expand Down
42 changes: 39 additions & 3 deletions lib/collection/src/shards/shard_holder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::time::Duration;
use common::cpu::CpuBudget;
use futures::Future;
use itertools::Itertools;
use segment::types::ShardKey;
use segment::types::{Condition, Filter, ShardKey};
use tar::Builder as TarBuilder;
use tokio::runtime::Handle;
use tokio::sync::{broadcast, RwLock};
Expand All @@ -18,7 +18,7 @@ use super::resharding::{ReshardKey, ReshardState};
use super::transfer::transfer_tasks_pool::TransferTasksPool;
use crate::common::validate_snapshot_archive::validate_open_snapshot_archive;
use crate::config::{CollectionConfig, ShardingMethod};
use crate::hash_ring::HashRing;
use crate::hash_ring::{self, HashRing};
use crate::operations::shard_selector_internal::ShardSelectorInternal;
use crate::operations::shared_storage_config::SharedStorageConfig;
use crate::operations::snapshot_ops::SnapshotDescription;
Expand Down Expand Up @@ -236,7 +236,15 @@ impl ShardHolder {
Ok(())
})?;

todo!()
self.resharding_state.write(|state| {
let Some(state) = state else {
unreachable!();
};

state.filter_read_operations = true; // TODO(resharding): Add proper resharding state!
})?;

Ok(())
}

pub fn commit_write_hashring(&mut self, resharding_key: ReshardKey) -> CollectionResult<()> {
Expand Down Expand Up @@ -353,6 +361,34 @@ impl ShardHolder {
Ok(())
}

pub fn resharding_filter(&self) -> Option<Filter> {
let filter = self.resharding_filter_impl()?;
let filter = Filter::new_must_not(Condition::Resharding(Arc::new(filter)));
Some(filter)
}

pub fn resharding_filter_impl(&self) -> Option<hash_ring::Filter> {
let state = self.resharding_state.read();

let Some(state) = state.deref() else {
return None;
};

if !state.filter_read_operations {
return None;
}

let Some(ring) = self.rings.get(&state.shard_key) else {
return None; // TODO(resharding)!?
};

let HashRing::Resharding { new, .. } = ring else {
return None; // TODO(resharding)!?
};

Some(hash_ring::Filter::new(new.clone(), state.shard_id))
}

pub fn add_shard(
&mut self,
shard_id: ShardId,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;

use common::types::PointOffsetType;
use serde_json::Value;
Expand Down Expand Up @@ -124,7 +124,7 @@ pub fn condition_converter<'a>(
Condition::Resharding(cond) => {
let segment_ids: HashSet<_> = id_tracker
.iter_external()
.filter(|point_id| cond.check(point_id))
.filter(|&point_id| cond.check(point_id))
.filter_map(|external_id| id_tracker.internal_id(external_id))
.collect();

Expand Down

0 comments on commit 8739505

Please sign in to comment.