Skip to content

Commit

Permalink
Allow tests to include anonymous identifiers in filter (#6378)
Browse files Browse the repository at this point in the history
## What is the goal of this PR?

For the purpose of testing reasoner, we need to be able to return answers to anonymous variables, for queries constructed using backend TypeQL patterns. This is almost doable, except that our filters only allow named identifiers. This change allows us to pass anonymous identifiers in the filter, and calling the `executeTraversal`/`executeReasoner` methods directly, which are now public.

## What are the changes implemented in this PR?

* expose to methods to call traversal/reasoning directly on `Reasoner` class
* allow passing anonymous identifiers in the query filter, which allows tests to receive anonymous variables' answers. This avoids the need to rewrite or deanonimize variables.
  • Loading branch information
flyingsilverfin committed Jun 18, 2021
1 parent 38c676f commit afa09bd
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 46 deletions.
22 changes: 12 additions & 10 deletions reasoner/Reasoner.java
Expand Up @@ -115,12 +115,14 @@ public FunctionalIterator<ConceptMap> execute(Disjunction disjunction, TypeQLMat
throw TypeDBException.of(UNSATISFIABLE_PATTERN, disjunction, causes);
}

if (mayReason(disjunction, context)) return executeReasoner(disjunction, modifiers, context);
if (mayReason(disjunction, context)) return executeReasoner(disjunction, filter(modifiers.filter()), context);
else return executeTraversal(disjunction, context, filter(modifiers.filter()));
}

