Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.split.client.impressions;

import java.util.HashMap;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

public class ImpressionCounter {

private static final long TIME_INTERVAL_MS = 3600L * 1000L;

private final ConcurrentHashMap<String, AtomicInteger> _counts;

public ImpressionCounter() {
_counts = new ConcurrentHashMap<>();
}

public void inc(String featureName, long timeFrame, int amount) {
String key = makeKey(featureName, timeFrame);
AtomicInteger count = _counts.get(key);
if (Objects.isNull(count)) {
count = new AtomicInteger();
AtomicInteger old = _counts.putIfAbsent(key, count);
if (!Objects.isNull(old)) { // Some other thread won the race, use that AtomicInteger instead
count = old;
}
}
count.addAndGet(amount);
}

public HashMap<String, Integer> popAll() {
HashMap<String, Integer> toReturn = new HashMap<>();
for (String key : _counts.keySet()) {
AtomicInteger curr = _counts.remove(key);
toReturn.put(key ,curr.get());
}
return toReturn;
}

static String makeKey(String featureName, long timeFrame) {
return String.join("::", featureName, String.valueOf(truncateTimeframe(timeFrame)));
}

static long truncateTimeframe(long timestampInMs) {
return timestampInMs - (timestampInMs % TIME_INTERVAL_MS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.split.client.dtos.KeyImpression;
import org.apache.http.annotation.NotThreadSafe;

import java.util.Objects;

/*
According to guava's docs (https://guava.dev/releases/18.0/api/docs/com/google/common/annotations/Beta.html),
Expand All @@ -13,7 +14,6 @@
*/

@SuppressWarnings("UnstableApiUsage")
@NotThreadSafe
public class ImpressionObserver {

private final Cache<Long, Long> _cache;
Expand All @@ -33,6 +33,6 @@ public Long testAndSet(KeyImpression impression) {
Long hash = ImpressionHasher.process(impression);
Long previous = _cache.getIfPresent(hash);
_cache.put(hash, impression.time);
return previous;
return (Objects.isNull(previous)) ? null : Math.min(previous, impression.time);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package io.split.client.impressions;

import org.junit.Test;

import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.IsEqual.equalTo;

public class ImpressionCounterTest {

private long makeTimestamp(int year, int month, int day, int hour, int minute, int second) {
return ZonedDateTime.of(year, month, day, hour, minute, second, 0, ZoneId.of("UTC")).toInstant().toEpochMilli();
}

@Test
public void testTruncateTimeFrame() {
assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 53, 12)),
is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0))));
assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 0, 0)),
is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0))));
assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 53, 0 )),
is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0))));
assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(2020, 9, 2, 10, 0, 12)),
is(equalTo(makeTimestamp(2020, 9, 2, 10, 0, 0))));
assertThat(ImpressionCounter.truncateTimeframe(makeTimestamp(1970, 1, 1, 0, 0, 0)),
is(equalTo(makeTimestamp(1970, 1, 1, 0, 0, 0))));
}

@Test
public void testMakeKey() {
long targetTZ = makeTimestamp(2020, 9, 2, 10, 0, 0);
assertThat(ImpressionCounter.makeKey("someFeature", makeTimestamp(2020, 9, 2, 10, 5, 23)),
is(equalTo("someFeature::" + targetTZ)));
assertThat(ImpressionCounter.makeKey("", makeTimestamp(2020, 9, 2, 10, 5, 23)),
is(equalTo("::" + targetTZ)));
assertThat(ImpressionCounter.makeKey(null, makeTimestamp(2020, 9, 2, 10, 5, 23)),
is(equalTo("null::" + targetTZ)));
assertThat(ImpressionCounter.makeKey(null, 0L), is(equalTo("null::0")));
}

@Test
public void testBasicUsage() {
final ImpressionCounter counter = new ImpressionCounter();
final long timestamp = makeTimestamp(2020, 9, 2, 10, 10, 12);
counter.inc("feature1", timestamp, 1);
counter.inc("feature1", timestamp + 1, 1);
counter.inc("feature1", timestamp + 2, 1);
counter.inc("feature2", timestamp + 3, 2);
counter.inc("feature2", timestamp + 4, 2);
Map<String, Integer> counted = counter.popAll();
assertThat(counted.size(), is(equalTo(2)));
assertThat(counted.get(ImpressionCounter.makeKey("feature1", timestamp)), is(equalTo(3)));
assertThat(counted.get(ImpressionCounter.makeKey("feature2", timestamp)), is(equalTo(4)));
assertThat(counter.popAll().size(), is(equalTo(0)));

final long nextHourTimestamp = makeTimestamp(2020, 9, 2, 11, 10, 12);
counter.inc("feature1", timestamp, 1);
counter.inc("feature1", timestamp + 1, 1);
counter.inc("feature1", timestamp + 2, 1);
counter.inc("feature2", timestamp + 3, 2);
counter.inc("feature2", timestamp + 4, 2);
counter.inc("feature1", nextHourTimestamp, 1);
counter.inc("feature1", nextHourTimestamp + 1, 1);
counter.inc("feature1", nextHourTimestamp + 2, 1);
counter.inc("feature2", nextHourTimestamp + 3, 2);
counter.inc("feature2", nextHourTimestamp + 4, 2);
counted = counter.popAll();
assertThat(counted.size(), is(equalTo(4)));
assertThat(counted.get(ImpressionCounter.makeKey("feature1", timestamp)), is(equalTo(3)));
assertThat(counted.get(ImpressionCounter.makeKey("feature2", timestamp)), is(equalTo(4)));
assertThat(counted.get(ImpressionCounter.makeKey("feature1", nextHourTimestamp)), is(equalTo(3)));
assertThat(counted.get(ImpressionCounter.makeKey("feature2", nextHourTimestamp)), is(equalTo(4)));
assertThat(counter.popAll().size(), is(equalTo(0)));
}

