Skip to content
This repository has been archived by the owner on Nov 22, 2023. It is now read-only.

Commit

Permalink
Block bad requests sooner with preconditions
Browse files Browse the repository at this point in the history
Adds precondition checks to KeywhizClient, to block bad requests sooner.
Move base64'ing content into KeywhizClient instead of CLI.
  • Loading branch information
Justin Cummins committed Jun 10, 2015
1 parent 69bc717 commit 554221b
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 60 deletions.
5 changes: 5 additions & 0 deletions api/pom.xml
Expand Up @@ -53,6 +53,11 @@
<artifactId>dropwizard-testing</artifactId> <artifactId>dropwizard-testing</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>javax.inject</groupId>
<artifactId>javax.inject</artifactId>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>com.squareup.keywhiz</groupId> <groupId>com.squareup.keywhiz</groupId>
<artifactId>keywhiz-testing</artifactId> <artifactId>keywhiz-testing</artifactId>
Expand Down
31 changes: 31 additions & 0 deletions api/src/test/java/keywhiz/api/CreateClientRequestTest.java
Expand Up @@ -16,17 +16,48 @@


package keywhiz.api; package keywhiz.api;


import io.dropwizard.testing.junit.ResourceTestRule;
import javax.validation.ConstraintViolationException;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.ProcessingException;
import org.junit.ClassRule;
import org.junit.Test; import org.junit.Test;


import static javax.ws.rs.client.Entity.entity;
import static keywhiz.testing.JsonHelpers.fromJson; import static keywhiz.testing.JsonHelpers.fromJson;
import static keywhiz.testing.JsonHelpers.jsonFixture; import static keywhiz.testing.JsonHelpers.jsonFixture;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.catchThrowable;


public class CreateClientRequestTest { public class CreateClientRequestTest {
@ClassRule public static final ResourceTestRule resources = ResourceTestRule.builder()
.addResource(new Resource())
.build();

@Test public void deserializesCorrectly() throws Exception { @Test public void deserializesCorrectly() throws Exception {
CreateClientRequest createClientRequest = new CreateClientRequest("client-name"); CreateClientRequest createClientRequest = new CreateClientRequest("client-name");
assertThat(fromJson( assertThat(fromJson(
jsonFixture("fixtures/createClientRequest.json"), CreateClientRequest.class)) jsonFixture("fixtures/createClientRequest.json"), CreateClientRequest.class))
.isEqualTo(createClientRequest); .isEqualTo(createClientRequest);
} }

@Test public void emptyNameFailsValidation() throws Exception {
CreateClientRequest createClientRequest = new CreateClientRequest("");
Throwable exception = catchThrowable(() ->
resources.client().target("/").request()
.post(entity(createClientRequest, "application/json")));

assertThat(exception)
.isInstanceOf(ProcessingException.class)
.hasCauseInstanceOf(ConstraintViolationException.class);
}

@Path("/") private static class Resource {
@POST @Consumes("application/json") public String method(@Valid CreateClientRequest request) {
throw new UnsupportedOperationException();
}
}
} }
16 changes: 9 additions & 7 deletions cli/src/main/java/keywhiz/cli/commands/AddAction.java
Expand Up @@ -18,14 +18,12 @@


import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.text.ParseException; import java.text.ParseException;
import java.util.Base64;
import java.util.List; import java.util.List;
import keywhiz.api.SecretDetailResponse; import keywhiz.api.SecretDetailResponse;
import keywhiz.api.model.Group; import keywhiz.api.model.Group;
Expand Down Expand Up @@ -104,7 +102,7 @@ public AddAction(AddActionConfig addActionConfig, KeywhizClient client, ObjectMa
throw Throwables.propagate(e); throw Throwables.propagate(e);
} }
String secretName = parts[0]; String secretName = parts[0];
String content = readSecretContent(); byte[] content = readSecretContent();
ImmutableMap<String, String> metadata = getMetadata(); ImmutableMap<String, String> metadata = getMetadata();


String version = getVersion(parts); String version = getVersion(parts);
Expand Down Expand Up @@ -135,7 +133,7 @@ public AddAction(AddActionConfig addActionConfig, KeywhizClient client, ObjectMa
} }
} }


