diff --git a/client/src/main/java/io/split/client/impressions/ImpressionCounter.java b/client/src/main/java/io/split/client/impressions/ImpressionCounter.java new file mode 100644 index 000000000..09bfdb26b --- /dev/null +++ b/client/src/main/java/io/split/client/impressions/ImpressionCounter.java @@ -0,0 +1,47 @@ +package io.split.client.impressions; + +import java.util.HashMap; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +public class ImpressionCounter { + + private static final long TIME_INTERVAL_MS = 3600L * 1000L; + + private final ConcurrentHashMap _counts; + + public ImpressionCounter() { + _counts = new ConcurrentHashMap<>(); + } + + public void inc(String featureName, long timeFrame, int amount) { + String key = makeKey(featureName, timeFrame); + AtomicInteger count = _counts.get(key); + if (Objects.isNull(count)) { + count = new AtomicInteger(); + AtomicInteger old = _counts.putIfAbsent(key, count); + if (!Objects.isNull(old)) { // Some other thread won the race, use that AtomicInteger instead + count = old; + } + } + count.addAndGet(amount); + } + + public HashMap popAll() { + HashMap toReturn = new HashMap<>(); + for (String key : _counts.keySet()) { + AtomicInteger curr = _counts.remove(key); + toReturn.put(key ,curr.get()); + } + return toReturn; + } + + static String makeKey(String featureName, long timeFrame) { + return String.join("::", featureName, String.valueOf(truncateTimeframe(timeFrame))); + } + + static long truncateTimeframe(long timestampInMs) { + return timestampInMs - (timestampInMs % TIME_INTERVAL_MS); + } +} diff --git a/client/src/main/java/io/split/client/impressions/ImpressionObserver.java b/client/src/main/java/io/split/client/impressions/ImpressionObserver.java index b85a91da7..38164a445 100644 --- a/client/src/main/java/io/split/client/impressions/ImpressionObserver.java +++ b/client/src/main/java/io/split/client/impressions/ImpressionObserver.java @@ -3,7 +3,8 @@ import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import io.split.client.dtos.KeyImpression; -import org.apache.http.annotation.NotThreadSafe; + +import java.util.Objects; /* According to guava's docs (https://guava.dev/releases/18.0/api/docs/com/google/common/annotations/Beta.html), @@ -13,7 +14,6 @@ */ @SuppressWarnings("UnstableApiUsage") -@NotThreadSafe public class ImpressionObserver { private final Cache _cache; @@ -33,6 +33,6 @@ public Long testAndSet(KeyImpression impression) { Long hash = ImpressionHasher.process(impression); Long previous = _cache.getIfPresent(hash); _cache.put(hash, impression.time); - return previous; + return (Objects.isNull(previous)) ? null : Math.min(previous, impression.time); } } \ No newline at end of file diff --git a/client/src/test/java/io/split/client/impressions/ImpressionCounterTest.java b/client/src/test/java/io/split/client/impressions/ImpressionCounterTest.java new file mode 100644 index 000000000..a25c4e78e --- /dev/null +++ b/client/src/test/java/io/split/client/impressions/ImpressionCounterTest.java @@ -0,0 +1,117 @@ +package io.split.client.impressions; + +import org.junit.Test; + +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; + +public class ImpressionCounterTest { + + private long makeTimestamp(int year, int month, int day, int hour, int minute, int second) { + return ZonedDateTime.of(year, month, day, hour, minute, second, 0, ZoneId.of("UTC")).toInstant().toEpochMilli(); + } + + @Test + public void testTruncateTimeFrame() { + assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 53, 12)), + is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0)))); + assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 0, 0)), + is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0)))); + assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 53, 0 )), + is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0)))); + assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 0, 12)), + is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0)))); + assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(1970, 1, 1, 0, 0, 0)), + is(equalTo(makeTimestamp(1970, 1, 1, 0, 0, 0)))); + } + + @Test + public void testMakeKey() { + long targetTZ = makeTimestamp(2020, 9, 2, 10, 0, 0); + assertThat(ImpressionCounter.makeKey("someFeature", makeTimestamp(2020, 9, 2, 10, 5, 23)), + is(equalTo("someFeature::" + targetTZ))); + assertThat(ImpressionCounter.makeKey("", makeTimestamp(2020, 9, 2, 10, 5, 23)), + is(equalTo("::" + targetTZ))); + assertThat(ImpressionCounter.makeKey(null, makeTimestamp(2020, 9, 2, 10, 5, 23)), + is(equalTo("null::" + targetTZ))); + assertThat(ImpressionCounter.makeKey(null, 0L), is(equalTo("null::0"))); + } + + @Test + public void testBasicUsage() { + final ImpressionCounter counter = new ImpressionCounter(); + final long timestamp = makeTimestamp(2020, 9, 2, 10, 10, 12); + counter.inc("feature1", timestamp, 1); + counter.inc("feature1", timestamp + 1, 1); + counter.inc("feature1", timestamp + 2, 1); + counter.inc("feature2", timestamp + 3, 2); + counter.inc("feature2", timestamp + 4, 2); + Map counted = counter.popAll(); + assertThat(counted.size(), is(equalTo(2))); + assertThat(counted.get(ImpressionCounter.makeKey("feature1", timestamp)), is(equalTo(3))); + assertThat(counted.get(ImpressionCounter.makeKey("feature2", timestamp)), is(equalTo(4))); + assertThat(counter.popAll().size(), is(equalTo(0))); + + final long nextHourTimestamp = makeTimestamp(2020, 9, 2, 11, 10, 12); + counter.inc("feature1", timestamp, 1); + counter.inc("feature1", timestamp + 1, 1); + counter.inc("feature1", timestamp + 2, 1); + counter.inc("feature2", timestamp + 3, 2); + counter.inc("feature2", timestamp + 4, 2); + counter.inc("feature1", nextHourTimestamp, 1); + counter.inc("feature1", nextHourTimestamp + 1, 1); + counter.inc("feature1", nextHourTimestamp + 2, 1); + counter.inc("feature2", nextHourTimestamp + 3, 2); + counter.inc("feature2", nextHourTimestamp + 4, 2); + counted = counter.popAll(); + assertThat(counted.size(), is(equalTo(4))); + assertThat(counted.get(ImpressionCounter.makeKey("feature1", timestamp)), is(equalTo(3))); + assertThat(counted.get(ImpressionCounter.makeKey("feature2", timestamp)), is(equalTo(4))); + assertThat(counted.get(ImpressionCounter.makeKey("feature1", nextHourTimestamp)), is(equalTo(3))); + assertThat(counted.get(ImpressionCounter.makeKey("feature2", nextHourTimestamp)), is(equalTo(4))); + assertThat(counter.popAll().size(), is(equalTo(0))); + } + + @Test + public void manyConcurrentCalls() throws InterruptedException { + final int iterations = 10000000; + final long timestamp = makeTimestamp(2020, 9, 2, 10, 10, 12); + final long nextHourTimestamp = makeTimestamp(2020, 9, 2, 11, 10, 12); + ImpressionCounter counter = new ImpressionCounter(); + Thread t1 = new Thread(() -> { + int times = iterations; + while (times-- > 0) { + counter.inc("feature1", timestamp, 1); + counter.inc("feature2", timestamp, 1); + counter.inc("feature1", nextHourTimestamp, 2); + counter.inc("feature2", nextHourTimestamp, 2); + } + }); + Thread t2 = new Thread(() -> { + int times = iterations; + while (times-- > 0) { + counter.inc("feature1", timestamp, 2); + counter.inc("feature2", timestamp, 2); + counter.inc("feature1", nextHourTimestamp, 1); + counter.inc("feature2", nextHourTimestamp, 1); + } + }); + + t1.setDaemon(true); t2.setDaemon(true); + t1.start(); t2.start(); + t1.join(); t2.join(); + + HashMap counted = counter.popAll(); + assertThat(counted.size(), is(equalTo(4))); + assertThat(counted.get(ImpressionCounter.makeKey("feature1", timestamp)), is(equalTo(iterations * 3))); + assertThat(counted.get(ImpressionCounter.makeKey("feature2", timestamp)), is(equalTo(iterations * 3))); + assertThat(counted.get(ImpressionCounter.makeKey("feature1", nextHourTimestamp)), is(equalTo(iterations * 3))); + assertThat(counted.get(ImpressionCounter.makeKey("feature2", nextHourTimestamp)), is(equalTo(iterations * 3))); + } +} diff --git a/client/src/test/java/io/split/client/impressions/ImpressionObserverTest.java b/client/src/test/java/io/split/client/impressions/ImpressionObserverTest.java index 7ef190391..599d32797 100644 --- a/client/src/test/java/io/split/client/impressions/ImpressionObserverTest.java +++ b/client/src/test/java/io/split/client/impressions/ImpressionObserverTest.java @@ -2,18 +2,22 @@ import com.google.common.base.Strings; import io.split.client.dtos.KeyImpression; -// import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.core.AnyOf.anyOf; +import static org.hamcrest.core.Is.is; +import static org.hamcrest.core.IsEqual.equalTo; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; @@ -24,6 +28,7 @@ public class ImpressionObserverTest { // We allow the cache implementation to have a 0.01% drift in size when elements change, given that it's internal // structure/references might vary, and the ObjectSizeCalculator is not 100% accurate private static final double SIZE_DELTA = 0.01; + private final Random _rand = new Random(); private List generateKeyImpressions(long count) { ArrayList imps = new ArrayList<>(); @@ -91,7 +96,43 @@ public void testMemoryUsageStopsWhenCacheIsFull() throws Exception { long sizeAfterSecondPopulation = (long) getObjectSize.invoke(null, observer); - assertThat((double) (sizeAfterSecondPopulation - sizeAfterInitialPopulation), lessThan (SIZE_DELTA * sizeAfterInitialPopulation)); + assertThat((double) (sizeAfterSecondPopulation - sizeAfterInitialPopulation), lessThan(SIZE_DELTA * sizeAfterInitialPopulation)); + } + + + private void caller(ImpressionObserver o, int count, ConcurrentLinkedQueue imps) { + while (count-- > 0) { + KeyImpression k = new KeyImpression(); + k.keyName = "key_" + _rand.nextInt(100); + k.feature = "feature_" + _rand.nextInt(10); + k.label = "label" + _rand.nextInt(5); + k.treatment = _rand.nextBoolean() ? "on" : "off"; + k.changeNumber = 1234567L; + k.time = System.currentTimeMillis(); + k.pt = o.testAndSet(k); + imps.offer(k); + } + } + + @Test + public void testConcurrencyVsAccuracy() throws InterruptedException { + ImpressionObserver observer = new ImpressionObserver(500000); + ConcurrentLinkedQueue imps = new ConcurrentLinkedQueue<>(); + Thread t1 = new Thread(() -> caller(observer, 1000000, imps)); + Thread t2 = new Thread(() -> caller(observer, 1000000, imps)); + Thread t3 = new Thread(() -> caller(observer, 1000000, imps)); + Thread t4 = new Thread(() -> caller(observer, 1000000, imps)); + Thread t5 = new Thread(() -> caller(observer, 1000000, imps)); + + // start the 5 threads an wait for them to finish. + t1.setDaemon(true); t2.setDaemon(true); t3.setDaemon(true); t4.setDaemon(true); t5.setDaemon(true); + t1.start(); t2.start(); t3.start(); t4.start(); t5.start(); + t1.join(); t2.join(); t3.join(); t4.join(); t5.join(); + + assertThat(imps.size(), is(equalTo(5000000))); + for (KeyImpression i : imps) { + assertThat(i.pt, is(anyOf(nullValue(), lessThanOrEqualTo(i.time)))); + } } }