Skip to content

Commit

Permalink
minimal test for querying collection
Browse files Browse the repository at this point in the history
  • Loading branch information
generall committed Jun 19, 2024
1 parent 6268291 commit 31842ce
Show file tree
Hide file tree
Showing 11 changed files with 298 additions and 65 deletions.
29 changes: 11 additions & 18 deletions src/builder_ext.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,8 @@
use crate::builder_types::RecommendExample;
use crate::qdrant::{
shard_key, BinaryQuantizationBuilder, ClearPayloadPointsBuilder, ContextExamplePair,
CountPointsBuilder, CreateAliasBuilder, CreateCollectionBuilder,
CreateFieldIndexCollectionBuilder, CreateShardKeyRequestBuilder, DeleteCollectionBuilder,
DeleteFieldIndexCollectionBuilder, DeletePayloadPointsBuilder, DeletePointVectorsBuilder,
DeletePointsBuilder, DeleteShardKey, DeleteShardKeyRequestBuilder, DiscoverBatchPointsBuilder,
DiscoverPoints, DiscoverPointsBuilder, Distance, GetPointsBuilder, LookupLocationBuilder,
OrderByBuilder, PayloadExcludeSelector, PayloadIncludeSelector, PointId, PointStruct,
PointVectors, PointsUpdateOperation, ProductQuantizationBuilder, QuantizationType,
RecommendBatchPointsBuilder, RecommendPointGroups, RecommendPointGroupsBuilder,
RecommendPoints, RecommendPointsBuilder, RenameAliasBuilder, ScalarQuantizationBuilder,
ScrollPointsBuilder, SearchBatchPointsBuilder, SearchPointGroupsBuilder, SearchPoints,
SearchPointsBuilder, SetPayloadPointsBuilder, ShardKey, TextIndexParamsBuilder, TokenizerType,
UpdateBatchPointsBuilder, UpdateCollectionBuilder, UpdateCollectionClusterSetupRequestBuilder,
UpdatePointVectorsBuilder, UpsertPointsBuilder, Value, VectorParamsBuilder, VectorsSelector,
WithLookupBuilder,
};
use std::collections::HashMap;

use crate::builder_types::RecommendExample;
use crate::qdrant::{BinaryQuantizationBuilder, ClearPayloadPointsBuilder, ContextExamplePair, CountPointsBuilder, CreateAliasBuilder, CreateCollectionBuilder, CreateFieldIndexCollectionBuilder, CreateShardKeyRequestBuilder, DeleteCollectionBuilder, DeleteFieldIndexCollectionBuilder, DeletePayloadPointsBuilder, DeletePointsBuilder, DeletePointVectorsBuilder, DeleteShardKey, DeleteShardKeyRequestBuilder, DiscoverBatchPointsBuilder, DiscoverPoints, DiscoverPointsBuilder, Distance, GetPointsBuilder, LookupLocationBuilder, OrderByBuilder, PayloadExcludeSelector, PayloadIncludeSelector, PointId, PointStruct, PointsUpdateOperation, PointVectors, ProductQuantizationBuilder, QuantizationType, QueryPointsBuilder, RecommendBatchPointsBuilder, RecommendPointGroups, RecommendPointGroupsBuilder, RecommendPoints, RecommendPointsBuilder, RenameAliasBuilder, ScalarQuantizationBuilder, ScrollPointsBuilder, SearchBatchPointsBuilder, SearchPointGroupsBuilder, SearchPoints, SearchPointsBuilder, SetPayloadPointsBuilder, shard_key, ShardKey, TextIndexParamsBuilder, TokenizerType, UpdateBatchPointsBuilder, UpdateCollectionBuilder, UpdateCollectionClusterSetupRequestBuilder, UpdatePointVectorsBuilder, UpsertPointsBuilder, Value, VectorParamsBuilder, VectorsSelector, WithLookupBuilder};