@Test
public void manyConcurrentCalls() throws InterruptedException {
final int iterations = 10000000;
final long timestamp = makeTimestamp(2020, 9, 2, 10, 10, 12);
final long nextHourTimestamp = makeTimestamp(2020, 9, 2, 11, 10, 12);
ImpressionCounter counter = new ImpressionCounter();
Thread t1 = new Thread(() -> {
int times = iterations;
while (times-- > 0) {
counter.inc("feature1", timestamp, 1);
counter.inc("feature2", timestamp, 1);
counter.inc("feature1", nextHourTimestamp, 2);
counter.inc("feature2", nextHourTimestamp, 2);
}
});
Thread t2 = new Thread(() -> {
int times = iterations;
while (times-- > 0) {
counter.inc("feature1", timestamp, 2);
counter.inc("feature2", timestamp, 2);
counter.inc("feature1", nextHourTimestamp, 1);
counter.inc("feature2", nextHourTimestamp, 1);
}
});

t1.setDaemon(true); t2.setDaemon(true);
t1.start(); t2.start();
t1.join(); t2.join();

HashMap<String, Integer> counted = counter.popAll();
assertThat(counted.size(), is(equalTo(4)));
assertThat(counted.get(ImpressionCounter.makeKey("feature1", timestamp)), is(equalTo(iterations * 3)));
assertThat(counted.get(ImpressionCounter.makeKey("feature2", timestamp)), is(equalTo(iterations * 3)));
assertThat(counted.get(ImpressionCounter.makeKey("feature1", nextHourTimestamp)), is(equalTo(iterations * 3)));
assertThat(counted.get(ImpressionCounter.makeKey("feature2", nextHourTimestamp)), is(equalTo(iterations * 3)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import com.google.common.base.Strings;
import io.split.client.dtos.KeyImpression;
// import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentLinkedQueue;

import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;

Expand All @@ -24,6 +28,7 @@ public class ImpressionObserverTest {
// We allow the cache implementation to have a 0.01% drift in size when elements change, given that it's internal
// structure/references might vary, and the ObjectSizeCalculator is not 100% accurate
private static final double SIZE_DELTA = 0.01;
private final Random _rand = new Random();

private List<KeyImpression> generateKeyImpressions(long count) {
ArrayList<KeyImpression> imps = new ArrayList<>();
Expand Down Expand Up @@ -91,7 +96,43 @@ public void testMemoryUsageStopsWhenCacheIsFull() throws Exception {

long sizeAfterSecondPopulation = (long) getObjectSize.invoke(null, observer);

assertThat((double) (sizeAfterSecondPopulation - sizeAfterInitialPopulation), lessThan (SIZE_DELTA * sizeAfterInitialPopulation));
assertThat((double) (sizeAfterSecondPopulation - sizeAfterInitialPopulation), lessThan(SIZE_DELTA * sizeAfterInitialPopulation));
}


private void caller(ImpressionObserver o, int count, ConcurrentLinkedQueue<KeyImpression> imps) {

while (count-- > 0) {
KeyImpression k = new KeyImpression();
k.keyName = "key_" + _rand.nextInt(100);
k.feature = "feature_" + _rand.nextInt(10);
k.label = "label" + _rand.nextInt(5);
k.treatment = _rand.nextBoolean() ? "on" : "off";
k.changeNumber = 1234567L;
k.time = System.currentTimeMillis();
k.pt = o.testAndSet(k);
imps.offer(k);
}
}

@Test
public void testConcurrencyVsAccuracy() throws InterruptedException {
ImpressionObserver observer = new ImpressionObserver(500000);
ConcurrentLinkedQueue<KeyImpression> imps = new ConcurrentLinkedQueue<>();
Thread t1 = new Thread(() -> caller(observer, 1000000, imps));
Thread t2 = new Thread(() -> caller(observer, 1000000, imps));
Thread t3 = new Thread(() -> caller(observer, 1000000, imps));
Thread t4 = new Thread(() -> caller(observer, 1000000, imps));
Thread t5 = new Thread(() -> caller(observer, 1000000, imps));

// start the 5 threads an wait for them to finish.
t1.setDaemon(true); t2.setDaemon(true); t3.setDaemon(true); t4.setDaemon(true); t5.setDaemon(true);
t1.start(); t2.start(); t3.start(); t4.start(); t5.start();
t1.join(); t2.join(); t3.join(); t4.join(); t5.join();

assertThat(imps.size(), is(equalTo(5000000)));
for (KeyImpression i : imps) {
assertThat(i.pt, is(anyOf(nullValue(), lessThanOrEqualTo(i.time))));
}
}
}