Skip to content

Commit

Permalink
DATAMONGO-2341 - Support shard key derivation in save operations via @…
Browse files Browse the repository at this point in the history
…Sharded annotation.

Spring Data MongoDB uses the @Sharded annotation to identify entities stored in sharded collections.
The shard key consists of a single or multiple properties present in every document within the target collection, and is used to distribute them across shards.

Spring Data MongoDB will do best effort optimisations for sharded scenarios when using repositories by adding required shard key information, if not already present, to replaceOne filter queries when upserting entities. This may require an additional server round trip to determine the actual value of the current shard key.

By setting @Sharded(immutableKey = true) no attempt will be made to check if an entities shard key changed.

Please see the MongoDB Documentation for further details and the list below for which operations are eligible to auto include the shard key.

* Reactive/CrudRepository.save(...)
* Reactive/CrudRepository.saveAll(...)
* Reactive/MongoTemplate.save(...)

Original pull request: #833.
  • Loading branch information
christophstrobl authored and mp911de committed Feb 17, 2020
1 parent f153399 commit 6259cd2
Show file tree
Hide file tree
Showing 20 changed files with 1,093 additions and 24 deletions.
Expand Up @@ -80,7 +80,11 @@ public boolean isIdPresent(Class<?> type) {
}

public Bson getIdFilter() {
return Filters.eq(ID_FIELD, document.get(ID_FIELD));
return new Document(ID_FIELD, document.get(ID_FIELD));
}

public Object get(String key) {
return document.get(key);
}

