Skip to content

Commit

Permalink
Issue/483 (#484)
Browse files Browse the repository at this point in the history
* Downcast_ref

* fixing unit test
  • Loading branch information
fulmicoton committed Jan 28, 2019
1 parent e99d1a2 commit 6a547b0
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ futures-cpupool = "0.1"
owning_ref = "0.4"
stable_deref_trait = "1.0.0"
rust-stemmers = "1"
downcast = { version="0.9" }
downcast-rs = { version="1.0" }
matches = "0.1"
bitpacking = "0.5"
census = "0.2"
Expand Down
11 changes: 4 additions & 7 deletions src/collector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ See the `custom_collector` example.
*/

use downcast;
use downcast_rs;
use DocId;
use Result;
use Score;
Expand All @@ -111,9 +111,9 @@ pub use self::facet_collector::FacetCollector;

/// `Fruit` is the type for the result of our collection.
/// e.g. `usize` for the `Count` collector.
pub trait Fruit: Send + downcast::Any {}
pub trait Fruit: Send + downcast_rs::Downcast {}

impl<T> Fruit for T where T: Send + downcast::Any {}
impl<T> Fruit for T where T: Send + downcast_rs::Downcast {}

/// Collectors are in charge of collecting and retaining relevant
/// information from the document found and scored by the query.
Expand Down Expand Up @@ -358,10 +358,7 @@ where
}
}

#[allow(missing_docs)]
mod downcast_impl {
downcast!(super::Fruit);
}
impl_downcast!(Fruit);