private Set<Identifier.Variable.Name> filter(List<UnboundVariable> typeQLVars) {
return iterate(typeQLVars).map(v -> v.reference().asName()).map(Identifier.Variable::of).toSet();
private Set<Identifier.Variable.Retrievable> filter(List<UnboundVariable> typeQLVars) {
Set<Identifier.Variable.Retrievable> names = new HashSet<>();
iterate(typeQLVars).map(v -> v.reference().asName()).map(Identifier.Variable::of).forEachRemaining(names::add);
return names;
}

private Set<Conjunction> incoherentConjunctions(Disjunction disjunction) {
Expand All @@ -135,16 +137,16 @@ private Set<Conjunction> incoherentConjunctions(Disjunction disjunction) {
return causes;
}

private FunctionalIterator<ConceptMap> executeReasoner(Disjunction disjunction, TypeQLMatch.Modifiers modifiers,
public FunctionalIterator<ConceptMap> executeReasoner(Disjunction disjunction, Set<Identifier.Variable.Retrievable> filter,
Context.Query context) {
ReasonerProducer producer = disjunction.conjunctions().size() == 1
? new ReasonerProducer(disjunction.conjunctions().get(0), modifiers, context.options(), resolverRegistry, explainablesManager)
: new ReasonerProducer(disjunction, modifiers, context.options(), resolverRegistry, explainablesManager);
? new ReasonerProducer(disjunction.conjunctions().get(0), filter, context.options(), resolverRegistry, explainablesManager)
: new ReasonerProducer(disjunction, filter, context.options(), resolverRegistry, explainablesManager);
return produce(producer, context.producer(), async1());
}

private FunctionalIterator<ConceptMap> executeTraversal(Disjunction disjunction, Context.Query context,
Set<Identifier.Variable.Name> filter) {
public FunctionalIterator<ConceptMap> executeTraversal(Disjunction disjunction, Context.Query context,
Set<Identifier.Variable.Retrievable> filter) {
FunctionalIterator<ConceptMap> answers;
FunctionalIterator<Conjunction> conjs = iterate(disjunction.conjunctions());
if (!context.options().parallel()) answers = conjs.flatMap(conj -> iterator(conj, filter, context));
Expand All @@ -153,7 +155,7 @@ private FunctionalIterator<ConceptMap> executeTraversal(Disjunction disjunction,
return answers;
}

private Producer<ConceptMap> producer(Conjunction conjunction, Set<Identifier.Variable.Name> filter,
private Producer<ConceptMap> producer(Conjunction conjunction, Set<Identifier.Variable.Retrievable> filter,
Context.Query context) {
if (conjunction.negations().isEmpty()) {
return traversalEng.producer(
Expand All @@ -176,7 +178,7 @@ private FunctionalIterator<ConceptMap> iterator(Conjunction conjunction, Concept
return iterator(bound(conjunction, bounds), set(), defaultContext);
}

private FunctionalIterator<ConceptMap> iterator(Conjunction conjunction, Set<Identifier.Variable.Name> filter,
private FunctionalIterator<ConceptMap> iterator(Conjunction conjunction, Set<Identifier.Variable.Retrievable> filter,
Context.Query context) {
if (!conjunction.isCoherent()) return Iterators.empty();
if (conjunction.negations().isEmpty()) {
Expand Down
12 changes: 4 additions & 8 deletions reasoner/ReasonerProducer.java
Expand Up @@ -69,7 +69,7 @@ public class ReasonerProducer implements Producer<ConceptMap> {
private boolean sentReiterationRequests;

// TODO: this class should not be a Producer, it implements a different async processing mechanism
public ReasonerProducer(Conjunction conjunction, TypeQLMatch.Modifiers modifiers, Options.Query options,
public ReasonerProducer(Conjunction conjunction, Set<Identifier.Variable.Retrievable> filter, Options.Query options,
ResolverRegistry resolverRegistry, ExplainablesManager explainablesManager) {
this.options = options;
this.resolverRegistry = resolverRegistry;
Expand All @@ -82,15 +82,15 @@ public ReasonerProducer(Conjunction conjunction, TypeQLMatch.Modifiers modifiers
this.rootResolver = this.resolverRegistry.root(conjunction, this::requestAnswered, this::requestFailed, this::exception);
this.computeSize = options.parallel() ? Executors.PARALLELISATION_FACTOR * 2 : 1;
assert computeSize > 0;
Root<?, ?> downstream = InitialImpl.create(filter(modifiers.filter()), new ConceptMap(), this.rootResolver, options.explain()).toDownstream();
Root<?, ?> downstream = InitialImpl.create(filter, new ConceptMap(), this.rootResolver, options.explain()).toDownstream();
this.resolveRequest = Request.create(rootResolver, downstream);
this.reiterationRequest = ReiterationQuery.Request.create(rootResolver, this::receiveReiterationResponse);
this.sentReiterationRequests = false;
this.requiresReiteration = false;
if (options.traceInference()) ResolutionTracer.initialise(options.logsDir());
}

public ReasonerProducer(Disjunction disjunction, TypeQLMatch.Modifiers modifiers, Options.Query options,
public ReasonerProducer(Disjunction disjunction, Set<Identifier.Variable.Retrievable> filter, Options.Query options,
ResolverRegistry resolverRegistry, ExplainablesManager explainablesManager) {
this.options = options;
this.resolverRegistry = resolverRegistry;
Expand All @@ -103,7 +103,7 @@ public ReasonerProducer(Disjunction disjunction, TypeQLMatch.Modifiers modifiers
this.rootResolver = this.resolverRegistry.root(disjunction, this::requestAnswered, this::requestFailed, this::exception);
this.computeSize = options.parallel() ? Executors.PARALLELISATION_FACTOR * 2 : 1;
assert computeSize > 0;
Root<?, ?> downstream = InitialImpl.create(filter(modifiers.filter()), new ConceptMap(), this.rootResolver, options.explain()).toDownstream();
Root<?, ?> downstream = InitialImpl.create(filter, new ConceptMap(), this.rootResolver, options.explain()).toDownstream();
this.resolveRequest = Request.create(rootResolver, downstream);
this.reiterationRequest = ReiterationQuery.Request.create(rootResolver, this::receiveReiterationResponse);
this.sentReiterationRequests = false;
Expand All @@ -128,10 +128,6 @@ public synchronized void produce(Queue<ConceptMap> queue, int request, Executor
public void recycle() {
}

private Set<Identifier.Variable.Name> filter(List<UnboundVariable> filter) {
return iterate(filter).map(v -> Identifier.Variable.of(v.reference().asName())).toSet();
}

// note: root resolver calls this single-threaded, so is thread safe
private void requestAnswered(Finished answer) {
if (options.traceInference()) ResolutionTracer.get().finish();
Expand Down
2 changes: 1 addition & 1 deletion reasoner/resolution/answer/AnswerState.java
Expand Up @@ -74,7 +74,7 @@ interface Top extends AnswerState {

interface Match extends Top, Explainable {

Set<Identifier.Variable.Name> getFilter();
Set<Identifier.Variable.Retrievable> filter();

@Override
default boolean isMatch() {
Expand Down
32 changes: 16 additions & 16 deletions reasoner/resolution/answer/AnswerStateImpl.java
Expand Up @@ -80,21 +80,21 @@ public static abstract class TopImpl extends AnswerStateImpl implements Top {

public static abstract class MatchImpl extends TopImpl implements Match {

private final Set<Identifier.Variable.Name> getFilter;
private final Set<Identifier.Variable.Retrievable> filter;
private final boolean explainable;
private final int hash;

MatchImpl(Set<Identifier.Variable.Name> getFilter, ConceptMap conceptMap,
MatchImpl(Set<Identifier.Variable.Retrievable> getFilter, ConceptMap conceptMap,
Actor.Driver<? extends Resolver<?>> root, boolean explainable) {
super(conceptMap, root);
this.getFilter = getFilter;
this.filter = getFilter;
this.explainable = explainable;
this.hash = Objects.hash(root(), conceptMap(), getFilter(), explainable());
this.hash = Objects.hash(root(), conceptMap(), filter(), explainable());
}

@Override
public Set<Identifier.Variable.Name> getFilter() {
return getFilter;
public Set<Identifier.Variable.Retrievable> filter() {
return filter;
}

@Override
Expand All @@ -109,7 +109,7 @@ public boolean equals(Object o) {
AnswerStateImpl.TopImpl.MatchImpl that = (AnswerStateImpl.TopImpl.MatchImpl) o;
return Objects.equals(root(), that.root()) &&
Objects.equals(conceptMap(), that.conceptMap()) &&
Objects.equals(getFilter(), that.getFilter()) &&
Objects.equals(filter(), that.filter()) &&
explainable() == that.explainable();
}

Expand All @@ -120,12 +120,12 @@ public int hashCode() {

public static class InitialImpl extends MatchImpl implements Initial {

private InitialImpl(Set<Identifier.Variable.Name> getFilter, ConceptMap conceptMap,
private InitialImpl(Set<Identifier.Variable.Retrievable> getFilter, ConceptMap conceptMap,
Actor.Driver<? extends Resolver<?>> root, boolean explainable) {
super(getFilter, conceptMap, root, explainable);
}

public static InitialImpl create(Set<Identifier.Variable.Name> getFilter, ConceptMap conceptMap,
public static InitialImpl create(Set<Identifier.Variable.Retrievable> getFilter, ConceptMap conceptMap,
Actor.Driver<? extends Resolver<?>> root, boolean explainable) {
return new InitialImpl(getFilter, conceptMap, root, explainable);
}
Expand All @@ -138,24 +138,24 @@ public PartialImpl.CompoundImpl.RootImpl.MatchImpl toDownstream() {
@Override
public FinishedImpl finish(ConceptMap conceptMap) {
ConceptMap answer = conceptMap;
if (!explainable()) answer = conceptMap.filter(getFilter());
return FinishedImpl.create(getFilter(), answer, root(), explainable());
if (!explainable()) answer = conceptMap.filter(filter());
return FinishedImpl.create(filter(), answer, root(), explainable());
}

}

public static class FinishedImpl extends MatchImpl implements Finished {

private FinishedImpl(Set<Identifier.Variable.Name> getFilter, ConceptMap conceptMap,
private FinishedImpl(Set<Identifier.Variable.Retrievable> filter, ConceptMap conceptMap,
Actor.Driver<? extends Resolver<?>> root, boolean explainable) {
super(getFilter, conceptMap, root, explainable);
super(filter, conceptMap, root, explainable);
}

public static FinishedImpl create(Set<Identifier.Variable.Name> getFilter, ConceptMap conceptMap,
public static FinishedImpl create(Set<Identifier.Variable.Retrievable> filter, ConceptMap conceptMap,
Actor.Driver<? extends Resolver<?>> root, boolean explainable) {
ConceptMap initialAns = conceptMap;
if (!explainable) initialAns = conceptMap.filter(getFilter);
return new FinishedImpl(getFilter, initialAns, root, explainable);
if (!explainable) initialAns = conceptMap.filter(filter);
return new FinishedImpl(filter, initialAns, root, explainable);
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions reasoner/resolution/answer/AnswerStateTest.java
Expand Up @@ -43,7 +43,7 @@ public void test_root_empty_mapped_to_downstream_and_back() {
Map<Identifier.Variable.Retrievable, Identifier.Variable.Retrievable> mapping = new HashMap<>();
mapping.put(Identifier.Variable.name("a"), Identifier.Variable.name("x"));
mapping.put(Identifier.Variable.name("b"), Identifier.Variable.name("y"));
Set<Identifier.Variable.Name> filter = set(Identifier.Variable.name("a"), Identifier.Variable.name("b"));
Set<Identifier.Variable.Retrievable> filter = set(Identifier.Variable.name("a"), Identifier.Variable.name("b"));
Concludable.Match<?> mapped = InitialImpl.create(filter, new ConceptMap(), null, false).toDownstream().toDownstream(Mapping.of(mapping), null);
assertTrue(mapped.conceptMap().concepts().isEmpty());

Expand All @@ -64,7 +64,7 @@ public void test_root_partially_mapped_to_downstream_and_back() {
mapping.put(Identifier.Variable.name("b"), Identifier.Variable.name("y"));
Map<Identifier.Variable.Retrievable, Concept> concepts = new HashMap<>();
concepts.put(Identifier.Variable.name("a"), new MockConcept(0));
Set<Identifier.Variable.Name> filter = set(Identifier.Variable.name("a"), Identifier.Variable.name("b"));
Set<Identifier.Variable.Retrievable> filter = set(Identifier.Variable.name("a"), Identifier.Variable.name("b"));
Concludable.Match<?> mapped = InitialImpl.create(filter, new ConceptMap(), null, false).toDownstream()
.with(new ConceptMap(concepts))
.toDownstream(Mapping.of(mapping), null);
Expand Down Expand Up @@ -92,7 +92,7 @@ public void test_root_with_unmapped_elements() {
Map<Identifier.Variable.Retrievable, Concept> concepts = new HashMap<>();
concepts.put(Identifier.Variable.name("a"), new MockConcept(0));
concepts.put(Identifier.Variable.name("c"), new MockConcept(2));
Set<Identifier.Variable.Name> filter = set(Identifier.Variable.name("a"), Identifier.Variable.name("b"));
Set<Identifier.Variable.Retrievable> filter = set(Identifier.Variable.name("a"), Identifier.Variable.name("b"));
Concludable.Match<?> mapped = InitialImpl.create(filter, new ConceptMap(), null, false).toDownstream()
.with(new ConceptMap(concepts))
.toDownstream(Mapping.of(mapping), null);
Expand Down
7 changes: 4 additions & 3 deletions test/integration/reasoner/resolution/ReiterationTest.java
Expand Up @@ -119,8 +119,9 @@ public void test_first_iteration_exhausts_and_second_iteration_recurses_infinite
try (RocksSession session = dataSession()) {
try (RocksTransaction transaction = singleThreadElgTransaction(session)) {
Conjunction conjunction = resolvedConjunction("{ $y isa Y; }", transaction.logic());
Set<Identifier.Variable.Name> filter = iterate(conjunction.variables()).map(Variable::id).filter(Identifier::isName)
.map(Identifier.Variable::asName).toSet();
Set<Identifier.Variable.Retrievable> filter = new HashSet<>();
iterate(conjunction.variables()).map(Variable::id).filter(Identifier::isName)
.map(Identifier.Variable::asName).forEachRemaining(filter::add);
ResolverRegistry registry = transaction.reasoner().resolverRegistry();
LinkedBlockingQueue<Match.Finished> responses = new LinkedBlockingQueue<>();
LinkedBlockingQueue<Integer> failed = new LinkedBlockingQueue<>();
Expand Down Expand Up @@ -179,7 +180,7 @@ public void test_first_iteration_exhausts_and_second_iteration_recurses_infinite
}
}

private void sendRootRequest(Actor.Driver<RootResolver.Conjunction> root, Set<Identifier.Variable.Name> filter, int iteration) {
private void sendRootRequest(Actor.Driver<RootResolver.Conjunction> root, Set<Identifier.Variable.Retrievable> filter, int iteration) {
Root.Match downstream = InitialImpl.create(filter, new ConceptMap(), root, true).toDownstream();
root.execute(actor -> actor.receiveRequest(
Request.create(root, downstream), iteration)
Expand Down
12 changes: 7 additions & 5 deletions test/integration/reasoner/resolution/ResolutionTest.java
Expand Up @@ -46,6 +46,7 @@
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -169,7 +170,7 @@ public void test_disjunction_no_rules() throws InterruptedException {
}
try (RocksSession session = dataSession()) {
try (RocksTransaction transaction = singleThreadElgTransaction(session)) {
Set<Identifier.Variable.Name> filter = set(Identifier.Variable.name("t"),
Set<Identifier.Variable.Retrievable> filter = set(Identifier.Variable.name("t"),
Identifier.Variable.name("p1"),
Identifier.Variable.name("p2"));
Disjunction disjunction = resolvedDisjunction("{ $t(twin1: $p1, twin2: $p2) isa twins; { $p1 has age 24; } or { $p1 has age 26; }; }", transaction.logic());
Expand Down Expand Up @@ -447,7 +448,7 @@ private RocksTransaction singleThreadElgTransaction(RocksSession session) {
}

private void createRootAndAssertResponses(RocksTransaction transaction, Disjunction disjunction,
Set<Identifier.Variable.Name> filter, long answerCount,
Set<Identifier.Variable.Retrievable> filter, long answerCount,
long explainableAnswers) throws InterruptedException {
ResolverRegistry registry = transaction.reasoner().resolverRegistry();
LinkedBlockingQueue<Match.Finished> responses = new LinkedBlockingQueue<>();
Expand All @@ -467,8 +468,9 @@ private void createRootAndAssertResponses(RocksTransaction transaction, Conjunct
ResolverRegistry registry = transaction.reasoner().resolverRegistry();
LinkedBlockingQueue<Match.Finished> responses = new LinkedBlockingQueue<>();
AtomicLong doneReceived = new AtomicLong(0L);
Set<Identifier.Variable.Name> filter = iterate(conjunction.variables()).map(Variable::id)
.filter(Identifier::isName).map(Identifier.Variable::asName).toSet();
Set<Identifier.Variable.Retrievable> filter = new HashSet<>();
iterate(conjunction.variables()).map(Variable::id).filter(Identifier::isName).map(Identifier.Variable::asName)
.forEachRemaining(filter::add);
Actor.Driver<RootResolver.Conjunction> root;
try {
root = registry.root(conjunction, responses::add, iterDone -> doneReceived.incrementAndGet(), (throwable) -> fail());
Expand All @@ -479,7 +481,7 @@ private void createRootAndAssertResponses(RocksTransaction transaction, Conjunct
assertResponses(root, filter, responses, doneReceived, answerCount, explainableAnswers);
}

private void assertResponses(Actor.Driver<? extends Resolver<?>> root, Set<Identifier.Variable.Name> filter,
private void assertResponses(Actor.Driver<? extends Resolver<?>> root, Set<Identifier.Variable.Retrievable> filter,
LinkedBlockingQueue<Match.Finished> responses, AtomicLong doneReceived,
long answerCount, long explainableAnswers) throws InterruptedException {
long startTime = System.currentTimeMillis();
Expand Down

0 comments on commit afa09bd

Please sign in to comment.