Skip to content
This repository has been archived by the owner on Nov 22, 2023. It is now read-only.

Commit

Permalink
Adding owner support to ClientDAO (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmontgomery-square committed Dec 6, 2022
1 parent 62f0575 commit 2aae402
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 30 deletions.
4 changes: 3 additions & 1 deletion api/src/main/java/keywhiz/api/model/Client.java
Expand Up @@ -72,7 +72,7 @@ public class Client {
private final boolean enabled;

@JsonProperty
private final String owner;
private String owner;

/**
* True if client is enabled to do automationAllowed tasks.
Expand Down Expand Up @@ -191,6 +191,8 @@ public boolean isEnabled() {

public String getOwner() { return owner; }

public void setOwner(String owner) { this.owner = owner; }

public boolean isAutomationAllowed() {
return automationAllowed;
}
Expand Down
92 changes: 80 additions & 12 deletions server/src/main/java/keywhiz/service/daos/ClientDAO.java
Expand Up @@ -24,27 +24,37 @@
import java.time.OffsetDateTime;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.inject.Inject;
import keywhiz.api.model.Client;
import keywhiz.auth.mutualssl.CertificatePrincipal;
import keywhiz.jooq.tables.Groups;
import keywhiz.jooq.tables.records.ClientsRecord;
import keywhiz.jooq.tables.records.GroupsRecord;
import keywhiz.service.config.Readonly;
import keywhiz.service.crypto.RowHmacGenerator;
import org.jooq.Condition;
import org.jooq.Configuration;
import org.jooq.DSLContext;
import org.jooq.Param;
import org.jooq.Record;
import org.jooq.Result;
import org.jooq.impl.DSL;

import static com.google.common.base.Preconditions.checkNotNull;
import static java.time.Instant.EPOCH;
import static keywhiz.jooq.tables.Clients.CLIENTS;
import static keywhiz.jooq.tables.Groups.GROUPS;
import static keywhiz.jooq.tables.Memberships.MEMBERSHIPS;
import static org.jooq.impl.DSL.greatest;
import static org.jooq.impl.DSL.when;

public class ClientDAO {
private final static Duration LAST_SEEN_THRESHOLD = Duration.ofSeconds(24 * 60 * 60);
private static final Groups CLIENT_OWNERS = GROUPS.as("owners");
private static final Duration LAST_SEEN_THRESHOLD = Duration.ofSeconds(24 * 60 * 60);
private static final Long NO_OWNER = null;

private final DSLContext dslContext;
private final ClientMapper clientMapper;
Expand All @@ -57,8 +67,25 @@ private ClientDAO(DSLContext dslContext, ClientMapper clientMapper,
this.rowHmacGenerator = rowHmacGenerator;
}

public long createClient(String name, String user, String description,
public long createClient(
String name,
String user,
String description,
@Nullable URI spiffeId) {
return createClient(
name,
user,
description,
spiffeId,
NO_OWNER);
}

public long createClient(
String name,
String user,
String description,
@Nullable URI spiffeId,
@Nullable Long ownerId) {
ClientsRecord r = dslContext.newRecord(CLIENTS);

long now = OffsetDateTime.now().toEpochSecond();
Expand All @@ -85,6 +112,7 @@ public long createClient(String name, String user, String description,
r.setAutomationallowed(false);
r.setSpiffeId(spiffeStr);
r.setRowHmac(rowHmac);
r.setOwner(ownerId);
r.store();

return r.getId();
Expand Down Expand Up @@ -138,26 +166,66 @@ public void sawClient(Client client, @Nullable Principal principal) {
}

public Optional<Client> getClientByName(String name) {
ClientsRecord r = dslContext.fetchOne(CLIENTS, CLIENTS.NAME.eq(name));
return Optional.ofNullable(r).map(clientMapper::map);
return getClient(CLIENTS.NAME.eq(name));
}

public Optional<Client> getClientBySpiffeId(URI spiffeId) {
ClientsRecord r = dslContext.fetchOne(CLIENTS, CLIENTS.SPIFFE_ID.eq(spiffeId.toASCIIString()));
return Optional.ofNullable(r).map(clientMapper::map);
return getClient(CLIENTS.SPIFFE_ID.eq(spiffeId.toASCIIString()));
}

public Optional<Client> getClientById(long id) {
ClientsRecord r = dslContext.fetchOne(CLIENTS, CLIENTS.ID.eq(id));
return Optional.ofNullable(r).map(clientMapper::map);
return getClient(CLIENTS.ID.eq(id));
}

private Optional<Client> getClient(Condition condition) {
Record record = dslContext
.select(CLIENTS.fields())
.select(CLIENT_OWNERS.ID, CLIENT_OWNERS.NAME)
.from(CLIENTS)
.leftJoin(CLIENT_OWNERS)
.on(CLIENTS.OWNER.eq(CLIENT_OWNERS.ID))
.where(condition)
.fetchOne();

return Optional.ofNullable(recordToClient(record));
}

private Client recordToClient(Record record) {
if (record == null) {
return null;
}

ClientsRecord clientRecord = record.into(CLIENTS);
GroupsRecord ownerRecord = record.into(CLIENT_OWNERS);

boolean danglingOwner = clientRecord.getOwner() != null && ownerRecord.getId() == null;
if (danglingOwner) {
throw new IllegalStateException(
String.format(
"Owner %s for client %s is missing.",
clientRecord.getOwner(),
clientRecord.getName()));
}

Client client = clientMapper.map(clientRecord);
if (ownerRecord != null) {
client.setOwner(ownerRecord.getName());
}

return client;
}

public ImmutableSet<Client> getClients() {
List<Client> r = dslContext
.selectFrom(CLIENTS)
List<Client> clients = dslContext
.select(CLIENTS.fields())
.select(CLIENT_OWNERS.NAME)
.from(CLIENTS)
.leftJoin(CLIENT_OWNERS)
.on(CLIENTS.OWNER.eq(CLIENT_OWNERS.ID))
.fetch()
.map(clientMapper);
return ImmutableSet.copyOf(r);
.map(this::recordToClient);

return ImmutableSet.copyOf(clients);
}

public static class ClientDAOFactory implements DAOFactory<ClientDAO> {
Expand Down
35 changes: 23 additions & 12 deletions server/src/main/java/keywhiz/service/daos/GroupDAO.java
Expand Up @@ -30,13 +30,15 @@
import keywhiz.jooq.tables.Groups;
import keywhiz.jooq.tables.records.GroupsRecord;
import keywhiz.service.config.Readonly;
import org.jooq.Condition;
import org.jooq.Configuration;
import org.jooq.DSLContext;
import org.jooq.Record;
import org.jooq.Result;
import org.jooq.impl.DSL;

import static com.google.common.base.Preconditions.checkNotNull;
import static keywhiz.jooq.Tables.CLIENTS;
import static keywhiz.jooq.tables.Accessgrants.ACCESSGRANTS;
import static keywhiz.jooq.tables.Groups.GROUPS;
import static keywhiz.jooq.tables.Memberships.MEMBERSHIPS;
Expand Down Expand Up @@ -125,30 +127,30 @@ public void deleteGroup(Group group) {
.set(GROUPS.OWNER, (Long) null)
.where(GROUPS.OWNER.eq(group.getId()))
.execute();
DSL.using(configuration)
.update(CLIENTS)
.set(CLIENTS.OWNER, (Long) null)
.where(CLIENTS.OWNER.eq(group.getId()))
.execute();
});
}

public Optional<Group> getGroup(String name) {
Record record = dslContext
.select(GROUPS.fields())
.select(GROUP_OWNERS.NAME)
.from(GROUPS)
.leftJoin(GROUP_OWNERS)
.on(GROUPS.OWNER.eq(GROUP_OWNERS.ID))
.where(GROUPS.NAME.eq(name))
.fetchOne();

return Optional.ofNullable(recordToGroup(record));
return getGroup(GROUPS.NAME.eq(name));
}

public Optional<Group> getGroupById(long id) {
return getGroup(GROUPS.ID.eq(id));
}

private Optional<Group> getGroup(Condition condition) {
Record record = dslContext
.select(GROUPS.fields())
.select(GROUP_OWNERS.NAME)
.select(GROUP_OWNERS.ID, GROUP_OWNERS.NAME)
.from(GROUPS)
.leftJoin(GROUP_OWNERS)
.on(GROUPS.OWNER.eq(GROUP_OWNERS.ID))
.where(GROUPS.ID.eq(id))
.where(condition)
.fetchOne();

return Optional.ofNullable(recordToGroup(record));
Expand All @@ -162,6 +164,15 @@ private Group recordToGroup(Record record) {
GroupsRecord groupRecord = record.into(GROUPS);
GroupsRecord ownerRecord = record.into(GROUP_OWNERS);

boolean danglingOwner = groupRecord.getOwner() != null && ownerRecord.getId() == null;
if (danglingOwner) {
throw new IllegalStateException(
String.format(
"Owner %s for group %s is missing.",
groupRecord.getOwner(),
groupRecord.getName()));
}

Group group = groupMapper.map(groupRecord);
if (ownerRecord != null) {
group.setOwner(ownerRecord.getName());
Expand Down
98 changes: 94 additions & 4 deletions server/src/test/java/keywhiz/service/daos/ClientDAOTest.java
Expand Up @@ -16,10 +16,12 @@

package keywhiz.service.daos;

import com.google.common.collect.ImmutableMap;
import java.net.URI;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.util.Set;
import java.util.UUID;
import javax.inject.Inject;
import keywhiz.KeywhizTestRunner;
import keywhiz.api.ApiDate;
Expand All @@ -34,21 +36,29 @@
import static java.time.temporal.ChronoField.NANO_OF_SECOND;
import static keywhiz.jooq.tables.Clients.CLIENTS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@RunWith(KeywhizTestRunner.class)
public class ClientDAOTest {
private static final ImmutableMap<String, String> NO_METADATA = ImmutableMap.of();
private static final Long NO_OWNER = null;
private static final URI NO_SPIFFE_URI = null;

@Inject DSLContext jooqContext;
@Inject ClientDAOFactory clientDAOFactory;
@Inject GroupDAO.GroupDAOFactory groupDAOFactory;

Client client1, client2;
ClientDAO clientDAO;
private Client client1, client2;
private ClientDAO clientDAO;
private GroupDAO groupDAO;

@Before public void setUp() {

clientDAO = clientDAOFactory.readwrite();
groupDAO = groupDAOFactory.readwrite();

long now = OffsetDateTime.now().toEpochSecond();

jooqContext.insertInto(CLIENTS, CLIENTS.NAME, CLIENTS.DESCRIPTION, CLIENTS.CREATEDBY,
Expand Down Expand Up @@ -102,6 +112,16 @@ public class ClientDAOTest {
assertThat(clientDAO.getClientByName("non-existent")).isEmpty();
}

@Test public void getClientByNamePopulatesOwner() {
String ownerName = randomName();
long ownerId = createGroup(ownerName);

String clientName = randomName();
createClient(clientName, ownerId);
Client client = getClientByName(clientName);
assertEquals(ownerName, client.getOwner());
}

@Test public void getClientById() {
Client client = clientDAO.getClientById(client1.getId()).orElseThrow(RuntimeException::new);
assertThat(client).isEqualTo(client1);
Expand All @@ -111,11 +131,44 @@ public class ClientDAOTest {
assertThat(clientDAO.getClientById(-1)).isEmpty();
}

@Test public void getClientByIdPopulatesOwner() {
String ownerName = randomName();
long ownerId = createGroup(ownerName);

long clientId = createClient(randomName(), ownerId);
Client client = getClientById(clientId);
assertEquals(ownerName, client.getOwner());
}

@Test public void getClientBySpiffeIdPopulatesOwner() throws Exception {
String ownerName = randomName();
long ownerId = createGroup(ownerName);

URI spiffeId = new URI("spiffe://test.env.com/" + randomName());
createClientWithSpiffeId(randomName(), spiffeId, ownerId);

Client client = getClientBySpiffeId(spiffeId);
assertEquals(ownerName, client.getOwner());
}

@Test public void getsClients() {
Set<Client> clients = clientDAO.getClients();
assertThat(clients).containsOnly(client1, client2);
}

@Test public void getClientsPopulatesOwner() {
String ownerName = randomName();
long ownerId = createGroup(ownerName);

long clientId = createClient(randomName(), ownerId);

Client client = clientDAO.getClients().stream()
.filter(x -> x.getId() == clientId)
.findFirst()
.get();
assertEquals(ownerName, client.getOwner());
}

@Test public void sawClientTest() {
assertThat(client1.getLastSeen()).isNull();
assertThat(client2.getLastSeen()).isNull();
Expand Down Expand Up @@ -144,8 +197,45 @@ public class ClientDAOTest {
assertThat(client2v2.getExpiration()).isNull();
}


private int tableSize() {
return jooqContext.fetchCount(CLIENTS);
}

private long createClient(String name, Long ownerId) {
return clientDAO.createClient(
name,
"user",
"description",
NO_SPIFFE_URI,
ownerId);
}

private long createClientWithSpiffeId(String name, URI spiffeId, Long ownerId) {
return clientDAO.createClient(
name,
"user",
"description",
spiffeId,
ownerId);
}

private long createGroup(String name) {
return groupDAO.createGroup(name, "creator", "description", NO_METADATA, NO_OWNER);
}

private Client getClientById(long id) {
return clientDAO.getClientById(id).get();
}

private Client getClientByName(String name) {
return clientDAO.getClientByName(name).get();
}

private Client getClientBySpiffeId(URI spiffeId) {
return clientDAO.getClientBySpiffeId(spiffeId).get();
}

private static String randomName() {
return UUID.randomUUID().toString();
}
}

0 comments on commit 2aae402

Please sign in to comment.