public UpdateDefinition updateWithoutId() {
Expand Down
Expand Up @@ -33,7 +33,6 @@
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
Expand Down Expand Up @@ -1480,23 +1479,38 @@ protected Object saveDocument(String collectionName, Document dbDoc, Class<?> en
}

return execute(collectionName, collection -> {

MongoAction mongoAction = new MongoAction(writeConcern, MongoActionOperation.SAVE, collectionName, entityClass,
dbDoc, null);
WriteConcern writeConcernToUse = prepareWriteConcern(mongoAction);

MappedDocument mapped = MappedDocument.of(dbDoc);

MongoCollection<Document> collectionToUse = writeConcernToUse == null //
? collection //
: collection.withWriteConcern(writeConcernToUse);

if (!mapped.hasId()) {
if (writeConcernToUse == null) {
collection.insertOne(dbDoc);
} else {
collection.withWriteConcern(writeConcernToUse).insertOne(dbDoc);
}
} else if (writeConcernToUse == null) {
collection.replaceOne(mapped.getIdFilter(), dbDoc, new ReplaceOptions().upsert(true));
collectionToUse.insertOne(dbDoc);
} else {
collection.withWriteConcern(writeConcernToUse).replaceOne(mapped.getIdFilter(), dbDoc,
new ReplaceOptions().upsert(true));

MongoPersistentEntity<?> entity = mappingContext.getPersistentEntity(entityClass);
UpdateContext updateContext = queryOperations.replaceSingleContext(mapped, true);
Document replacement = updateContext.getMappedUpdate(entity);

Document filter = updateContext.getMappedQuery(entity);

if (updateContext.requiresShardKey(filter, entity)) {

if (entity.getShardKey().isImmutable()) {
filter = updateContext.applyShardKey(entity, filter, null);
} else {
filter = updateContext.applyShardKey(entity, filter,
collection.find(filter, Document.class).projection(updateContext.getMappedShardKey(entity)).first());
}
}

collectionToUse.replaceOne(filter, replacement, new ReplaceOptions().upsert(true));
}
return mapped.getId();
});
Expand Down Expand Up @@ -1615,8 +1629,20 @@ protected UpdateResult doUpdate(String collectionName, Query query, UpdateDefini

if (!UpdateMapper.isUpdateObject(updateObj)) {

Document filter = new Document(queryObj);

if (updateContext.requiresShardKey(filter, entity)) {

if (entity.getShardKey().isImmutable()) {
filter = updateContext.applyShardKey(entity, filter, null);
} else {
filter = updateContext.applyShardKey(entity, filter,
collection.find(filter, Document.class).projection(updateContext.getMappedShardKey(entity)).first());
}
}

ReplaceOptions replaceOptions = updateContext.getReplaceOptions(entityClass);
return collection.replaceOne(queryObj, updateObj, replaceOptions);
return collection.replaceOne(filter, updateObj, replaceOptions);
} else {
return multi ? collection.updateMany(queryObj, updateObj, opts)
: collection.updateOne(queryObj, updateObj, opts);
Expand Down
Expand Up @@ -16,7 +16,10 @@
package org.springframework.data.mongodb.core;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -154,6 +157,15 @@ UpdateContext updateSingleContext(UpdateDefinition updateDefinition, Document qu
return new UpdateContext(updateDefinition, query, false, upsert);
}

/**
* @param replacement the {@link MappedDocument mapped replacement} document.
* @param upsert use {@literal true} to insert diff when no existing document found.
* @return new instance of {@link UpdateContext}.
*/
UpdateContext replaceSingleContext(MappedDocument replacement, boolean upsert) {
return new UpdateContext(replacement, upsert);
}

/**
* Create a new {@link DeleteContext} instance removing all matching documents.
*
Expand Down Expand Up @@ -253,7 +265,6 @@ Document getMappedFields(@Nullable MongoPersistentEntity<?> entity) {
*/
Document getMappedSort(@Nullable MongoPersistentEntity<?> entity) {
return queryMapper.getMappedSort(query.getSortObject(), entity);

}

/**
Expand Down Expand Up @@ -353,7 +364,6 @@ Class<?> getMostSpecificConversionTargetType(Class<?> requestedTargetType, Class
if (ClassUtils.isAssignable(requestedTargetType, propertyType)) {
conversionTargetType = propertyType;
}

} catch (PropertyReferenceException e) {
// just don't care about it as we default to Object.class anyway.
}
Expand Down Expand Up @@ -491,7 +501,9 @@ class UpdateContext extends QueryContext {

private final boolean multi;
private final boolean upsert;
private final UpdateDefinition update;
private final @Nullable UpdateDefinition update;
private final @Nullable MappedDocument mappedDocument;
private final Map<Class<?>, Document> mappedShardKey = new ConcurrentHashMap<>(1);

/**
* Create a new {@link UpdateContext} instance.
Expand Down Expand Up @@ -520,6 +532,16 @@ class UpdateContext extends QueryContext {
this.multi = multi;
this.upsert = upsert;
this.update = update;
this.mappedDocument = null;
}

UpdateContext(MappedDocument update, boolean upsert) {

super(new BasicQuery(new Document(BsonUtils.asMap(update.getIdFilter()))));
this.multi = false;
this.upsert = upsert;
this.mappedDocument = update;
this.update = null;
}

/**
Expand All @@ -544,7 +566,7 @@ UpdateOptions getUpdateOptions(@Nullable Class<?> domainType, @Nullable Consumer
UpdateOptions options = new UpdateOptions();
options.upsert(upsert);

if (update.hasArrayFilters()) {
if (update != null && update.hasArrayFilters()) {
options
.arrayFilters(update.getArrayFilters().stream().map(ArrayFilter::asDocument).collect(Collectors.toList()));
}
Expand Down Expand Up @@ -602,6 +624,45 @@ <T> Document getMappedQuery(@Nullable MongoPersistentEntity<T> domainType) {
return mappedQuery;
}

<T> Document applyShardKey(@Nullable MongoPersistentEntity<T> domainType, Document filter,
@Nullable Document existing) {

Document shardKeySource = existing != null ? existing
: mappedDocument != null ? mappedDocument.getDocument() : getMappedUpdate(domainType);

Document filterWithShardKey = new Document(filter);
for (String key : getMappedShardKeyFields(domainType)) {
if (!filterWithShardKey.containsKey(key)) {
filterWithShardKey.append(key, shardKeySource.get(key));
}
}

return filterWithShardKey;
}

<T> boolean requiresShardKey(Document filter, @Nullable MongoPersistentEntity<T> domainType) {

if (multi || domainType == null || !domainType.isSharded() || domainType.idPropertyIsShardKey()) {
return false;
}

if (filter.keySet().containsAll(getMappedShardKeyFields(domainType))) {
return false;
}

return true;
}

Set<String> getMappedShardKeyFields(@Nullable MongoPersistentEntity<?> entity) {
return getMappedShardKey(entity).keySet();
}

Document getMappedShardKey(@Nullable MongoPersistentEntity<?> entity) {

return mappedShardKey.computeIfAbsent(entity.getType(),
key -> queryMapper.getMappedFields(entity.getShardKey().getDocument(), entity));
}

/**
* Get the already mapped aggregation pipeline to use with an {@link #isAggregationUpdate()}.
*
Expand All @@ -625,8 +686,11 @@ List<Document> getUpdatePipeline(@Nullable Class<?> domainType) {
*/
Document getMappedUpdate(@Nullable MongoPersistentEntity<?> entity) {

return update instanceof MappedUpdate ? update.getUpdateObject()
: updateMapper.getMappedObject(update.getUpdateObject(), entity);
if (update != null) {
return update instanceof MappedUpdate ? update.getUpdateObject()
: updateMapper.getMappedObject(update.getUpdateObject(), entity);
}
return mappedDocument.getDocument();
}

/**
Expand Down
Expand Up @@ -39,7 +39,6 @@
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
Expand Down Expand Up @@ -1638,9 +1637,31 @@ protected Mono<Object> saveDocument(String collectionName, Document document, Cl
? collection //
: collection.withWriteConcern(writeConcernToUse);

Publisher<?> publisher = !mapped.hasId() //
? collectionToUse.insertOne(document) //
: collectionToUse.replaceOne(mapped.getIdFilter(), document, new ReplaceOptions().upsert(true));
Publisher<?> publisher = null;
if (!mapped.hasId()) {
publisher = collectionToUse.insertOne(document);
} else {

MongoPersistentEntity<?> entity = mappingContext.getPersistentEntity(entityClass);
UpdateContext updateContext = queryOperations.replaceSingleContext(mapped, true);
Document filter = updateContext.getMappedQuery(entity);
Document replacement = updateContext.getMappedUpdate(entity);

Mono<Document> theFilter = Mono.just(filter);

if(updateContext.requiresShardKey(filter, entity)) {
if (entity.getShardKey().isImmutable()) {
theFilter = Mono.just(updateContext.applyShardKey(entity, filter, null));
} else {
theFilter = Mono.from(
collection.find(filter, Document.class).projection(updateContext.getMappedShardKey(entity)).first())
.defaultIfEmpty(replacement).map(it -> updateContext.applyShardKey(entity, filter, it));
}
}

publisher = theFilter.flatMap(
it -> Mono.from(collectionToUse.replaceOne(it, replacement, updateContext.getReplaceOptions(entityClass))));
}

return Mono.from(publisher).map(o -> mapped.getId());
});
Expand Down Expand Up @@ -1778,8 +1799,21 @@ protected Mono<UpdateResult> doUpdate(String collectionName, Query query, @Nulla

if (!UpdateMapper.isUpdateObject(updateObj)) {

Document filter = new Document(queryObj);
Mono<Document> theFilter = Mono.just(filter);

if(updateContext.requiresShardKey(filter, entity)) {
if (entity.getShardKey().isImmutable()) {
theFilter = Mono.just(updateContext.applyShardKey(entity, filter, null));
} else {
theFilter = Mono.from(
collection.find(filter, Document.class).projection(updateContext.getMappedShardKey(entity)).first())
.defaultIfEmpty(updateObj).map(it -> updateContext.applyShardKey(entity, filter, it));
}
}

ReplaceOptions replaceOptions = updateContext.getReplaceOptions(entityClass);
return collectionToUse.replaceOne(queryObj, updateObj, replaceOptions);
return theFilter.flatMap(it -> Mono.from(collectionToUse.replaceOne(it, updateObj, replaceOptions)));
}

return multi ? collectionToUse.updateMany(queryObj, updateObj, updateOptions)
Expand Down
Expand Up @@ -37,6 +37,7 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

/**
Expand All @@ -63,6 +64,8 @@ public class BasicMongoPersistentEntity<T> extends BasicPersistentEntity<T, Mong
private final @Nullable String collation;
private final @Nullable Expression collationExpression;

private final ShardKey shardKey;

/**
* Creates a new {@link BasicMongoPersistentEntity} with the given {@link TypeInformation}. Will default the
* collection name to the entities simple type name.
Expand Down Expand Up @@ -92,6 +95,8 @@ public BasicMongoPersistentEntity(TypeInformation<T> typeInformation) {
this.collation = null;
this.collationExpression = null;
}

this.shardKey = detectShardKey(this);
}

/*
Expand Down Expand Up @@ -160,6 +165,11 @@ public org.springframework.data.mongodb.core.query.Collation getCollation() {
: null;
}

@Override
public ShardKey getShardKey() {
return shardKey;
}

/*
* (non-Javadoc)
* @see org.springframework.data.mapping.model.BasicPersistentEntity#verify()
Expand Down Expand Up @@ -297,6 +307,26 @@ private static Expression detectExpression(@Nullable String potentialExpression)
return expression instanceof LiteralExpression ? null : expression;
}

@Nullable
private static ShardKey detectShardKey(BasicMongoPersistentEntity<?> entity) {

if (!entity.isAnnotationPresent(Sharded.class)) {
return ShardKey.none();
}

Sharded sharded = entity.getRequiredAnnotation(Sharded.class);

String[] keyProperties = sharded.shardKey();
if (ObjectUtils.isEmpty(keyProperties)) {
keyProperties = new String[] { "_id" };
}

ShardKey shardKey = ShardingStrategy.HASH.equals(sharded.shardingStrategy()) ? ShardKey.hash(keyProperties)
: ShardKey.range(keyProperties);

return sharded.immutableKey() ? ShardKey.immutable(shardKey) : shardKey;
}

/**
* Handler to collect {@link MongoPersistentProperty} instances and check that each of them is mapped to a distinct
* field name.
Expand Down

0 comments on commit 6259cd2

Please sign in to comment.