private void createAndAssignSecret(String secretName, String content, boolean useVersion, private void createAndAssignSecret(String secretName, byte[] content, boolean useVersion,
String version, ImmutableMap<String, String> metadata) { String version, ImmutableMap<String, String> metadata) {
try { try {
SecretDetailResponse secretResponse = keywhizClient.createSecret(secretName, "", content, SecretDetailResponse secretResponse = keywhizClient.createSecret(secretName, "", content,
Expand Down Expand Up @@ -178,11 +176,15 @@ private String getVersion(String[] parts) {
return version; return version;
} }


@VisibleForTesting String readSecretContent() { private byte[] readSecretContent() {
try { try {
return Base64.getEncoder().encodeToString(ByteStreams.toByteArray(stream)); byte[] content = ByteStreams.toByteArray(stream);
if (content.length == 0) {
throw new RuntimeException("Secret content empty!");
}
return content;
} catch (IOException e) { } catch (IOException e) {
throw new AssertionError(e); throw Throwables.propagate(e);
} }
} }


Expand Down
68 changes: 32 additions & 36 deletions cli/src/test/java/keywhiz/cli/commands/AddActionTest.java
Expand Up @@ -72,43 +72,39 @@ public void addCallsAddForGroup() throws Exception {
addActionConfig.addType = Arrays.asList("group"); addActionConfig.addType = Arrays.asList("group");
addActionConfig.name = group.getName(); addActionConfig.name = group.getName();


when(keywhizClient.getGroupByName(group.getName())).thenThrow( when(keywhizClient.getGroupByName(group.getName())).thenThrow(new NotFoundException());
new NotFoundException());


addAction.run(); addAction.run();
verify(keywhizClient) verify(keywhizClient).createGroup(addActionConfig.name, null);
.createGroup(addActionConfig.name, null);
} }


@Test @Test
public void addCallsAddForSecret() throws Exception { public void addCallsAddForSecret() throws Exception {
addActionConfig.addType = Arrays.asList("secret"); addActionConfig.addType = Arrays.asList("secret");
addActionConfig.name = secret.getDisplayName(); addActionConfig.name = secret.getDisplayName();


addAction.stream = new ByteArrayInputStream(base64Decoder.decode(secret.getSecret())); byte[] content = base64Decoder.decode(secret.getSecret());
addAction.stream = new ByteArrayInputStream(content);
when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion())) when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion()))
.thenThrow(new NotFoundException()); // Call checks for existence. .thenThrow(new NotFoundException()); // Call checks for existence.


when(keywhizClient.createSecret(secret.getName(), "", secret.getSecret(), when(keywhizClient.createSecret(secret.getName(), "", content, true, secret.getMetadata()))
true, secret.getMetadata()))
.thenReturn(secretDetailResponse); .thenReturn(secretDetailResponse);


addAction.run(); addAction.run();
verify(keywhizClient, times(1)).createSecret(secret.getName(), "", secret.getSecret(), verify(keywhizClient, times(1))
true, secret.getMetadata()); .createSecret(secret.getName(), "", content, true, secret.getMetadata());
} }


@Test @Test
public void addCallsAddForClient() throws Exception { public void addCallsAddForClient() throws Exception {
addActionConfig.addType = Arrays.asList("client"); addActionConfig.addType = Arrays.asList("client");
addActionConfig.name = client.getName(); addActionConfig.name = client.getName();


when(keywhizClient.getClientByName(client.getName())).thenThrow( when(keywhizClient.getClientByName(client.getName())).thenThrow(new NotFoundException());
new NotFoundException());


addAction.run(); addAction.run();
verify(keywhizClient) verify(keywhizClient).createClient(addActionConfig.name);
.createClient(addActionConfig.name);
} }


@Test @Test
Expand All @@ -117,14 +113,15 @@ public void addSecretCanAssignGroup() throws Exception {
addActionConfig.name = secret.getDisplayName(); addActionConfig.name = secret.getDisplayName();
addActionConfig.group = group.getName(); addActionConfig.group = group.getName();


addAction.stream = new ByteArrayInputStream(base64Decoder.decode(secret.getSecret())); byte[] content = base64Decoder.decode(secret.getSecret());
addAction.stream = new ByteArrayInputStream(content);
when(keywhizClient.getGroupByName(group.getName())).thenReturn(group); when(keywhizClient.getGroupByName(group.getName())).thenReturn(group);


when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion())) when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion()))
.thenThrow(new NotFoundException()); // Call checks for existence. .thenThrow(new NotFoundException()); // Call checks for existence.


when(keywhizClient.createSecret(secret.getName(), "", secret.getSecret(), when(keywhizClient.createSecret(secret.getName(), "", content, true, secret.getMetadata()))
true, secret.getMetadata())).thenReturn(secretDetailResponse); .thenReturn(secretDetailResponse);


