Skip to content

Commit

Permalink
Add fallback to RESP2 upon NOPROTO response #2455
Browse files Browse the repository at this point in the history
  • Loading branch information
mp911de committed Jul 25, 2023
1 parent 20e6720 commit 3e92091
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/main/java/io/lettuce/core/RedisHandshake.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private CompletionStage<?> tryHandshakeResp3(Channel channel) {
}

if (throwable != null) {
if (isUnknownCommand(throwable)) {
if (isUnknownCommand(throwable) || isNoProto(throwable)) {
try {
fallbackToResp2(channel, handshake);
} catch (Exception e) {
Expand All @@ -115,6 +115,7 @@ private CompletionStage<?> tryHandshakeResp3(Channel channel) {
handshake.completeExceptionally(throwable);
}
} else {
onHelloResponse(settings);
handshake.complete(null);
}
});
Expand Down Expand Up @@ -145,19 +146,20 @@ private CompletableFuture<?> initializeResp2(Channel channel) {
}

private CompletionStage<Void> initializeResp3(Channel channel) {
return initiateHandshakeResp3(channel, connectionState.getCredentialsProvider()).thenAccept(this::onHelloResponse);
}

return initiateHandshakeResp3(channel, connectionState.getCredentialsProvider()).thenAccept(response -> {
private void onHelloResponse(Map<String, Object> response) {

Long id = (Long) response.get("id");
String mode = (String) response.get("mode");
String version = (String) response.get("version");
String role = (String) response.get("role");
Long id = (Long) response.get("id");
String mode = (String) response.get("mode");
String version = (String) response.get("version");
String role = (String) response.get("role");

negotiatedProtocolVersion = ProtocolVersion.RESP3;
negotiatedProtocolVersion = ProtocolVersion.RESP3;

connectionState.setHandshakeResponse(
new ConnectionState.HandshakeResponse(negotiatedProtocolVersion, id, version, mode, role));
});
connectionState.setHandshakeResponse(
new ConnectionState.HandshakeResponse(negotiatedProtocolVersion, id, version, mode, role));
}

/**
Expand Down Expand Up @@ -272,4 +274,9 @@ private static boolean isUnknownCommand(Throwable error) {
&& ((error.getMessage().startsWith("ERR") && error.getMessage().contains("unknown")));
}

private static boolean isNoProto(Throwable error) {
return error instanceof RedisException && LettuceStrings.isNotEmpty(error.getMessage())
&& error.getMessage().startsWith("NOPROTO");
}

}
104 changes: 104 additions & 0 deletions src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.lettuce.core;

import java.nio.ByteBuffer;
import java.util.Map;

import io.lettuce.core.output.CommandOutput;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.ProtocolVersion;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.*;

/**
* Unit tests for {@link RedisHandshake}.
*
* @author Mark Paluch
*/
class RedisHandshakeUnitTests {

@Test
void handshakeWithResp3ShouldPass() {

EmbeddedChannel channel = new EmbeddedChannel(true, false);

ConnectionState state = new ConnectionState();
state.setCredentialsProvider(new StaticCredentialsProvider("foo", "bar".toCharArray()));
RedisHandshake handshake = new RedisHandshake(ProtocolVersion.RESP3, false, state);
handshake.initialize(channel);

AsyncCommand<String, String, Map<String, String>> hello = channel.readOutbound();
helloResponse(hello.getOutput());
hello.complete();

assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP3);
}

@Test
void handshakeWithDiscoveryShouldPass() {

EmbeddedChannel channel = new EmbeddedChannel(true, false);

ConnectionState state = new ConnectionState();
state.setCredentialsProvider(new StaticCredentialsProvider("foo", "bar".toCharArray()));
RedisHandshake handshake = new RedisHandshake(null, false, state);
handshake.initialize(channel);

AsyncCommand<String, String, Map<String, String>> hello = channel.readOutbound();
helloResponse(hello.getOutput());
hello.complete();

assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP3);
}

@Test
void handshakeWithDiscoveryShouldDowngrade() {

EmbeddedChannel channel = new EmbeddedChannel(true, false);

ConnectionState state = new ConnectionState();
state.setCredentialsProvider(new StaticCredentialsProvider(null, null));
RedisHandshake handshake = new RedisHandshake(null, false, state);
handshake.initialize(channel);

AsyncCommand<String, String, Map<String, String>> hello = channel.readOutbound();
hello.getOutput().setError("NOPROTO");
hello.completeExceptionally(new RedisException("NOPROTO"));
hello.complete();

assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP2);
}

private static void helloResponse(CommandOutput<String, String, Map<String, String>> output) {

output.multi(8);
output.set(ByteBuffer.wrap("id".getBytes()));
output.set(1);

output.set(ByteBuffer.wrap("mode".getBytes()));
output.set(ByteBuffer.wrap("master".getBytes()));

output.set(ByteBuffer.wrap("role".getBytes()));
output.set(ByteBuffer.wrap("master".getBytes()));

output.set(ByteBuffer.wrap("version".getBytes()));
output.set(ByteBuffer.wrap("1.2.3".getBytes()));
}

}

0 comments on commit 3e92091

Please sign in to comment.