diff --git a/spring-test/src/main/java/org/springframework/test/json/JsonAssert.java b/spring-test/src/main/java/org/springframework/test/json/JsonAssert.java index 377c85f20510..a8d2bf404aca 100644 --- a/spring-test/src/main/java/org/springframework/test/json/JsonAssert.java +++ b/spring-test/src/main/java/org/springframework/test/json/JsonAssert.java @@ -16,13 +16,14 @@ package org.springframework.test.json; +import org.json.JSONException; import org.skyscreamer.jsonassert.JSONCompare; import org.skyscreamer.jsonassert.JSONCompareMode; import org.skyscreamer.jsonassert.JSONCompareResult; +import org.skyscreamer.jsonassert.comparator.DefaultComparator; import org.skyscreamer.jsonassert.comparator.JSONComparator; import org.springframework.lang.Nullable; -import org.springframework.util.function.ThrowingBiFunction; /** * Useful methods that can be used with {@code org.skyscreamer.jsonassert}. @@ -40,7 +41,9 @@ public abstract class JsonAssert { * @see JSONCompareMode#LENIENT */ public static JsonComparator comparator(JsonCompareMode compareMode) { - return comparator(toJSONCompareMode(compareMode)); + JSONCompareMode jsonAssertCompareMode = (compareMode != JsonCompareMode.LENIENT + ? JSONCompareMode.STRICT : JSONCompareMode.LENIENT); + return comparator(jsonAssertCompareMode); } /** @@ -50,8 +53,7 @@ public static JsonComparator comparator(JsonCompareMode compareMode) { * @return a new {@link JsonComparator} instance */ public static JsonComparator comparator(JSONComparator comparator) { - return comparator((expectedJson, actualJson) -> JSONCompare - .compareJSON(expectedJson, actualJson, comparator)); + return new JsonAssertJsonComparator(comparator); } /** @@ -61,33 +63,41 @@ public static JsonComparator comparator(JSONComparator comparator) { * @return a new {@link JsonComparator} instance */ public static JsonComparator comparator(JSONCompareMode mode) { - return comparator((expectedJson, actualJson) -> JSONCompare - .compareJSON(expectedJson, actualJson, mode)); + return new JsonAssertJsonComparator(mode); } - private static JsonComparator comparator(ThrowingBiFunction compareFunction) { - return (expectedJson, actualJson) -> compare(expectedJson, actualJson, compareFunction); - } + private static class JsonAssertJsonComparator implements JsonComparator { - private static JsonComparison compare(@Nullable String expectedJson, @Nullable String actualJson, - ThrowingBiFunction compareFunction) { + private final JSONComparator jsonAssertComparator; - if (actualJson == null) { - return (expectedJson != null) - ? JsonComparison.mismatch("Expected null JSON") - : JsonComparison.match(); + JsonAssertJsonComparator(JSONComparator jsonAssertComparator) { + this.jsonAssertComparator = jsonAssertComparator; } - if (expectedJson == null) { - return JsonComparison.mismatch("Expected non-null JSON"); + + JsonAssertJsonComparator(JSONCompareMode compareMode) { + this(new DefaultComparator(compareMode)); } - JSONCompareResult result = compareFunction.throwing(IllegalStateException::new).apply(expectedJson, actualJson); - return (!result.passed()) - ? JsonComparison.mismatch(result.getMessage()) - : JsonComparison.match(); - } - private static JSONCompareMode toJSONCompareMode(JsonCompareMode compareMode) { - return (compareMode != JsonCompareMode.LENIENT ? JSONCompareMode.STRICT : JSONCompareMode.LENIENT); + @Override + public JsonComparison compare(@Nullable String expectedJson, @Nullable String actualJson) { + if (actualJson == null) { + return (expectedJson != null) + ? JsonComparison.mismatch("Expected null JSON") + : JsonComparison.match(); + } + if (expectedJson == null) { + return JsonComparison.mismatch("Expected non-null JSON"); + } + try { + JSONCompareResult result = JSONCompare.compareJSON(expectedJson, actualJson, this.jsonAssertComparator); + return (!result.passed()) + ? JsonComparison.mismatch(result.getMessage()) + : JsonComparison.match(); + } + catch (JSONException ex) { + throw new IllegalStateException(ex); + } + } } }