Skip to content

Commit

Permalink
Merge pull request #659 from softwaremill/feature/custom-validator
Browse files Browse the repository at this point in the history
Add a validator for integration with custom complex validations
  • Loading branch information
adamw committed Sep 17, 2020
2 parents fdc8c46 + 3874582 commit e34760e
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 72 deletions.
46 changes: 28 additions & 18 deletions core/src/main/scala/sttp/tapir/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
def maxLength[T <: String](value: Int): Validator.Primitive[T] = MaxLength(value)
def minSize[T, C[_] <: Iterable[_]](value: Int): Validator.Primitive[C[T]] = MinSize(value)
def maxSize[T, C[_] <: Iterable[_]](value: Int): Validator.Primitive[C[T]] = MaxSize(value)
def custom[T](doValidate: T => Boolean, message: String): Validator.Primitive[T] = Custom(doValidate, message)
def custom[T](doValidate: T => List[ValidationError[_]], showMessage: Option[String] = None): Validator[T] =
Custom(doValidate, showMessage)

/**
* Creates an enum validator where all subtypes of the sealed hierarchy `T` are `object`s.
Expand All @@ -71,7 +72,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (implicitly[Numeric[T]].gt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value))) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
Expand All @@ -80,7 +81,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (implicitly[Numeric[T]].lt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value))) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
Expand All @@ -89,7 +90,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (t.matches(value)) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
Expand All @@ -98,7 +99,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (t.size >= value) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
Expand All @@ -107,7 +108,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (t.size <= value) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
Expand All @@ -116,7 +117,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (t.size >= value) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
Expand All @@ -125,17 +126,13 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (t.size <= value) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}
}
case class Custom[T](doValidate: T => Boolean, message: String) extends Primitive[T] {
case class Custom[T](doValidate: T => List[ValidationError[_]], showMessage: Option[String]) extends Validator.Single[T] {
override def validate(t: T): List[ValidationError[_]] = {
if (doValidate(t)) {
List.empty
} else {
List(ValidationError(this, t))
}
doValidate(t)
}
}

Expand All @@ -144,7 +141,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
if (possibleValues.contains(t)) {
List.empty
} else {
List(ValidationError(this, t))
List(ValidationError.Primitive(this, t))
}
}

Expand Down Expand Up @@ -236,7 +233,7 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
case MaxLength(value) => Some(s"length<=$value")
case MinSize(value) => Some(s"size>=$value")
case MaxSize(value) => Some(s"size<=$value")
case Custom(_, message) => Some(message)
case Custom(_, showMessage) => showMessage.orElse(Some("custom"))
case Enum(possibleValues, _) => Some(s"in(${possibleValues.mkString(",")}")
case CollectionElements(el, _) => recurse(el).map(se => s"elements($se)")
case Product(fields) =>
Expand Down Expand Up @@ -281,6 +278,19 @@ object Validator extends ValidatorMagnoliaDerivation with ValidatorEnumMacro {
implicit def openProduct[T: Validator]: Validator[Map[String, T]] = OpenProduct(implicitly[Validator[T]])
}

case class ValidationError[T](validator: Validator.Primitive[T], invalidValue: T, path: List[FieldName] = Nil) {
def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
sealed trait ValidationError[T] {
def prependPath(f: FieldName): ValidationError[T]
def invalidValue: T
def path: List[FieldName]
}

object ValidationError {

case class Primitive[T](validator: Validator.Primitive[T], invalidValue: T, path: List[FieldName] = Nil) extends ValidationError[T] {
override def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
}

case class Custom[T](invalidValue: T, message: String, path: List[FieldName] = Nil) extends ValidationError[T] {
override def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
}
}
33 changes: 20 additions & 13 deletions core/src/main/scala/sttp/tapir/server/ServerDefaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,26 @@ object ServerDefaults {
* @param valueName Name of the validated value to be used in error messages
*/
def invalidValueMessage[T](ve: ValidationError[T], valueName: String): String =
ve.validator match {
case Validator.Min(value, exclusive) =>
s"expected $valueName to be greater than ${if (exclusive) "" else "or equal to "}$value, but was ${ve.invalidValue}"
case Validator.Max(value, exclusive) =>
s"expected $valueName to be less than ${if (exclusive) "" else "or equal to "}$value, but was ${ve.invalidValue}"
case Validator.Pattern(value) => s"expected $valueName to match '$value', but was '${ve.invalidValue}'"
case Validator.MinLength(value) => s"expected $valueName to have length greater than or equal to $value, but was ${ve.invalidValue}"
case Validator.MaxLength(value) => s"expected $valueName to have length less than or equal to $value, but was ${ve.invalidValue} "
case Validator.MinSize(value) =>
s"expected size of $valueName to be greater than or equal to $value, but was ${ve.invalidValue.size}"
case Validator.MaxSize(value) => s"expected size of $valueName to be less than or equal to $value, but was ${ve.invalidValue.size}"
case Validator.Custom(_, message) => s"expected $valueName to pass custom validation: $message, but was '${ve.invalidValue}'"
case Validator.Enum(possibleValues, _) => s"expected $valueName to be within $possibleValues, but was '${ve.invalidValue}'"
ve match {
case p: ValidationError.Primitive[T] =>
p.validator match {
case Validator.Min(value, exclusive) =>
s"expected $valueName to be greater than ${if (exclusive) "" else "or equal to "}$value, but was ${ve.invalidValue}"
case Validator.Max(value, exclusive) =>
s"expected $valueName to be less than ${if (exclusive) "" else "or equal to "}$value, but was ${ve.invalidValue}"
case Validator.Pattern(value) => s"expected $valueName to match '$value', but was '${ve.invalidValue}'"
case Validator.MinLength(value) =>
s"expected $valueName to have length greater than or equal to $value, but was ${ve.invalidValue}"
case Validator.MaxLength(value) =>
s"expected $valueName to have length less than or equal to $value, but was ${ve.invalidValue} "
case Validator.MinSize(value) =>
s"expected size of $valueName to be greater than or equal to $value, but was ${ve.invalidValue.size}"
case Validator.MaxSize(value) =>
s"expected size of $valueName to be less than or equal to $value, but was ${ve.invalidValue.size}"
case Validator.Enum(possibleValues, _) => s"expected $valueName to be within $possibleValues, but was '${ve.invalidValue}'"
}
case c: ValidationError.Custom[T] =>
s"expected $valueName to pass custom validation: ${c.message}, but was '${ve.invalidValue}'"
}

/**
Expand Down
94 changes: 65 additions & 29 deletions core/src/test/scala/sttp/tapir/ValidatorTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,69 +14,69 @@ class ValidatorTest extends AnyFlatSpec with Matchers {
val min = 1
val wrong = 0
val v = Validator.min(min)
v.validate(wrong) shouldBe List(ValidationError(v, wrong))
v.validate(wrong) shouldBe List(ValidationError.Primitive(v, wrong))
v.validate(min) shouldBe empty
}

it should "validate for min value (exclusive)" in {
val min = 1
val wrong = 0
val v = Validator.min(min, exclusive = true)
v.validate(wrong) shouldBe List(ValidationError(v, wrong))
v.validate(min) shouldBe List(ValidationError(v, min))
v.validate(wrong) shouldBe List(ValidationError.Primitive(v, wrong))
v.validate(min) shouldBe List(ValidationError.Primitive(v, min))
v.validate(min + 1) shouldBe empty
}

it should "validate for max value" in {
val max = 0
val wrong = 1
val v = Validator.max(max)
v.validate(wrong) shouldBe List(ValidationError(v, wrong))
v.validate(wrong) shouldBe List(ValidationError.Primitive(v, wrong))
v.validate(max) shouldBe empty
}

it should "validate for max value (exclusive)" in {
val max = 0
val wrong = 1
val v = Validator.max(max, exclusive = true)
v.validate(wrong) shouldBe List(ValidationError(v, wrong))
v.validate(max) shouldBe List(ValidationError(v, max))
v.validate(wrong) shouldBe List(ValidationError.Primitive(v, wrong))
v.validate(max) shouldBe List(ValidationError.Primitive(v, max))
v.validate(max - 1) shouldBe empty
}

it should "validate for maxSize of collection" in {
val expected = 1
val actual = List(1, 2, 3)
val v = Validator.maxSize[Int, List](expected)
v.validate(actual) shouldBe List(ValidationError(v, actual))
v.validate(actual) shouldBe List(ValidationError.Primitive(v, actual))
v.validate(List(1)) shouldBe empty
}

it should "validate for minSize of collection" in {
val expected = 3
val v = Validator.minSize[Int, List](expected)
v.validate(List(1, 2)) shouldBe List(ValidationError(v, List(1, 2)))
v.validate(List(1, 2)) shouldBe List(ValidationError.Primitive(v, List(1, 2)))
v.validate(List(1, 2, 3)) shouldBe empty
}

it should "validate for matching regex pattern" in {
val expected = "^apple$|^banana$"
val wrong = "orange"
Validator.pattern(expected).validate(wrong) shouldBe List(ValidationError(Validator.pattern(expected), wrong))
Validator.pattern(expected).validate(wrong) shouldBe List(ValidationError.Primitive(Validator.pattern(expected), wrong))
Validator.pattern(expected).validate("banana") shouldBe empty
}

it should "validate for minLength of string" in {
val expected = 3
val v = Validator.minLength[String](expected)
v.validate("ab") shouldBe List(ValidationError(v, "ab"))
v.validate("ab") shouldBe List(ValidationError.Primitive(v, "ab"))
v.validate("abc") shouldBe empty
}

it should "validate for maxLength of string" in {
val expected = 1
val v = Validator.maxLength[String](expected)
v.validate("ab") shouldBe List(ValidationError(v, "ab"))
v.validate("ab") shouldBe List(ValidationError.Primitive(v, "ab"))
v.validate("a") shouldBe empty
}

Expand All @@ -85,53 +85,56 @@ class ValidatorTest extends AnyFlatSpec with Matchers {
validator.validate(4) shouldBe empty
validator.validate(7) shouldBe empty
validator.validate(11) shouldBe List(
ValidationError(Validator.max(5), 11),
ValidationError(Validator.max(10), 11)
ValidationError.Primitive(Validator.max(5), 11),
ValidationError.Primitive(Validator.max(10), 11)
)
}

it should "validate with all of validators" in {
val validator = Validator.all(Validator.min(3), Validator.max(10))
validator.validate(4) shouldBe empty
validator.validate(2) shouldBe List(ValidationError(Validator.min(3), 2))
validator.validate(11) shouldBe List(ValidationError(Validator.max(10), 11))
validator.validate(2) shouldBe List(ValidationError.Primitive(Validator.min(3), 2))
validator.validate(11) shouldBe List(ValidationError.Primitive(Validator.max(10), 11))
}

it should "validate with custom validator" in {
val v = Validator.custom(
{ x: Int =>
x > 5
},
"X has to be greater than 5!"
if (x > 5) {
List.empty
} else {
List(ValidationError.Custom(x, "X has to be greater than 5!"))
}
}
)
v.validate(0) shouldBe List(ValidationError(v, 0))
v.validate(0) shouldBe List(ValidationError.Custom(0, "X has to be greater than 5!"))
}

it should "validate openProduct" in {
val validator = Validator.openProduct(Validator.min(10))
validator.validate(Map("key" -> 0)).map(noPath(_)) shouldBe List(ValidationError(Validator.min(10), 0))
validator.validate(Map("key" -> 0)).map(noPath(_)) shouldBe List(ValidationError.Primitive(Validator.min(10), 0))
validator.validate(Map("key" -> 12)) shouldBe empty
}

it should "validate option" in {
val validator = Validator.optionElement(Validator.min(10))
validator.validate(None) shouldBe empty
validator.validate(Some(12)) shouldBe empty
validator.validate(Some(5)) shouldBe List(ValidationError(Validator.min(10), 5))
validator.validate(Some(5)) shouldBe List(ValidationError.Primitive(Validator.min(10), 5))
}

it should "validate iterable" in {
val validator = Validator.iterableElements[Int, List](Validator.min(10))
validator.validate(List.empty[Int]) shouldBe empty
validator.validate(List(11)) shouldBe empty
validator.validate(List(5)) shouldBe List(ValidationError(Validator.min(10), 5))
validator.validate(List(5)) shouldBe List(ValidationError.Primitive(Validator.min(10), 5))
}

it should "validate array" in {
val validator = Validator.arrayElements[Int](Validator.min(10))
validator.validate(Array.empty[Int]) shouldBe empty
validator.validate(Array(11)) shouldBe empty
validator.validate(Array(5)) shouldBe List(ValidationError(Validator.min(10), 5))
validator.validate(Array(5)) shouldBe List(ValidationError.Primitive(Validator.min(10), 5))
}

it should "validate product" in {
Expand All @@ -140,13 +143,13 @@ class ValidatorTest extends AnyFlatSpec with Matchers {
implicit val ageValidator: Validator[Int] = Validator.min(18)
val validator = Validator.validatorForCaseClass[Person]
validator.validate(Person("notImportantButOld", 21)).map(noPath(_)) shouldBe List(
ValidationError(Validator.pattern("^[A-Z].*"), "notImportantButOld")
ValidationError.Primitive(Validator.pattern("^[A-Z].*"), "notImportantButOld")
)
validator.validate(Person("notImportantAndYoung", 15)).map(noPath(_)) shouldBe List(
ValidationError(Validator.pattern("^[A-Z].*"), "notImportantAndYoung"),
ValidationError(Validator.min(18), 15)
ValidationError.Primitive(Validator.pattern("^[A-Z].*"), "notImportantAndYoung"),
ValidationError.Primitive(Validator.min(18), 15)
)
validator.validate(Person("ImportantButYoung", 15)).map(noPath(_)) shouldBe List(ValidationError(Validator.min(18), 15))
validator.validate(Person("ImportantButYoung", 15)).map(noPath(_)) shouldBe List(ValidationError.Primitive(Validator.min(18), 15))
validator.validate(Person("ImportantAndOld", 21)) shouldBe empty
}

Expand All @@ -157,7 +160,36 @@ class ValidatorTest extends AnyFlatSpec with Matchers {
it should "validate closed set of ints" in {
val v = Validator.enum(List(1, 2, 3, 4))
v.validate(1) shouldBe empty
v.validate(0) shouldBe List(ValidationError(v, 0))
v.validate(0) shouldBe List(ValidationError.Primitive(v, 0))
}

it should "validate a custom case class" in {
case class InnerCaseClass(innerValue: Long)
case class MyClass(name: String, age: Int, field: InnerCaseClass)
val validator = Validator.custom[MyClass](doValidate = { v =>
val nameErrors =
if (v.name.length < 3) List(ValidationError.Custom(v.name, "Name length should be >= 3", List(FieldName("name", "name"))))
else List.empty
val ageErrors =
if (v.age <= 0) List(ValidationError.Custom(v.age, "Age should be > 0", List(FieldName("age", "age")))) else List.empty
val innerErrors =
if (v.field.innerValue <= 0)
List(
ValidationError.Custom(
v.field.innerValue,
"Inner value should be > 0",
List(FieldName("field", "field"), FieldName("innerValue", "innerValue"))
)
)
else List.empty
nameErrors ++ ageErrors ++ innerErrors
})

validator.validate(MyClass("ab", -1, InnerCaseClass(-3))) shouldBe List(
ValidationError.Custom("ab", "Name length should be >= 3", List(FieldName("name", "name"))),
ValidationError.Custom(-1, "Age should be > 0", List(FieldName("age", "age"))),
ValidationError.Custom(-3, "Inner value should be > 0", List(FieldName("field", "field"), FieldName("innerValue", "innerValue")))
)
}

it should "skip collection validation for array if element validator is passing" in {
Expand Down Expand Up @@ -204,7 +236,11 @@ class ValidatorTest extends AnyFlatSpec with Matchers {
v.show shouldBe Some("subNames->(elements(elements(recursive)))")
}

private def noPath[T](v: ValidationError[T]): ValidationError[T] = v.copy(path = Nil)
private def noPath[T](v: ValidationError[T]): ValidationError[T] =
v match {
case p: ValidationError.Primitive[T] => p.copy(path = Nil)
case c: ValidationError.Custom[T] => c.copy(path = Nil)
}
}

sealed trait Color
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ private[schema] class TSchemaToOSchema(schemaReferenceMapper: SchemaReferenceMap
case Validator.MaxLength(value) => oschema.copy(maxLength = Some(value))
case Validator.MinSize(value) => oschema.copy(minItems = Some(value))
case Validator.MaxSize(value) => oschema.copy(maxItems = Some(value))
case Validator.Custom(_, _) => oschema
case Validator.Enum(_, None) => oschema
case Validator.Enum(v, Some(encode)) =>
val values = v.flatMap(x => encode(x).map(rawToString))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ package object schema {
case Validator.Product(_) => Nil
case Validator.Coproduct(_) => Nil
case Validator.OpenProduct(_) => Nil
case Validator.Custom(_) => Nil
case bv: Validator.Primitive[_] => List(bv)
}
}
Expand Down

0 comments on commit e34760e

Please sign in to comment.