addAction.run(); addAction.run();
verify(keywhizClient).grantSecretToGroupByIds((int) secret.getId(), (int) group.getId()); verify(keywhizClient).grantSecretToGroupByIds((int) secret.getId(), (int) group.getId());
Expand All @@ -135,38 +132,37 @@ public void addCreatesWithoutVersionByDefault() throws Exception {
addActionConfig.addType = Arrays.asList("secret"); addActionConfig.addType = Arrays.asList("secret");
addActionConfig.name = secret.getName(); // Name without version addActionConfig.name = secret.getName(); // Name without version


addAction.stream = new ByteArrayInputStream(base64Decoder.decode(secret.getSecret())); byte[] content = base64Decoder.decode(secret.getSecret());
addAction.stream = new ByteArrayInputStream(content);


when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), "")) when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), ""))
.thenThrow(new NotFoundException()); // Call checks for existence. .thenThrow(new NotFoundException()); // Call checks for existence.


when(keywhizClient.createSecret(secret.getName(), "", secret.getSecret(), when(keywhizClient.createSecret(secret.getName(), "", content, false, secret.getMetadata()))
false, secret.getMetadata()))
.thenReturn(secretDetailResponse); .thenReturn(secretDetailResponse);


addAction.run(); addAction.run();


verify(keywhizClient, never()).createSecret(secret.getName(), "", secret.getSecret(), verify(keywhizClient, never()).createSecret(secret.getName(), "", content, true, secret.getMetadata());
true, secret.getMetadata());
} }


@Test @Test
public void addCreatesVersionedSecretWhenVersionInName() throws Exception { public void addCreatesVersionedSecretWhenVersionInName() throws Exception {
addActionConfig.addType = Arrays.asList("secret"); addActionConfig.addType = Arrays.asList("secret");
addActionConfig.name = secret.getDisplayName(); // Name includes version, e.g. newSecret..df97a addActionConfig.name = secret.getDisplayName(); // Name includes version, e.g. newSecret..df97a


addAction.stream = new ByteArrayInputStream(base64Decoder.decode(secret.getSecret())); byte[] content = base64Decoder.decode(secret.getSecret());
addAction.stream = new ByteArrayInputStream(content);
when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion())) when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion()))
.thenThrow(new NotFoundException()); // Call checks for existence. .thenThrow(new NotFoundException()); // Call checks for existence.


when(keywhizClient.createSecret(secret.getName(), "", secret.getSecret(), when(keywhizClient.createSecret(secret.getName(), "", content, true, secret.getMetadata()))
true, secret.getMetadata()))
.thenReturn(secretDetailResponse); .thenReturn(secretDetailResponse);


addAction.run(); addAction.run();


verify(keywhizClient, times(1)).createSecret(secret.getName(), "", secret.getSecret(), verify(keywhizClient, times(1))
true, secret.getMetadata()); .createSecret(secret.getName(), "", content, true, secret.getMetadata());
} }


@Test @Test
Expand All @@ -175,18 +171,19 @@ public void addCanCreateWithVersion() throws Exception {
addActionConfig.name = secret.getDisplayName(); addActionConfig.name = secret.getDisplayName();
addActionConfig.withVersion = true; addActionConfig.withVersion = true;


addAction.stream = new ByteArrayInputStream(base64Decoder.decode(secret.getSecret())); byte[] content = base64Decoder.decode(secret.getSecret());
addAction.stream = new ByteArrayInputStream(content);
when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion())) when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion()))
.thenThrow(new NotFoundException()); // Call checks for existence. .thenThrow(new NotFoundException()); // Call checks for existence.


when(keywhizClient.createSecret(secret.getName(), "", secret.getSecret(), when(keywhizClient.createSecret(secret.getName(), "", content, addActionConfig.withVersion,
addActionConfig.withVersion, secret.getMetadata())) secret.getMetadata()))
.thenReturn(secretDetailResponse); .thenReturn(secretDetailResponse);


addAction.run(); addAction.run();


verify(keywhizClient, times(1)).createSecret(secret.getName(), "", secret.getSecret(), verify(keywhizClient, times(1))
true, secret.getMetadata()); .createSecret(secret.getName(), "", content, true, secret.getMetadata());
} }


