Skip to content

Commit

Permalink
Fix function score query
Browse files Browse the repository at this point in the history
  • Loading branch information
buinauskas committed Jun 19, 2023
1 parent 4a71715 commit 912a2cf
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 20 deletions.
131 changes: 114 additions & 17 deletions src/search/queries/compound/function_score_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use crate::util::*;
/// # use elasticsearch_dsl::queries::*;
/// # use elasticsearch_dsl::queries::params::*;
/// # let query =
/// Query::function_score(Query::term("test", 1))
/// .function(RandomScore::new())
/// Query::function_score()
/// .query(Query::term("test", 1))
/// .function(RandomScore::new().filter(Query::term("test", 1)).weight(2.0))
/// .function(Weight::new(2.0))
/// .max_boost(2.2)
/// .min_score(2.3)
Expand All @@ -28,7 +29,8 @@ use crate::util::*;
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(remote = "Self")]
pub struct FunctionScoreQuery {
query: Box<Query>,
#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
query: Option<Box<Query>>,

#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
functions: Vec<Function>,
Expand All @@ -54,12 +56,9 @@ pub struct FunctionScoreQuery {

impl Query {
/// Creates an instance of [`FunctionScoreQuery`]
pub fn function_score<T>(query: T) -> FunctionScoreQuery
where
T: Into<Query>,
{
pub fn function_score() -> FunctionScoreQuery {
FunctionScoreQuery {
query: Box::new(query.into()),
query: None,
functions: Default::default(),
max_boost: None,
min_score: None,
Expand All @@ -72,6 +71,15 @@ impl Query {
}

impl FunctionScoreQuery {
/// Base function score query
pub fn query<T>(mut self, query: T) -> Self
where
T: Into<Option<Query>>,
{
self.query = query.into().map(Box::new);
self
}

/// Push function to the list
pub fn function<T>(mut self, function: T) -> Self
where
Expand Down Expand Up @@ -139,16 +147,9 @@ mod tests {
#[test]
fn serialization() {
assert_serialize_query(
Query::function_score(Query::term("test", 1)).function(RandomScore::new()),
Query::function_score().function(RandomScore::new()),
json!({
"function_score": {
"query": {
"term": {
"test": {
"value": 1
}
}
},
"functions": [
{
"random_score": {}
Expand All @@ -159,7 +160,8 @@ mod tests {
);

assert_serialize_query(
Query::function_score(Query::term("test", 1))
Query::function_score()
.query(Query::term("test", 1))
.function(RandomScore::new())
.function(Weight::new(2.0))
.max_boost(2.2)
Expand Down Expand Up @@ -195,4 +197,99 @@ mod tests {
}),
);
}

#[test]
fn issue_24() {
let json = json!({
"function_score": {
"boost_mode": "replace",
"functions": [
{
"filter": { "term": { "type": "stop" } },
"field_value_factor": {
"field": "weight",
"factor": 1.0,
"missing": 1.0
},
"weight": 1.0
},
{
"filter": { "term": { "type": "address" } },
"filter": { "term": { "type": "addr" } },
"field_value_factor": {
"field": "weight",
"factor": 1.0,
"missing": 1.0
},
"weight": 1.0
},
{
"filter": { "term": { "type": "admin" } },
"field_value_factor": {
"field": "weight",
"factor": 1.0,
"missing": 1.0
},
"weight": 1.0
},
{
"filter": { "term": { "type": "poi" } },
"field_value_factor": {
"field": "weight",
"factor": 1.0,
"missing": 1.0
},
"weight": 1.0
},
{
"filter": { "term": { "type": "street" } },
"field_value_factor": {
"field": "weight",
"factor": 1.0,
"missing": 1.0
},
"weight": 1.0
}
]
}
});

let _ = Query::function_score()
.boost_mode(FunctionBoostMode::Replace)
.function(
FieldValueFactor::new("weight")
.factor(1.0)
.missing(1.0)
.weight(1.0)
.filter(Query::term("type", "stop")),
)
.function(
FieldValueFactor::new("weight")
.factor(1.0)
.missing(1.0)
.weight(1.0)
.filter(Query::terms("type", ["address", "addr"])),
)
.function(
FieldValueFactor::new("weight")
.factor(1.0)
.missing(1.0)
.weight(1.0)
.filter(Query::term("type", "admin")),
)
.function(
FieldValueFactor::new("weight")
.factor(1.0)
.missing(1.0)
.weight(1.0)
.filter(Query::term("type", "poi")),
)
.function(
FieldValueFactor::new("weight")
.factor(1.0)
.missing(1.0)
.weight(1.0)
.filter(Query::term("type", "street")),
);
}
}
109 changes: 106 additions & 3 deletions src/search/queries/params/function_score_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,29 @@ impl Function {
///
/// This can sometimes be desired since boost value set on specific queries gets normalized, while
/// for this score function it does not
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize)]
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct Weight {
weight: f32,
#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Query>,
}

impl Weight {
/// Creates an instance of [Weight](Weight)
pub fn new(weight: f32) -> Self {
Self { weight }
Self {
weight,
filter: None,
}
}

/// Add function filter
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}
}

Expand All @@ -202,6 +216,12 @@ impl Weight {
#[derive(Debug, Default, Clone, PartialEq, Serialize)]
pub struct RandomScore {
random_score: RandomScoreInner,

#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Query>,

#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
weight: Option<f32>,
}

#[derive(Debug, Default, Clone, PartialEq, Serialize)]
Expand All @@ -219,6 +239,26 @@ impl RandomScore {
Default::default()
}

/// Add function filter
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}

/// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
/// since boost value set on specific queries gets normalized, while for this score function it does not.
/// The number value is of type float.
pub fn weight<T>(mut self, weight: T) -> Self
where
T: num_traits::AsPrimitive<f32>,
{
self.weight = Some(weight.as_());
self
}

/// Sets seed value
pub fn seed<T>(mut self, seed: T) -> Self
where
Expand Down Expand Up @@ -262,6 +302,12 @@ impl RandomScore {
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct FieldValueFactor {
field_value_factor: FieldValueFactorInner,

#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Query>,

#[serde(skip_serializing_if = "ShouldSkip::should_skip")]
weight: Option<f32>,
}

#[derive(Debug, Clone, PartialEq, Serialize)]
Expand Down Expand Up @@ -293,9 +339,31 @@ impl FieldValueFactor {
modifier: None,
missing: None,
},
filter: None,
weight: None,
}
}

/// Add function filter
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}

/// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
/// since boost value set on specific queries gets normalized, while for this score function it does not.
/// The number value is of type float.
pub fn weight<T>(mut self, weight: T) -> Self
where
T: num_traits::AsPrimitive<f32>,
{
self.weight = Some(weight.as_());
self
}

/// Factor to multiply the field value with
pub fn factor(mut self, factor: f32) -> Self {
self.field_value_factor.factor = Some(factor);
Expand Down Expand Up @@ -403,7 +471,12 @@ impl_origin_for_numbers![i8, i16, i32, i64, u8, u16, u32, u64, f32, f64];
#[derive(Debug, Clone, PartialEq)]
pub struct Decay<T: Origin> {
function: DecayFunction,

inner: DecayFieldInner<T>,

filter: Option<Query>,

weight: Option<f32>,
}
#[derive(Debug, Clone, PartialEq)]
struct DecayFieldInner<T: Origin> {
Expand Down Expand Up @@ -458,9 +531,31 @@ where
decay: None,
},
},
filter: None,
weight: None,
}
}

/// Add function filter
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}

/// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
/// since boost value set on specific queries gets normalized, while for this score function it does not.
/// The number value is of type float.
pub fn weight<T>(mut self, weight: T) -> Self
where
T: num_traits::AsPrimitive<f32>,
{
self.weight = Some(weight.as_());
self
}

/// If an `offset` is defined, the decay function will only compute the decay function for
/// documents with a distance greater than the defined `offset`.
///
Expand All @@ -483,10 +578,18 @@ impl<T: Origin> Serialize for Decay<T> {
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(1))?;
let mut map = serializer.serialize_map(Some(3))?;

map.serialize_entry(&self.function, &self.inner)?;

if let Some(filter) = &self.filter {
map.serialize_entry("filter", filter)?;
}

if let Some(weight) = &self.weight {
map.serialize_entry("weight", weight)?;
}

map.end()
}
}
Expand Down

0 comments on commit 912a2cf

Please sign in to comment.