impl VectorParamsBuilder {
pub fn new(size: u64, distance: Distance) -> Self {
let mut builder = Self::empty();
Expand Down Expand Up @@ -469,3 +454,11 @@ impl RenameAliasBuilder {
builder
}
}

impl QueryPointsBuilder {
pub fn new(collection_name: impl Into<String>) -> Self {
let mut builder = Self::empty();
builder.collection_name = Some(collection_name.into());
builder
}
}
1 change: 1 addition & 0 deletions src/grpc_conversions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod extensions;
pub mod primitives;
mod query;

use crate::client::Payload;
use crate::qdrant::payload_index_params::IndexParams;
Expand Down
9 changes: 9 additions & 0 deletions src/grpc_conversions/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// use crate::qdrant::{Query, query};
//
// impl From<Vec<f32>> for Query {
// fn from(v: Vec<f32>) -> Self {
// Query {
// variant: Some(query::Variant::Nearest()),
// }
// }
// }
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
//! [`Qdrant`](qdrant_client::Qdrant) client:
//! ```
//!# use qdrant_client::prelude::*;
//! use qdrant_client::qdrant_client::config::QdrantConfig;
//!# use qdrant_client::qdrant_client::Qdrant;
//!# use qdrant_client::qdrant_client::errors::QdrantError;
//!# fn establish_connection(url: &str) -> Result<Qdrant, QdrantError> {
//! let mut config = QdrantClientConfig::from_url(url);
//! let mut config = QdrantConfig::from_url(url);
//! config.api_key = std::env::var("QDRANT_API_KEY").ok();
//! Qdrant::new(Some(config))
//!# }
Expand Down Expand Up @@ -135,6 +136,7 @@ mod tests {
};
use crate::qdrant_client::Qdrant;
use std::collections::HashMap;
use crate::qdrant_client::config::QdrantConfig;

#[test]
fn display() {
Expand Down Expand Up @@ -184,7 +186,7 @@ mod tests {

#[tokio::test]
async fn test_qdrant_queries() -> anyhow::Result<()> {
let config = QdrantClientConfig::from_url("http://localhost:6334");
let config = QdrantConfig::from_url("http://localhost:6334");
let client = Qdrant::new(Some(config))?;

let health = client.health_check().await?;
Expand Down
8 changes: 8 additions & 0 deletions src/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4557,6 +4557,11 @@ pub struct PrefetchQuery {
pub lookup_from: ::core::option::Option<LookupLocation>,
}
#[derive(derive_builder::Builder)]
#[builder(
build_fn(private, name = "build_inner"),
pattern = "owned",
custom_constructor
)]
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct QueryPoints {
Expand All @@ -4566,6 +4571,7 @@ pub struct QueryPoints {
pub collection_name: ::prost::alloc::string::String,
/// Sub-requests to perform first. If present, the query will be performed on the results of the prefetches.
#[prost(message, repeated, tag = "2")]
#[builder(default, setter(into, strip_option), field(vis = "pub(crate)"))]
pub prefetch: ::prost::alloc::vec::Vec<PrefetchQuery>,
/// Query to perform. If missing, returns points ordered by their IDs.
#[prost(message, optional, tag = "3")]
Expand All @@ -4581,6 +4587,7 @@ pub struct QueryPoints {
pub filter: ::core::option::Option<Filter>,
/// Search params for when there is no prefetch.
#[prost(message, optional, tag = "6")]
#[builder(default, setter(into, strip_option), field(vis = "pub(crate)"))]
pub search_params: ::core::option::Option<SearchParams>,
/// Return points with scores better than this threshold.
#[prost(float, optional, tag = "7")]
Expand Down Expand Up @@ -8690,6 +8697,7 @@ builder_type_conversions!(RecommendBatchPoints, RecommendBatchPointsBuilder, tru
builder_type_conversions!(RecommendPointGroups, RecommendPointGroupsBuilder, true);
builder_type_conversions!(DiscoverPoints, DiscoverPointsBuilder, true);
builder_type_conversions!(DiscoverBatchPoints, DiscoverBatchPointsBuilder, true);
builder_type_conversions!(QueryPoints, QueryPointsBuilder, true);
builder_type_conversions!(CountPoints, CountPointsBuilder, true);
builder_type_conversions!(CreateFieldIndexCollection, CreateFieldIndexCollectionBuilder, true);
builder_type_conversions!(DeleteFieldIndexCollection, DeleteFieldIndexCollectionBuilder, true);
Expand Down
4 changes: 2 additions & 2 deletions src/qdrant_client/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ impl Qdrant {
#[cfg(test)]
mod tests {
use super::*;
use crate::client::QdrantClientConfig;
use crate::payload::Payload;
use crate::prelude::Distance;
use crate::qdrant::{
Expand All @@ -222,10 +221,11 @@ mod tests {
};
use std::time::Duration;
use tokio::time::sleep;
use crate::qdrant_client::config::QdrantConfig;

#[tokio::test]
async fn create_collection_and_do_the_search() -> Result<()> {
let config = QdrantClientConfig::from_url("http://localhost:6334");
let config = QdrantConfig::from_url("http://localhost:6334");
let client = Qdrant::new(Some(config))?;

let collection_name = "test2";
Expand Down
184 changes: 184 additions & 0 deletions src/qdrant_client/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use crate::qdrant_client::Qdrant;
use std::time::Duration;
use crate::prelude::QdrantError;

pub struct QdrantConfig {
pub uri: String,
pub timeout: Duration,
pub connect_timeout: Duration,
pub keep_alive_while_idle: bool,

/// API key or token to use for authorization
pub api_key: Option<String>,
pub compression: Option<CompressionEncoding>,
}

impl QdrantConfig {
pub fn from_url(url: &str) -> Self {
QdrantConfig {
uri: url.to_string(),
..Self::default()
}
}

/// Sets the API key or token
pub fn set_api_key(&mut self, api_key: &str) {
self.api_key = Some(api_key.to_string());
}

pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}

pub fn set_connect_timeout(&mut self, connect_timeout: Duration) {
self.connect_timeout = connect_timeout;
}

pub fn set_keep_alive_while_idle(&mut self, keep_alive_while_idle: bool) {
self.keep_alive_while_idle = keep_alive_while_idle;
}

pub fn set_compression(&mut self, compression: Option<CompressionEncoding>) {
self.compression = compression;
}

/// set the API key, builder-like. The API key argument can be any of
/// `&str`, `String`, `Option<&str>``, `Option<String>` or `Result<String>`.`
///
/// # Examples:
///
/// A typical use case might be getting the key from an env var:
/// ```rust, no_run
/// use qdrant_client::prelude::*;
///
/// let client = Qdrant::from_url("localhost:6334")
/// .with_api_key(std::env::var("QDRANT_API_KEY"))
/// .build();
/// ```
/// Another possibility might be getting it out of some config
/// ```rust, no_run
/// use qdrant_client::prelude::*;
///# use std::collections::HashMap;
///# let config: HashMap<&str, String> = HashMap::new();
/// let client = QdrantConfig::from_url("localhost:6334")
/// .with_api_key(config.get("api_key"))
/// .build();
/// ```
pub fn with_api_key(mut self, api_key: impl MaybeApiKey) -> Self {
self.api_key = api_key.maybe_key();
self
}

/// Configure the service to keep the connection alive while idle
pub fn keep_alive_while_idle(mut self) -> Self {
self.keep_alive_while_idle = true;
self
}

/// Set the timeout for this client
pub fn with_timeout(mut self, timeout: impl AsTimeout) -> Self {
self.timeout = timeout.timeout();
self
}

/// Set the connect timeout for this client
pub fn with_connect_timeout(mut self, timeout: impl AsTimeout) -> Self {
self.connect_timeout = timeout.timeout();
self
}

/// Set the compression to use for this client
pub fn with_compression(mut self, compression: Option<CompressionEncoding>) -> Self {
self.compression = compression;
self
}

/// Build the Qdrant
pub fn build(self) -> Result<Qdrant, QdrantError> {
Qdrant::new(Some(self))
}
}

impl Default for QdrantConfig {
fn default() -> Self {
Self {
uri: String::from("http://localhost:6334"),
timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
keep_alive_while_idle: true,
api_key: None,
compression: None,
}
}
}

/// The type of compression to use for requests.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionEncoding {
Gzip,
}

impl From<CompressionEncoding> for tonic::codec::CompressionEncoding {
fn from(encoding: CompressionEncoding) -> Self {
match encoding {
CompressionEncoding::Gzip => tonic::codec::CompressionEncoding::Gzip,
}
}
}

pub trait AsTimeout {
fn timeout(self) -> Duration;
}

impl AsTimeout for Duration {
fn timeout(self) -> Duration {
self
}
}

impl AsTimeout for u64 {
fn timeout(self) -> Duration {
Duration::from_secs(self)
}
}

/// Helper thread to allow setting an API key from various types
pub trait MaybeApiKey {
fn maybe_key(self) -> Option<String>;
}

impl MaybeApiKey for &str {
fn maybe_key(self) -> Option<String> {
Some(self.to_string())
}
}

impl MaybeApiKey for String {
fn maybe_key(self) -> Option<String> {
Some(self)
}
}

impl MaybeApiKey for Option<String> {
fn maybe_key(self) -> Option<String> {
self
}
}

impl MaybeApiKey for Option<&String> {
fn maybe_key(self) -> Option<String> {
self.map(ToOwned::to_owned)
}
}

impl MaybeApiKey for Option<&str> {
fn maybe_key(self) -> Option<String> {
self.map(ToOwned::to_owned)
}
}

impl<E: Sized> MaybeApiKey for Result<String, E> {
fn maybe_key(self) -> Option<String> {
self.ok()
}
}
15 changes: 9 additions & 6 deletions src/qdrant_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ pub mod errors;
mod points;
pub mod sharding_keys;
pub mod snapshot;
mod query;
pub mod config;

use crate::channel_pool::ChannelPool;
use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest};
Expand All @@ -12,27 +14,28 @@ use tonic::transport::{Channel, Uri};
use tonic::Status;

pub use crate::auth::TokenInterceptor;
pub use crate::config::{AsTimeout, CompressionEncoding, MaybeApiKey, QdrantClientConfig};
pub use crate::config::{AsTimeout, CompressionEncoding, MaybeApiKey};
pub use crate::payload::Payload;
use crate::qdrant_client::config::QdrantConfig;
use crate::qdrant_client::errors::QdrantError;

pub type Result<T> = std::result::Result<T, QdrantError>;

/// A builder type for `QdrantClient`s
pub type QdrantClientBuilder = QdrantClientConfig;
pub type QdrantBuilder = QdrantConfig;

pub struct Qdrant {
pub channel: ChannelPool,
pub cfg: QdrantClientConfig,
pub cfg: QdrantConfig,
}

impl Qdrant {
/// Create a builder to setup the client
pub fn from_url(url: &str) -> QdrantClientBuilder {
QdrantClientBuilder::from_url(url)
pub fn from_url(url: &str) -> QdrantBuilder {
QdrantBuilder::from_url(url)
}

pub fn new(cfg: Option<QdrantClientConfig>) -> Result<Self> {
pub fn new(cfg: Option<QdrantConfig>) -> Result<Self> {
let cfg = cfg.unwrap_or_default();

let channel = ChannelPool::new(
Expand Down
Loading

0 comments on commit 31842ce

Please sign in to comment.