@@ -53,8 +53,10 @@
import javax.persistence.metamodel.PluralAttribute;

import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.jpa.domain.JpaSort.JpaOrder;
import org.springframework.data.mapping.PropertyPath;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@@ -66,6 +68,7 @@
* @author Kevin Raymond
* @author Thomas Darimont
* @author Komi Innocent
* @author Christoph Strobl
*/
public abstract class QueryUtils {

@@ -102,6 +105,10 @@
private static final int QUERY_JOIN_ALIAS_GROUP_INDEX = 2;
private static final int VARIABLE_NAME_GROUP_INDEX = 4;

private static final Pattern PUNCTATION_PATTERN = Pattern.compile(".*((?![\\._])[\\p{Punct}|\\s])");
private static final String FUNCTION_ALIAS_GROUP_NAME = "alias";
private static final Pattern FUNCTION_PATTERN;

static {

StringBuilder builder = new StringBuilder();
@@ -145,6 +152,13 @@
builder.append("\\)");

CONSTRUCTOR_EXPRESSION = compile(builder.toString(), CASE_INSENSITIVE + DOTALL);

builder = new StringBuilder();
builder.append("\\s+"); // at least one space
builder.append("\\w+\\([0-9a-zA-z\\._,\\s']+\\)"); // any function call including parameters within the brackets
builder.append("\\s+[as|AS]+\\s+(?<" + FUNCTION_ALIAS_GROUP_NAME + ">[\\w\\.]+)"); // the potential alias

FUNCTION_PATTERN = compile(builder.toString());
}

/**
@@ -227,9 +241,10 @@ public static String applySorting(String query, Sort sort, String alias) {
}

Set<String> aliases = getOuterJoinAliases(query);
Set<String> functionAliases = getFunctionAliases(query);

for (Order order : sort) {
builder.append(getOrderClause(aliases, alias, order)).append(", ");
builder.append(getOrderClause(aliases, functionAliases, alias, order)).append(", ");
}

builder.delete(builder.length() - 2, builder.length());
@@ -246,9 +261,16 @@ public static String applySorting(String query, Sort sort, String alias) {
* @param order the order object to build the clause for.
* @return
*/
private static String getOrderClause(Set<String> joinAliases, String alias, Order order) {
private static String getOrderClause(Set<String> joinAliases, Set<String> functionAlias, String alias, Order order) {

String property = order.getProperty();

checkSortExpression(order);

if (functionAlias.contains(property)) {
return String.format("%s %s", property, toJpaDirection(order));
}

boolean qualifyReference = !property.contains("("); // ( indicates a function

for (String joinAlias : joinAliases) {
@@ -287,6 +309,28 @@ private static String getOrderClause(Set<String> joinAliases, String alias, Orde
return result;
}

/**
* Returns the aliases used for aggregate functions like {@code SUM, COUNT, ...}.
*
* @param query
* @return
*/
static Set<String> getFunctionAliases(String query) {

Set<String> result = new HashSet<String>();
Matcher matcher = FUNCTION_PATTERN.matcher(query);

while (matcher.find()) {

String alias = matcher.group(FUNCTION_ALIAS_GROUP_NAME);
if (StringUtils.hasText(alias)) {
result.add(alias);
}
}

return result;
}

private static String toJpaDirection(Order order) {
return order.getDirection().name().toLowerCase(Locale.US);
}
@@ -617,4 +661,23 @@ private static boolean isAlreadyFetched(From<?, ?> from, String attribute) {

return false;
}

/**
* Check any given {@link JpaOrder#isUnsafe()} order for presence of at least one property offending the
* {@link #PUNCTATION_PATTERN} and throw an {@link Exception} indicating potential unsafe order by expression.
*
* @param order
*/
private static void checkSortExpression(Order order) {

if (order instanceof JpaOrder && ((JpaOrder) order).isUnsafe()) {
return;
}

if (PUNCTATION_PATTERN.matcher(order.getProperty()).find()) {
throw new InvalidDataAccessApiUsageException(String
.format("Sort expression '%s' must not contain functions or expressions. Please use JpaSort.unsafe.", order));
}
}

}
@@ -1,5 +1,5 @@
/*
* Copyright 2013-2015 the original author or authors.
* Copyright 2013-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -29,6 +29,7 @@
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Direction;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.jpa.domain.JpaSort.JpaOrder;
import org.springframework.data.jpa.domain.JpaSort.Path;
import org.springframework.data.jpa.domain.sample.Address_;
import org.springframework.data.jpa.domain.sample.MailMessage_;
@@ -47,6 +48,7 @@
* @see DATAJPA-12
* @author Thomas Darimont
* @author Oliver Gierke
* @author Christoph Strobl
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration("classpath:infrastructure.xml")
@@ -173,4 +175,56 @@ public void buildsUpPathForPluralAttributesCorrectly() {
assertThat(new JpaSort(path(User_.colleagues).dot(User_.roles).dot(Role_.name)), //
hasItem(new Order(ASC, "colleagues.roles.name")));
}

/**
* @see DATAJPA-???
*/
@Test
public void createsUnsafeSortCorrectly() {

JpaSort sort = JpaSort.unsafe(DESC, "foo.bar");

assertThat(sort, hasItem(new Order(DESC, "foo.bar")));
assertThat(sort.getOrderFor("foo.bar"), is(instanceOf(JpaOrder.class)));
}

/**
* @see DATAJPA-???
*/
@Test
public void createsUnsafeSortWithMultiplePropertiesCorrectly() {

JpaSort sort = JpaSort.unsafe(DESC, "foo.bar", "spring.data");

assertThat(sort, hasItems(new Order(DESC, "foo.bar"), new Order(DESC, "spring.data")));
assertThat(sort.getOrderFor("foo.bar"), is(instanceOf(JpaOrder.class)));
assertThat(sort.getOrderFor("spring.data"), is(instanceOf(JpaOrder.class)));
}

/**
* @see DATAJPA-???
*/
@Test
public void combinesSafeAndUnsafeSortCorrectly() {

JpaSort sort = new JpaSort(path(User_.colleagues).dot(User_.roles).dot(Role_.name)).andUnsafe(DESC, "foo.bar");

assertThat(sort, hasItems(new Order(ASC, "colleagues.roles.name"), new Order(DESC, "foo.bar")));
assertThat(sort.getOrderFor("colleagues.roles.name"), is(not(instanceOf(JpaOrder.class))));
assertThat(sort.getOrderFor("foo.bar"), is(instanceOf(JpaOrder.class)));
}

/**
* @see DATAJPA-???
*/
@Test
public void combinesUnsafeAndSafeSortCorrectly() {

Sort sort = JpaSort.unsafe(DESC, "foo.bar").and(ASC, path(User_.colleagues).dot(User_.roles).dot(Role_.name));

assertThat(sort, hasItems(new Order(ASC, "colleagues.roles.name"), new Order(DESC, "foo.bar")));
assertThat(sort.getOrderFor("colleagues.roles.name"), is(not(instanceOf(JpaOrder.class))));
assertThat(sort.getOrderFor("foo.bar"), is(instanceOf(JpaOrder.class)));
}

}
@@ -1,5 +1,5 @@
/*
* Copyright 2008-2015 the original author or authors.
* Copyright 2008-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,14 +23,17 @@

import org.hamcrest.Matcher;
import org.junit.Test;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.domain.JpaSort;

/**
* Unit test for {@link QueryUtils}.
*
* @author Oliver Gierke
* @author Thomas Darimont
* @author Komi Innocent
* @author Christoph Strobl
*/
public class QueryUtilsUnitTests {

@@ -230,7 +233,7 @@ public void projectsCOuntQueriesForQueriesWithSubselects() {
/**
* @see DATAJPA-148
*/
@Test
@Test(expected = InvalidDataAccessApiUsageException.class)
public void doesNotPrefixSortsIfFunction() {

Sort sort = new Sort("sum(foo)");
@@ -361,6 +364,134 @@ public void doesNotQualifySortIfNoAliasDetected() {
endsWith("order by firstname asc"));
}

/**
* @see DATAJPA-???
*/
@Test(expected = InvalidDataAccessApiUsageException.class)
public void doesNotAllowWhitespaceInSort() {

Sort sort = new Sort("case when foo then bar");
applySorting("select p from Person p", sort, "p");
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixUnsageJpaSortFunctionCalls() {

JpaSort sort = JpaSort.unsafe("sum(foo)");
assertThat(applySorting("select p from Person p", sort, "p"), endsWith("order by sum(foo) asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixMultipleAliasedFunctionCalls() {

String query = "SELECT AVG(m.price) AS avgPrice, SUM(m.stocks) AS sumStocks FROM Magazine m";
Sort sort = new Sort("avgPrice", "sumStocks");

assertThat(applySorting(query, sort, "m"), endsWith("order by avgPrice asc, sumStocks asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixSingleAliasedFunctionCalls() {

String query = "SELECT AVG(m.price) AS avgPrice FROM Magazine m";
Sort sort = new Sort("avgPrice");

assertThat(applySorting(query, sort, "m"), endsWith("order by avgPrice asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void prefixesSingleNonAliasedFunctionCallRelatedSortProperty() {

String query = "SELECT AVG(m.price) AS avgPrice FROM Magazine m";
Sort sort = new Sort("someOtherProperty");

assertThat(applySorting(query, sort, "m"), endsWith("order by m.someOtherProperty asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void prefixesNonAliasedFunctionCallRelatedSortPropertyWhenSelectClauseContainesAliasedFunctionForDifferentProperty() {

String query = "SELECT m.name, AVG(m.price) AS avgPrice FROM Magazine m";
Sort sort = new Sort("name", "avgPrice");

assertThat(applySorting(query, sort, "m"), endsWith("order by m.name asc, avgPrice asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixAliasedFunctionCallNameWithMultipleNumericParameters() {

String query = "SELECT SUBSTRING(m.name, 2, 5) AS trimmedName FROM Magazine m";
Sort sort = new Sort("trimmedName");

assertThat(applySorting(query, sort, "m"), endsWith("order by trimmedName asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixAliasedFunctionCallNameWithMultipleStringParameters() {

String query = "SELECT CONCAT(m.name, 'foo') AS extendedName FROM Magazine m";
Sort sort = new Sort("extendedName");

assertThat(applySorting(query, sort, "m"), endsWith("order by extendedName asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixAliasedFunctionCallNameWithUnderscores() {

String query = "SELECT AVG(m.price) AS avg_price FROM Magazine m";
Sort sort = new Sort("avg_price");

assertThat(applySorting(query, sort, "m"), endsWith("order by avg_price asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixAliasedFunctionCallNameWithDots() {

String query = "SELECT AVG(m.price) AS m.avg FROM Magazine m";
Sort sort = new Sort("m.avg");

assertThat(applySorting(query, sort, "m"), endsWith("order by m.avg asc"));
}

/**
* @see DATAJPA-???
*/
@Test
public void doesNotPrefixAliasedFunctionCallNameWhenQueryStringContainsMultipleWhiteSpaces() {

String query = "SELECT AVG( m.price ) AS avgPrice FROM Magazine m";
Sort sort = new Sort("avgPrice");

assertThat(applySorting(query, sort, "m"), endsWith("order by avgPrice asc"));
}

private static void assertCountQuery(String originalQuery, String countQuery) {
assertThat(createCountQueryFor(originalQuery), is(countQuery));
}