@Test @Test
Expand All @@ -195,20 +192,19 @@ public void addWithMetadata() throws Exception {
addActionConfig.name = secret.getDisplayName(); addActionConfig.name = secret.getDisplayName();
addActionConfig.json = "{\"owner\":\"example-name\", \"group\":\"example-group\"}"; addActionConfig.json = "{\"owner\":\"example-name\", \"group\":\"example-group\"}";


addAction.stream = new ByteArrayInputStream(base64Decoder.decode(secret.getSecret())); byte[] content = base64Decoder.decode(secret.getSecret());
addAction.stream = new ByteArrayInputStream(content);
when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion())) when(keywhizClient.getSanitizedSecretByNameAndVersion(secret.getName(), secret.getVersion()))
.thenThrow(new NotFoundException()); // Call checks for existence. .thenThrow(new NotFoundException()); // Call checks for existence.


ImmutableMap<String,String> expected = ImmutableMap.of("owner", "example-name", "group", "example-group"); ImmutableMap<String,String> expected = ImmutableMap.of("owner", "example-name", "group", "example-group");


when(keywhizClient.createSecret(secret.getName(), "", secret.getSecret(), when(keywhizClient.createSecret(secret.getName(), "", content, true, expected))
true, expected))
.thenReturn(secretDetailResponse); .thenReturn(secretDetailResponse);


addAction.run(); addAction.run();


verify(keywhizClient, times(1)).createSecret(secret.getName(), "", secret.getSecret(), verify(keywhizClient, times(1)).createSecret(secret.getName(), "", content, true, expected);
true, expected);
} }


@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
Expand Down
27 changes: 22 additions & 5 deletions client/src/main/java/keywhiz/client/KeywhizClient.java
Expand Up @@ -27,6 +27,7 @@
import com.squareup.okhttp.RequestBody; import com.squareup.okhttp.RequestBody;
import com.squareup.okhttp.Response; import com.squareup.okhttp.Response;
import java.io.IOException; import java.io.IOException;
import java.util.Base64;
import java.util.List; import java.util.List;
import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.HttpHeaders;
import keywhiz.api.ClientDetailResponse; import keywhiz.api.ClientDetailResponse;
Expand All @@ -41,6 +42,7 @@
import keywhiz.api.model.SanitizedSecret; import keywhiz.api.model.SanitizedSecret;
import org.apache.http.HttpStatus; import org.apache.http.HttpStatus;


import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static java.lang.String.format; import static java.lang.String.format;


Expand Down Expand Up @@ -123,6 +125,7 @@ public List<Group> allGroups() throws IOException {
} }


public GroupDetailResponse createGroup(String name, String description) throws IOException { public GroupDetailResponse createGroup(String name, String description) throws IOException {
checkArgument(!name.isEmpty());
String response = httpPost(baseUrl.resolve("/admin/groups"), new CreateGroupRequest(name, description)); String response = httpPost(baseUrl.resolve("/admin/groups"), new CreateGroupRequest(name, description));
return mapper.readValue(response, GroupDetailResponse.class); return mapper.readValue(response, GroupDetailResponse.class);
} }
Expand All @@ -141,9 +144,13 @@ public List<SanitizedSecret> allSecrets() throws IOException {
return mapper.readValue(response, new TypeReference<List<SanitizedSecret>>() {}); return mapper.readValue(response, new TypeReference<List<SanitizedSecret>>() {});
} }


public SecretDetailResponse createSecret(String name, String description, String content, boolean withVersion, public SecretDetailResponse createSecret(String name, String description, byte[] content, boolean withVersion,
ImmutableMap<String, String> metadata) throws IOException { ImmutableMap<String, String> metadata) throws IOException {
CreateSecretRequest request = new CreateSecretRequest(name, description, content, checkArgument(!name.isEmpty());
checkArgument(content.length > 0, "Content must not be empty");

String b64Content = Base64.getEncoder().encodeToString(content);
CreateSecretRequest request = new CreateSecretRequest(name, description, b64Content,
withVersion, metadata); withVersion, metadata);
String response = httpPost(baseUrl.resolve("/admin/secrets"), request); String response = httpPost(baseUrl.resolve("/admin/secrets"), request);
return mapper.readValue(response, SecretDetailResponse.class); return mapper.readValue(response, SecretDetailResponse.class);
Expand All @@ -159,13 +166,17 @@ public void deleteSecretWithId(int secretId) throws IOException {
} }


public <T> List<SanitizedSecret> generateSecrets(String generatorName, T params) throws IOException { public <T> List<SanitizedSecret> generateSecrets(String generatorName, T params) throws IOException {
String response = httpPost(baseUrl.resolve(format("/admin/secrets/generators/%s", generatorName)), checkArgument(!generatorName.isEmpty());
String response = httpPost(baseUrl.resolve(
format("/admin/secrets/generators/%s", generatorName)),
params); params);
return mapper.readValue(response, new TypeReference<List<SanitizedSecret>>() {}); return mapper.readValue(response, new TypeReference<List<SanitizedSecret>>() {});
} }


public <T> List<SanitizedSecret> batchGenerateSecrets(String generatorName, List<T> params) throws IOException { public <T> List<SanitizedSecret> batchGenerateSecrets(String generatorName, List<T> params) throws IOException {
String response = httpPost(baseUrl.resolve(format("/admin/secrets/generators/%s/batch", generatorName)), checkArgument(!generatorName.isEmpty());
String response = httpPost(baseUrl.resolve(
format("/admin/secrets/generators/%s/batch", generatorName)),
params); params);
return mapper.readValue(response, new TypeReference<List<SanitizedSecret>>() {}); return mapper.readValue(response, new TypeReference<List<SanitizedSecret>>() {});
} }
Expand All @@ -176,6 +187,7 @@ public List<Client> allClients() throws IOException {
} }


