Skip to content

Commit

Permalink
Initial support for client-side caching (#3658)
Browse files Browse the repository at this point in the history
  • Loading branch information
sazzad16 committed Dec 28, 2023
1 parent 77d52ab commit da9c463
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 2 deletions.
58 changes: 58 additions & 0 deletions src/main/java/redis/clients/jedis/ClientSideCache.java
@@ -0,0 +1,58 @@
package redis.clients.jedis;

import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.util.SafeEncoder;

public class ClientSideCache {

private final Map<ByteBuffer, Object> cache = new HashMap<>();

protected ClientSideCache() {
}

protected void invalidateKeys(List list) {
if (list == null) {
cache.clear();
return;
}

list.forEach(this::invalidateKey);
}

private void invalidateKey(Object key) {
if (key instanceof byte[]) {
cache.remove(convertKey((byte[]) key));
} else {
throw new JedisException("" + key.getClass().getSimpleName() + " is not supported. Value: " + String.valueOf(key));
}
}

protected void setKey(Object key, Object value) {
cache.put(getMapKey(key), value);
}

protected <T> T getValue(Object key) {
return (T) getMapValue(key);
}

private Object getMapValue(Object key) {
return cache.get(getMapKey(key));
}

private ByteBuffer getMapKey(Object key) {
if (key instanceof byte[]) {
return convertKey((byte[]) key);
} else {
return convertKey(SafeEncoder.encode(String.valueOf(key)));
}
}

private static ByteBuffer convertKey(byte[] b) {
return ByteBuffer.wrap(b);
}
}
21 changes: 20 additions & 1 deletion src/main/java/redis/clients/jedis/Connection.java
Expand Up @@ -34,6 +34,7 @@ public class Connection implements Closeable {
private Socket socket;
private RedisOutputStream outputStream;
private RedisInputStream inputStream;
private ClientSideCache clientSideCache;
private int soTimeout = 0;
private int infiniteSoTimeout = 0;
private boolean broken = false;
Expand Down Expand Up @@ -121,6 +122,10 @@ public void rollbackTimeout() {
}
}

final void setClientSideCache(ClientSideCache clientSideCache) {
this.clientSideCache = clientSideCache;
}

public Object executeCommand(final ProtocolCommand cmd) {
return executeCommand(new CommandArguments(cmd));
}
Expand Down Expand Up @@ -347,9 +352,10 @@ protected Object readProtocolWithCheckingBroken() {
}

try {
Protocol.readPushes(inputStream, clientSideCache);
return Protocol.read(inputStream);
// Object read = Protocol.read(inputStream);
// System.out.println(SafeEncoder.encodeObject(read));
// System.out.println("REPLY: " + SafeEncoder.encodeObject(read));
// return read;
} catch (JedisConnectionException exc) {
broken = true;
Expand All @@ -370,6 +376,19 @@ public List<Object> getMany(final int count) {
return responses;
}

protected void readPushesWithCheckingBroken() {
if (broken) {
throw new JedisConnectionException("Attempting to read pushes from a broken connection");
}

try {
Protocol.readPushes(inputStream, clientSideCache);
} catch (JedisConnectionException exc) {
broken = true;
throw exc;
}
}

/**
* Check if the client name libname, libver, characters are legal
* @param info the name
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/redis/clients/jedis/JedisClientSideCache.java
@@ -0,0 +1,45 @@
package redis.clients.jedis;

import redis.clients.jedis.exceptions.JedisException;

public class JedisClientSideCache extends Jedis {

private final ClientSideCache cache;

public JedisClientSideCache(final HostAndPort hostPort, final JedisClientConfig config) {
this(hostPort, config, new ClientSideCache());
}

public JedisClientSideCache(final HostAndPort hostPort, final JedisClientConfig config,
ClientSideCache cache) {
super(hostPort, config);
if (config.getRedisProtocol() != RedisProtocol.RESP3) {
throw new JedisException("Client side caching is only supported with RESP3.");
}

this.cache = cache;
this.connection.setClientSideCache(cache);
clientTrackingOn();
}

private void clientTrackingOn() {
String reply = connection.executeCommand(new CommandObject<>(
new CommandArguments(Protocol.Command.CLIENT).add("TRACKING").add("ON").add("BCAST"),
BuilderFactory.STRING));
if (!"OK".equals(reply)) {
throw new JedisException("Could not enable client tracking. Reply: " + reply);
}
}

@Override
public String get(String key) {
connection.readPushesWithCheckingBroken();
String cachedValue = cache.getValue(key);
if (cachedValue != null) return cachedValue;

String value = super.get(key);
if (value != null) cache.setKey(key, value);
return value;
}

}
27 changes: 26 additions & 1 deletion src/main/java/redis/clients/jedis/Protocol.java
Expand Up @@ -4,8 +4,10 @@
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import redis.clients.jedis.exceptions.*;
import redis.clients.jedis.args.Rawable;
Expand Down Expand Up @@ -57,6 +59,8 @@ public final class Protocol {
private static final String WRONGPASS_PREFIX = "WRONGPASS";
private static final String NOPERM_PREFIX = "NOPERM";

private static final byte[] INVALIDATE_BYTES = SafeEncoder.encode("invalidate");

private Protocol() {
throw new InstantiationError("Must not instantiate this class");
}
Expand Down Expand Up @@ -133,7 +137,7 @@ private static String[] parseTargetHostAndSlot(String clusterRedirectResponse) {

private static Object process(final RedisInputStream is) {
final byte b = is.readByte();
//System.out.println((char) b);
//System.out.println("BYTE: " + (char) b);
switch (b) {
case PLUS_BYTE:
return is.readLineBytes();
Expand Down Expand Up @@ -167,6 +171,15 @@ private static Object process(final RedisInputStream is) {
}
}

private static void processPush(final RedisInputStream is, ClientSideCache cache) {
List<Object> list = processMultiBulkReply(is);
//System.out.println("PUSH: " + SafeEncoder.encodeObject(list));
if (list.size() == 2 && list.get(0) instanceof byte[]
&& Arrays.equals(INVALIDATE_BYTES, (byte[]) list.get(0))) {
cache.invalidateKeys((List) list.get(1));
}
}

private static byte[] processBulkReply(final RedisInputStream is) {
final int len = is.readIntCrLf();
if (len == -1) {
Expand All @@ -193,11 +206,13 @@ private static byte[] processBulkReply(final RedisInputStream is) {
private static List<Object> processMultiBulkReply(final RedisInputStream is) {
// private static List<Object> processMultiBulkReply(final int num, final RedisInputStream is) {
final int num = is.readIntCrLf();
//System.out.println("MULTI BULK: " + num);
if (num == -1) return null;
final List<Object> ret = new ArrayList<>(num);
for (int i = 0; i < num; i++) {
try {
ret.add(process(is));
//System.out.println("MULTI >> " + (i+1) + ": " + SafeEncoder.encodeObject(ret.get(i)));
} catch (JedisDataException e) {
ret.add(e);
}
Expand All @@ -221,6 +236,16 @@ public static Object read(final RedisInputStream is) {
return process(is);
}

static void readPushes(final RedisInputStream is, final ClientSideCache cache) {
if (cache != null) {
//System.out.println("PEEK: " + is.peekByte());
while (Objects.equals(GREATER_THAN_BYTE, is.peekByte())) {
is.readByte();
processPush(is, cache);
}
}
}

public static final byte[] toByteArray(final boolean value) {
return value ? BYTES_TRUE : BYTES_FALSE;
}
Expand Down
19 changes: 19 additions & 0 deletions src/main/java/redis/clients/jedis/util/RedisInputStream.java
Expand Up @@ -43,6 +43,11 @@ public RedisInputStream(InputStream in) {
this(in, INPUT_BUFFER_SIZE);
}

public Byte peekByte() {
ensureFillSafe();
return buf[count];
}

public byte readByte() throws JedisConnectionException {
ensureFill();
return buf[count++];
Expand Down Expand Up @@ -252,4 +257,18 @@ private void ensureFill() throws JedisConnectionException {
}
}
}

private void ensureFillSafe() {
if (count >= limit) {
try {
limit = in.read(buf);
count = 0;
if (limit == -1) {
throw new JedisConnectionException("Unexpected end of stream.");
}
} catch (IOException e) {
// do nothing
}
}
}
}
70 changes: 70 additions & 0 deletions src/test/java/redis/clients/jedis/JedisClientSideCacheTest.java
@@ -0,0 +1,70 @@
package redis.clients.jedis;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.InOrder;
import org.mockito.Mockito;

public class JedisClientSideCacheTest {

protected static final HostAndPort hnp = HostAndPorts.getRedisServers().get(1);

protected Jedis jedis;

@Before
public void setUp() throws Exception {
jedis = new Jedis(hnp, DefaultJedisClientConfig.builder().timeoutMillis(500).password("foobared").build());
jedis.flushAll();
}

@After
public void tearDown() throws Exception {
jedis.close();
}

private static final JedisClientConfig configForCache = DefaultJedisClientConfig.builder()
.resp3().socketTimeoutMillis(20).password("foobared").build();

@Test
public void simple() {
try (JedisClientSideCache jCache = new JedisClientSideCache(hnp, configForCache)) {
jedis.set("foo", "bar");
assertEquals("bar", jCache.get("foo"));
jedis.del("foo");
assertNull(jCache.get("foo"));
}
}

@Test
public void simpleMock() {
ClientSideCache cache = Mockito.mock(ClientSideCache.class);
try (JedisClientSideCache jCache = new JedisClientSideCache(hnp, configForCache, cache)) {
jedis.set("foo", "bar");
assertEquals("bar", jCache.get("foo"));
jedis.del("foo");
assertNull(jCache.get("foo"));
}

InOrder inOrder = Mockito.inOrder(cache);
inOrder.verify(cache).invalidateKeys(Mockito.notNull());
inOrder.verify(cache).getValue("foo");
inOrder.verify(cache).setKey("foo", "bar");
inOrder.verify(cache).invalidateKeys(Mockito.notNull());
inOrder.verify(cache).getValue("foo");
inOrder.verifyNoMoreInteractions();
}

@Test
public void flushall() {
try (JedisClientSideCache jCache = new JedisClientSideCache(hnp, configForCache)) {
jedis.set("foo", "bar");
assertEquals("bar", jCache.get("foo"));
jedis.flushAll();
assertNull(jCache.get("foo"));
}
}
}

0 comments on commit da9c463

Please sign in to comment.