Skip to content
This repository has been archived by the owner on Sep 29, 2021. It is now read-only.

Commit

Permalink
Move master endpoint validation logic to Endpoint
Browse files Browse the repository at this point in the history
from DefaultRequestDispatcher.

Remove VersionCommandTest.testVersionWithFailedConnection().
Endpoint validation happens before RequestDispatcher is called.

Instead of showing the user the error message "Unable to connect to
master," we print a more helpful error message about what kind of master
endpoints are valid to stderr.
  • Loading branch information
davidxia committed Nov 17, 2015
1 parent 14c0d03 commit d6a6b4f
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 30 deletions.
Expand Up @@ -20,7 +20,6 @@
import com.google.common.base.Joiner;
import com.google.common.base.Supplier;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Queues;
import com.google.common.util.concurrent.ListenableFuture;
Expand Down Expand Up @@ -77,9 +76,6 @@ class DefaultRequestDispatcher implements RequestDispatcher {

private static final long RETRY_TIMEOUT_MILLIS = SECONDS.toMillis(60);
private static final long HTTP_TIMEOUT_MILLIS = SECONDS.toMillis(10);
private static final List<String> VALID_PROTOCOLS = ImmutableList.of("http", "https");
private static final String VALID_PROTOCOLS_STR =
String.format("[%s]", Joiner.on("|").join(VALID_PROTOCOLS));

private final Iterator<Endpoint> endpointIterator;
private final ListeningExecutorService executorService;
Expand Down Expand Up @@ -174,23 +170,16 @@ private HttpURLConnection connect(final URI uri, final String method, final byte
final URI endpointUri = endpoint.getUri();
final String fullpath = endpointUri.getPath() + uri.getPath();

final String scheme = endpointUri.getScheme();
final String host = endpointUri.getHost();
final int port = endpointUri.getPort();
if (!VALID_PROTOCOLS.contains(scheme) || host == null || port == -1) {
throw new HeliosException(String.format(
"Master endpoints must be of the form \"%s://heliosmaster.domain.net:<port>\"",
VALID_PROTOCOLS_STR));
}
final String uriScheme = endpointUri.getScheme();

final URI ipUri = new URI(
scheme, endpointUri.getUserInfo(), endpoint.getIp().getHostAddress(),
port, fullpath, uri.getQuery(), null);
uriScheme, endpointUri.getUserInfo(), endpoint.getIp().getHostAddress(),
endpointUri.getPort(), fullpath, uri.getQuery(), null);

AgentProxy agentProxy = null;
Deque<Identity> identities = Queues.newArrayDeque();
try {
if (scheme.equals("https")) {
if (uriScheme.equals("https")) {
agentProxy = AgentProxies.newInstance();
for (final Identity identity : agentProxy.list()) {
if (identity.getPublicKey().getAlgorithm().equals("RSA")) {
Expand Down
Expand Up @@ -28,6 +28,9 @@ public interface Endpoint {

/**
* Returns the {@link URI} of the endpoint.
* A valid URI for Helios must have a scheme that's either http or https,
* a hostname, and a port. I.e. it must be of the form http(s)://heliosmaster.domain.net:port.
* It's up to the implementation to enforce this.
* @return URI
*/
URI getUri();
Expand Down
Expand Up @@ -17,6 +17,7 @@

package com.spotify.helios.client;

import com.google.common.base.Joiner;
import com.google.common.base.Objects;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
Expand All @@ -31,6 +32,8 @@
import java.net.UnknownHostException;
import java.util.List;

import static com.google.common.base.Preconditions.checkNotNull;

/**
* A class that provides static factory methods for {@link Endpoint}.
*/
Expand Down Expand Up @@ -100,12 +103,25 @@ static List<Endpoint> of(final List<URI> uris, final DnsResolver dnsResolver) {

private static class DefaultEndpoint implements Endpoint {

private static final List<String> VALID_PROTOCOLS = ImmutableList.of("http", "https");
private static final String VALID_PROTOCOLS_STR =
String.format("[%s]", Joiner.on("|").join(VALID_PROTOCOLS));

private final InetAddress ip;
private final URI uri;

DefaultEndpoint(final URI uri, final InetAddress ip) {
this.uri = uri;
this.ip = ip;
this.uri = checkNotNull(uri);
this.ip = checkNotNull(ip);

final String scheme = this.uri.getScheme();
final String host = this.uri.getHost();
final int port = this.uri.getPort();
if (!VALID_PROTOCOLS.contains(scheme) || host == null || port == -1) {
throw new IllegalArgumentException(String.format(
"Master endpoints must be of the form \"%s://heliosmaster.domain.net:<port>\"",
VALID_PROTOCOLS_STR));
}
}

@Override
Expand Down
Expand Up @@ -24,7 +24,9 @@

import org.apache.http.conn.DnsResolver;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.net.InetAddress;
import java.net.URI;
Expand All @@ -50,10 +52,13 @@ public class EndpointsTest {
private static URI uri2;
private static List<URI> uris;

@Rule
public final ExpectedException exception = ExpectedException.none();

@Before
public void setup() throws Exception {
uri1 = new URI("http://example.com");
uri2 = new URI("https://example.net");
uri1 = new URI("http://example.com:80");
uri2 = new URI("https://example.net:8080");
uris = ImmutableList.of(uri1, uri2);
}

Expand Down Expand Up @@ -104,4 +109,20 @@ public void testUnableToResolve() throws Exception {

assertThat(endpoints.size(), equalTo(0));
}

@Test
public void testInvalidUri_NoScheme() throws Exception {
final DnsResolver resolver = mock(DnsResolver.class);
when(resolver.resolve("example.com")).thenReturn(IPS_1);
exception.expect(IllegalArgumentException.class);
Endpoints.of(ImmutableList.of(new URI(null, "example.com", null, null)), resolver);
}

@Test
public void testInvalidUri_NoPort() throws Exception {
final DnsResolver resolver = mock(DnsResolver.class);
when(resolver.resolve("example.com")).thenReturn(IPS_1);
exception.expect(IllegalArgumentException.class);
Endpoints.of(ImmutableList.of(new URI("http", "example.com", null, null)), resolver);
}
}
Expand Up @@ -47,17 +47,6 @@ public void testJsonVersion() throws Exception {
assertEquals("wrong master version", POM_VERSION, version.getMasterVersion());
}

@Test
public void testVersionWithFailedConnection() throws Exception {
startDefaultMaster();
// If we fail to connect to master, we should still get the correct client version, and a nice
// error message instead of master version. Specify bogus endpoint to make this happen.
final VersionResponse version = getVersion("version", "--json", "-z", "-1");
assertEquals("wrong client version", POM_VERSION, version.getClientVersion());
assertEquals("wrong master version", "Unable to connect to master",
version.getMasterVersion());
}

@Test
public void testVersionWithServerError() throws Exception {
startDefaultMaster();
Expand Down

0 comments on commit d6a6b4f

Please sign in to comment.