public ClientDetailResponse createClient(String name) throws IOException { public ClientDetailResponse createClient(String name) throws IOException {
checkArgument(!name.isEmpty());
String response = httpPost(baseUrl.resolve("/admin/clients"), new CreateClientRequest(name)); String response = httpPost(baseUrl.resolve("/admin/clients"), new CreateClientRequest(name));
return mapper.readValue(response, ClientDetailResponse.class); return mapper.readValue(response, ClientDetailResponse.class);
} }
Expand Down Expand Up @@ -206,21 +218,26 @@ public void revokeSecretFromGroupByIds(int secretId, int groupId) throws IOExcep
} }


public Client getClientByName(String name) throws IOException { public Client getClientByName(String name) throws IOException {
checkArgument(!name.isEmpty());
String response = httpGet(baseUrl.resolve(format("/admin/clients?name=%s", name))); String response = httpGet(baseUrl.resolve(format("/admin/clients?name=%s", name)));
return mapper.readValue(response, Client.class); return mapper.readValue(response, Client.class);
} }


public Group getGroupByName(String name) throws IOException { public Group getGroupByName(String name) throws IOException {
checkArgument(!name.isEmpty());
String response = httpGet(baseUrl.resolve(format("/admin/groups?name=%s", name))); String response = httpGet(baseUrl.resolve(format("/admin/groups?name=%s", name)));
return mapper.readValue(response, Group.class); return mapper.readValue(response, Group.class);
} }


public SanitizedSecret getSanitizedSecretByNameAndVersion(String name, String version) throws IOException { public SanitizedSecret getSanitizedSecretByNameAndVersion(String name, String version) throws IOException {
String response = httpGet(baseUrl.resolve(format("/admin/secrets?name=%s&version=%s", name, version))); checkArgument(!name.isEmpty());
String response = httpGet(baseUrl.resolve(
format("/admin/secrets?name=%s&version=%s", name, version)));
return mapper.readValue(response, SanitizedSecret.class); return mapper.readValue(response, SanitizedSecret.class);
} }


public List<String> getVersionsForSecretName(String name) throws IOException { public List<String> getVersionsForSecretName(String name) throws IOException {
checkNotNull(name);
String response = httpGet(baseUrl.resolve(format("/admin/secrets/versions?name=%s", name))); String response = httpGet(baseUrl.resolve(format("/admin/secrets/versions?name=%s", name)));
return mapper.readValue(response, new TypeReference<List<String>>() {}); return mapper.readValue(response, new TypeReference<List<String>>() {});
} }
Expand Down
Expand Up @@ -99,12 +99,6 @@ public void createDuplicateClients() throws IOException {
keywhizClient.createClient("varys"); keywhizClient.createClient("varys");
} }


@Test(expected = KeywhizClient.ValidationException.class)
public void creatingClientWithEmptyNameFailsOnValidation() throws IOException {
keywhizClient.login(DbSeedCommand.defaultUser, DbSeedCommand.defaultPassword.toCharArray());
keywhizClient.createClient("");
}

@Test(expected = KeywhizClient.NotFoundException.class) @Test(expected = KeywhizClient.NotFoundException.class)
public void notFoundOnMissingId() throws IOException { public void notFoundOnMissingId() throws IOException {
keywhizClient.login(DbSeedCommand.defaultUser, DbSeedCommand.defaultPassword.toCharArray()); keywhizClient.login(DbSeedCommand.defaultUser, DbSeedCommand.defaultPassword.toCharArray());
Expand Down

0 comments on commit 554221b

Please sign in to comment.