Skip to content

Commit

Permalink
Merge 46653e1 into a9e3fb0
Browse files Browse the repository at this point in the history
  • Loading branch information
richardstartin committed Jul 17, 2019
2 parents a9e3fb0 + 46653e1 commit 94a5b12
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,20 @@ public ClassifierBuilder(Schema<Key, Input> registry) {
public ImmutableClassifier<Input, Classification> build(List<MatchingConstraint<Key, Classification>> constraints) {
int maxPriority = constraints.size();
return maxPriority < TinyMask.MAX_CAPACITY
? new ImmutableClassifier<>(build(constraints, TinyMask.FACTORY.contiguous(maxPriority), TinyMask.FACTORY, maxPriority))
? new ImmutableClassifier<>(build(constraints, TinyMask.FACTORY, maxPriority))
: maxPriority < SmallMask.MAX_CAPACITY
? new ImmutableClassifier<>(build(constraints, SmallMask.FACTORY.contiguous(maxPriority), SmallMask.FACTORY, maxPriority))
: new ImmutableClassifier<>(build(constraints, HugeMask.FACTORY.contiguous(maxPriority), HugeMask.FACTORY, maxPriority));
? new ImmutableClassifier<>(build(constraints, SmallMask.FACTORY, maxPriority))
: new ImmutableClassifier<>(build(constraints, HugeMask.FACTORY, maxPriority));
}

private <MaskType extends Mask<MaskType>>
MaskedClassifier<MaskType, Input, Classification> build(List<MatchingConstraint<Key, Classification>> specs,
MaskType mask,
MaskFactory<MaskType> maskFactory,
int max) {
PrimitiveIterator.OfInt sequence = IntStream.iterate(0, i -> i + 1).iterator();
specs.stream().sorted(Comparator.comparingInt(rd -> order(rd.getPriority())))
.forEach(rule -> addMatchingConstraint(rule, sequence.nextInt(), maskFactory, max));
return new MaskedClassifier<>((Classification[])classifications.toArray(), freezeMatchers(), mask);
return new MaskedClassifier<>((Classification[])classifications.toArray(), freezeMatchers(), maskFactory.contiguous(max));
}

private <MaskType extends Mask<MaskType>>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package uk.co.openkappa.bitrules;

import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.UUID;
import java.util.*;

import static java.util.Objects.requireNonNull;

Expand All @@ -22,7 +19,7 @@ public static <K, C> Builder<K, C> anonymous() {
public static class Builder<K, C> {

private final String id;
private Map<K, Constraint> constraints = new TreeMap<>();
private Map<K, Constraint> constraints = new HashMap<>();
private int priority;
private C classification;

Expand Down
12 changes: 12 additions & 0 deletions src/main/java/uk/co/openkappa/bitrules/masks/HugeMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ public void remove(int id) {

@Override
public HugeMask and(HugeMask other) {
if (other.isEmpty()) {
return FACTORY.empty();
}
return new HugeMask(RoaringBitmap.and(bitmap, other.bitmap));
}

Expand All @@ -43,17 +46,26 @@ public HugeMask andNot(HugeMask other) {

@Override
public HugeMask or(HugeMask other) {
if (other.isEmpty()) {
return this;
}
return new HugeMask(RoaringBitmap.or(bitmap, other.bitmap));
}

@Override
public HugeMask inPlaceAnd(HugeMask other) {
if (other.isEmpty()) {
return FACTORY.empty();
}
bitmap.and(other.bitmap);
return this;
}

@Override
public HugeMask inPlaceOr(HugeMask other) {
if (other.isEmpty()) {
return this;
}
bitmap.or(other.bitmap);
return this;
}
Expand Down
14 changes: 13 additions & 1 deletion src/main/java/uk/co/openkappa/bitrules/masks/SmallMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public void remove(int id) {

@Override
public SmallMask and(SmallMask other) {
if (other.isEmpty()) {
return FACTORY.empty();
}
return new SmallMask(container.and(other.container));
}

Expand All @@ -44,17 +47,26 @@ public SmallMask andNot(SmallMask other) {

@Override
public SmallMask inPlaceAnd(SmallMask other) {
if (other.isEmpty()) {
return FACTORY.empty();
}
this.container = container.iand(other.container);
return this;
}

@Override
public SmallMask or(SmallMask other) {
if (other.isEmpty()) {
return this;
}
return new SmallMask(container.or(other.container));
}

@Override
public SmallMask inPlaceOr(SmallMask other) {
if (other.isEmpty()) {
return this;
}
this.container = container.ior(other.container);
return this;
}
Expand Down Expand Up @@ -113,7 +125,7 @@ private static final class Factory implements MaskFactory<SmallMask> {

@Override
public SmallMask empty() {
return new SmallMask(new ArrayContainer());
return new SmallMask(new ArrayContainer(0));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static <MaskType extends Mask<MaskType>> float avgCardinality(MaskType[] masks)
static <Node> float avgCardinality(Collection<Node> nodes, ToDoubleFunction<Node> selectivity) {
float avg = 0;
int count = 0;
for (var node : nodes) {
for (Node node : nodes) {
avg += selectivity.applyAsDouble(node);
++count;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void addConstraint(Constraint constraint, int priority) {
break;
case EQ:
GenericEqualityNode<String, MaskType> literal = (GenericEqualityNode<String, MaskType>) nodes.computeIfAbsent(EQ,
o -> new GenericEqualityNode<>(mapSupplier.get(), empty, wildcards, PerfectHashMap::wrap));
o -> new GenericEqualityNode<>(mapSupplier.get(), empty, wildcards));
literal.add(constraint.getValue(), priority);
break;
default:
Expand Down Expand Up @@ -111,7 +111,7 @@ public ClassificationNode<String, MaskType> optimise() {
}
});
});
return new PrefixNode<>(empty, PerfectHashMap.wrap(map), longest);
return new PrefixNode<>(empty, map, longest);
}

public void add(String prefix, int id) {
Expand Down
50 changes: 46 additions & 4 deletions src/test/java/uk/co/openkappa/bitrules/LargeClassifierTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import org.junit.jupiter.api.Test;
import uk.co.openkappa.bitrules.schema.Schema;

import java.util.HashMap;
import java.util.Map;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.toList;
import static org.junit.jupiter.api.Assertions.assertEquals;


public class LargeClassifierTest {

@Test
public void testLargeClassifier() throws InterruptedException {
Thread.sleep(10000);
public void testLargeClassifier() {
Classifier<int[], String> classifier = ImmutableClassifier.
<Integer, int[], String>builder(Schema.<Integer, int[]>create()
.withAttribute(0, extract(0))
Expand All @@ -33,7 +34,7 @@ public void testLargeClassifier() throws InterruptedException {
.classification("SEGMENT" + i)
.build())

.collect(Collectors.toList())
.collect(toList())
);
int[] vector = new int[]{5, 5, 5, 5, 5};
String classification = classifier.classification(vector).orElseThrow(RuntimeException::new);
Expand All @@ -44,5 +45,46 @@ public void testLargeClassifier() throws InterruptedException {
private static ToIntFunction<int[]> extract(int feature) {
return features -> features[feature];
}


@Test
public void testLargeDiscreteClassifier() {
Classifier<Map<String, Object>, String> classifier = ImmutableClassifier.
<String, Map<String, Object>, String>builder(Schema.<String, Map<String, Object>>create()
.withStringAttribute("attr1", (Map<String, Object> map) -> (String)map.get("attr1"))
.withStringAttribute("attr2", (Map<String, Object> map) -> (String)map.get("attr2"))
.withStringAttribute("attr3", (Map<String, Object> map) -> (String)map.get("attr3"))
.withStringAttribute("attr4", (Map<String, Object> map) -> (String)map.get("attr4"))
.withStringAttribute("attr5", (Map<String, Object> map) -> (String)map.get("attr5"))
.withStringAttribute("attr6", (Map<String, Object> map) -> (String)map.get("attr6"))
).build(IntStream.range(0, 50000)
.mapToObj(i -> MatchingConstraint.<String, String>anonymous()
.eq("attr1", "value" + (i / 10000))
.eq("attr2", "value" + (i / 1000))
.eq("attr3", "value" + (i / 500))
.eq("attr4", "value" + (i / 250))
.eq("attr5", "value" + (i / 100))
.eq("attr6", "value" + (i / 10))
.classification("SEGMENT" + i).build()
).collect(toList()));

Map<String, Object> msg = new HashMap<>();
msg.put("attr1", "value0");
msg.put("attr2", "value0");
msg.put("attr3", "value0");
msg.put("attr4", "value0");
msg.put("attr5", "value0");
msg.put("attr6", "value9");


String classification = null;
long start = System.nanoTime();
for (int i = 0; i < 1_000_000; ++i) {
classification = classifier.classification(msg).orElseThrow(RuntimeException::new);
}
long end = System.nanoTime();
System.out.println(1_000_000 / ((end - start) / 1e6) + "ops/ms");
assertEquals("SEGMENT90", classification);
}
}

0 comments on commit 94a5b12

Please sign in to comment.