#[cfg(test)]
pub mod tests;
10 changes: 4 additions & 6 deletions src/collector/multi_collector.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::Collector;
use super::SegmentCollector;
use collector::Fruit;
use downcast::Downcast;
use std::marker::PhantomData;
use DocId;
use Result;
Expand Down Expand Up @@ -37,11 +36,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
let typed_fruit: Vec<TCollector::Fruit> = children
.into_iter()
.map(|untyped_fruit| {
Downcast::<TCollector::Fruit>::downcast(untyped_fruit)
untyped_fruit.downcast::<TCollector::Fruit>()
.map(|boxed_but_typed| *boxed_but_typed)
.map_err(|e| {
let err_msg = format!("Failed to cast child collector fruit. {:?}", e);
TantivyError::InvalidArgument(err_msg)
.map_err(|_| {
TantivyError::InvalidArgument("Failed to cast child fruit.".to_string())
})
})
.collect::<Result<_>>()?;
Expand Down Expand Up @@ -89,7 +87,7 @@ pub struct FruitHandle<TFruit: Fruit> {
impl<TFruit: Fruit> FruitHandle<TFruit> {
pub fn extract(self, fruits: &mut MultiFruit) -> TFruit {
let boxed_fruit = fruits.sub_fruits[self.pos].take().expect("");
*Downcast::<TFruit>::downcast(boxed_fruit).expect("Failed")
*boxed_fruit.downcast::<TFruit>().map_err(|_| ()).expect("Failed to downcast collector fruit.")
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ extern crate maplit;
extern crate test;

#[macro_use]
extern crate downcast;
extern crate downcast_rs;

#[macro_use]
extern crate fail;
Expand Down
8 changes: 3 additions & 5 deletions src/query/boolean_query/boolean_weight.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::SegmentReader;
use downcast::Downcast;
use downcast_rs::Downcast;
use query::intersect_scorers;
use query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
use query::term_query::TermScorer;
Expand All @@ -10,7 +10,6 @@ use query::RequiredOptionalScorer;
use query::Scorer;
use query::Union;
use query::Weight;
use std::borrow::Borrow;
use std::collections::HashMap;
use Result;

Expand All @@ -25,13 +24,12 @@ where

{
let is_all_term_queries = scorers.iter().all(|scorer| {
let scorer_ref: &Scorer = scorer.borrow();
Downcast::<TermScorer>::is_type(scorer_ref)
scorer.is::<TermScorer>()
});
if is_all_term_queries {
let scorers: Vec<TermScorer> = scorers
.into_iter()
.map(|scorer| *Downcast::<TermScorer>::downcast(scorer).unwrap())
.map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap() ))
.collect();
let scorer: Box<Scorer> = Box::new(Union::<TermScorer, TScoreCombiner>::from(scorers));
return scorer;
Expand Down
18 changes: 8 additions & 10 deletions src/query/boolean_query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod tests {

use super::*;
use collector::tests::TestCollector;
use downcast::Downcast;
use downcast_rs::Downcast;
use query::score_combiner::SumWithCoordsCombiner;
use query::term_query::TermScorer;
use query::Intersection;
Expand Down Expand Up @@ -72,7 +72,7 @@ mod tests {
let searcher = index.searcher();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
assert!(Downcast::<TermScorer>::is_type(&*scorer));
assert!(scorer.is::<TermScorer>());
}

#[test]
Expand All @@ -84,13 +84,13 @@ mod tests {
let query = query_parser.parse_query("+a +b +c").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
assert!(Downcast::<Intersection<TermScorer>>::is_type(&*scorer));
assert!(scorer.is::<Intersection<TermScorer>>());
}
{
let query = query_parser.parse_query("+a +(b c)").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
assert!(Downcast::<Intersection<Box<Scorer>>>::is_type(&*scorer));
assert!(scorer.is::<Intersection<Box<Scorer>>>());
}
}

Expand All @@ -103,18 +103,16 @@ mod tests {
let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
assert!(Downcast::<
RequiredOptionalScorer<Box<Scorer>, Box<Scorer>, SumWithCoordsCombiner>,
>::is_type(&*scorer));
assert!(scorer.is::<RequiredOptionalScorer<Box<Scorer>, Box<Scorer>, SumWithCoordsCombiner>>());
}
{
let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, false).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
println!("{:?}", scorer.type_name());
assert!(Downcast::<TermScorer>::is_type(&*scorer));
assert!(scorer.is::<TermScorer>());
}
}
}


#[test]
pub fn test_boolean_query() {
Expand Down
11 changes: 4 additions & 7 deletions src/query/intersection.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use docset::{DocSet, SkipResult};
use downcast::Downcast;
use query::term_query::TermScorer;
use query::EmptyScorer;
use query::Scorer;
use std::borrow::Borrow;
use DocId;
use Score;
use query::term_query::TermScorer;

/// Returns the intersection scorer.
///
Expand All @@ -27,12 +25,11 @@ pub fn intersect_scorers(mut scorers: Vec<Box<Scorer>>) -> Box<Scorer> {
(Some(left), Some(right)) => {
{
let all_term_scorers = [&left, &right].iter().all(|&scorer| {
let scorer_ref: &Scorer = <Box<Scorer> as Borrow<Scorer>>::borrow(scorer);
Downcast::<TermScorer>::is_type(scorer_ref)
scorer.is::<TermScorer>()
});
if all_term_scorers {
let left = *Downcast::<TermScorer>::downcast(left).unwrap();
let right = *Downcast::<TermScorer>::downcast(right).unwrap();
let left = *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap());
let right = *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap());
return Box::new(Intersection {
left,
right,
Expand Down
9 changes: 3 additions & 6 deletions src/query/query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::Weight;
use core::searcher::Searcher;
use downcast;
use downcast_rs;
use std::collections::BTreeSet;
use std::fmt;
use Result;
Expand Down Expand Up @@ -39,7 +39,7 @@ use Term;
///
/// When implementing a new type of `Query`, it is normal to implement a
/// dedicated `Query`, `Weight` and `Scorer`.
pub trait Query: QueryClone + downcast::Any + fmt::Debug {
pub trait Query: QueryClone + downcast_rs::Downcast + fmt::Debug {
/// Create the weight associated to a query.
///
/// If scoring is not required, setting `scoring_enabled` to `false`
Expand Down Expand Up @@ -96,7 +96,4 @@ impl QueryClone for Box<Query> {
}
}

#[allow(missing_docs)]
mod downcast_impl {
downcast!(super::Query);
}
impl_downcast!(Query);
10 changes: 4 additions & 6 deletions src/query/scorer.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use common::BitSet;
use docset::{DocSet, SkipResult};
use downcast;
use downcast_rs;
use std::ops::DerefMut;
use DocId;
use Score;

/// Scored set of documents matching a query within a specific segment.
///
/// See [`Query`](./trait.Query.html).
pub trait Scorer: downcast::Any + DocSet + 'static {
pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
/// Returns the score.
///
/// This method will perform a bit of computation and is not cached.
Expand All @@ -23,10 +23,8 @@ pub trait Scorer: downcast::Any + DocSet + 'static {
}
}

#[allow(missing_docs)]
mod downcast_impl {
downcast!(super::Scorer);
}
impl_downcast!(Scorer);


impl Scorer for Box<Scorer> {
fn score(&mut self) -> Score {
Expand Down

0 comments on commit 6a547b0

Please sign in to comment.