Skip to content

Commit

Permalink
scrooge-generator: Validate requests when servers receive them
Browse files Browse the repository at this point in the history
Problem

For services defined in the IDL, we want to validate all
input parameters for all methods defined in the service.
If the parameter is of type struct, union, or exception type,
we want to leverage the validateInstanceValue API.
After we validated all parameters, we will throw a
ThriftValidationException for all validation violations.

Solution

Modified methodService.mustache template for Scala generator
to throw the exception in MethodPerEndpoint definitions

JIRA Issues: CSL-11198

Differential Revision: https://phabricator.twitter.biz/D768362
  • Loading branch information
heligw authored and jenkins committed Nov 20, 2021
1 parent 4645a0f commit 81aa07b
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class GoldService$FinagleService(
trace.recordRpc("doGreatThings")
trace.recordBinary("srv/thrift_endpoint", "com.twitter.scrooge.test.gold.thriftscala.GoldService#doGreatThings()")
}
try {
val request_item = com.twitter.scrooge.test.gold.thriftscala.Request.validateInstanceValue(args.request)
if (request_item.nonEmpty) throw new com.twitter.scrooge.thrift_validation.ThriftValidationException("doGreatThings", args.request.getClass, request_item)
} catch {
case _: NullPointerException => ()
}
iface.doGreatThings(args.request)
}
}
Expand All @@ -109,6 +115,12 @@ class GoldService$FinagleService(
trace.recordRpc("noExceptionCall")
trace.recordBinary("srv/thrift_endpoint", "com.twitter.scrooge.test.gold.thriftscala.GoldService#noExceptionCall()")
}
try {
val request_item = com.twitter.scrooge.test.gold.thriftscala.Request.validateInstanceValue(args.request)
if (request_item.nonEmpty) throw new com.twitter.scrooge.thrift_validation.ThriftValidationException("noExceptionCall", args.request.getClass, request_item)
} catch {
case _: NullPointerException => ()
}
iface.noExceptionCall(args.request)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class PlatinumService$FinagleService(
trace.recordRpc("moreCoolThings")
trace.recordBinary("srv/thrift_endpoint", "com.twitter.scrooge.test.gold.thriftscala.PlatinumService#moreCoolThings()")
}
try {
val request_item = com.twitter.scrooge.test.gold.thriftscala.Request.validateInstanceValue(args.request)
if (request_item.nonEmpty) throw new com.twitter.scrooge.thrift_validation.ThriftValidationException("moreCoolThings", args.request.getClass, request_item)
} catch {
case _: NullPointerException => ()
}
iface.moreCoolThings(args.request)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,66 @@
package com.twitter.scrooge.backend

import com.twitter.conversions.DurationOps.richDurationFromInt
import com.twitter.finagle.Address
import com.twitter.finagle.Name
import com.twitter.finagle.Thrift
import com.twitter.finagle.ThriftMux
import com.twitter.scrooge.Request
import com.twitter.scrooge.backend.thriftscala._
import com.twitter.scrooge.testutil.JMockSpec
import com.twitter.scrooge.thrift_validation.ThriftValidationViolation
import com.twitter.util.Await
import com.twitter.util.Awaitable
import com.twitter.util.Duration
import com.twitter.util.Future
import java.net.InetAddress
import java.net.InetSocketAddress
import org.apache.thrift.TApplicationException

class ValidationsSpec extends JMockSpec {

def await[T](a: Awaitable[T], d: Duration = 5.seconds): T = Await.result(a, d)

val validationStruct =
ValidationStruct(
"email",
-1,
101,
0,
0,
Map("1" -> "1", "2" -> "2"),
boolField = false,
"anything",
Some("nothing"))

val validationException = ValidationException("")

val validationUnion = ValidationUnion.IntField(-1)

val iface = new ValidationService.MethodPerEndpoint {
override def validate(
structRequest: ValidationStruct,
unionRequest: ValidationUnion,
exceptionRequest: ValidationException
): Future[Boolean] = Future.True

override def validateOption(
structRequest: Option[ValidationStruct],
unionRequest: Option[ValidationUnion],
exceptionRequest: Option[ValidationException]
): Future[Boolean] = Future.True
}

val muxServer =
ThriftMux.server.serveIface(new InetSocketAddress(InetAddress.getLoopbackAddress, 0), iface)

val muxClient = ThriftMux.client.build[ValidationService.MethodPerEndpoint](
Name.bound(Address(muxServer.boundAddress.asInstanceOf[InetSocketAddress])),
"client"
)

"validateInstanceValue" should {
"validate Struct" in { _ =>
val validationStruct =
ValidationStruct(
"email",
-1,
101,
0,
0,
Map("1" -> "1", "2" -> "2"),
boolField = false,
"anything")
val validationViolations = ValidationStruct.validateInstanceValue(validationStruct)
val violationMessages = Set(
"length must be between 6 and 2147483647",
Expand All @@ -32,17 +76,6 @@ class ValidationsSpec extends JMockSpec {
}

"validate nested Struct" in { _ =>
val validationStruct =
ValidationStruct(
"email",
-1,
101,
0,
0,
Map("1" -> "1", "2" -> "2"),
boolField = false,
"anything",
Some("nothing"))
val nestedValidationStruct = NestedValidationStruct(
"not an email",
validationStruct,
Expand Down Expand Up @@ -72,7 +105,6 @@ class ValidationsSpec extends JMockSpec {
}

"validate exception" in { _ =>
val validationException = ValidationException("")
val validationViolations = ValidationException.validateInstanceValue(validationException)
assertViolations(validationViolations, 1, Set("must not be empty"))
}
Expand All @@ -82,6 +114,76 @@ class ValidationsSpec extends JMockSpec {
val validationViolations = NonValidationStruct.validateInstanceValue(nonValidationStruct)
assertViolations(validationViolations, 0, Set.empty)
}

"validate struct, union and exception request" in { _ =>
intercept[TApplicationException] {
await(muxClient.validate(validationStruct, validationUnion, validationException))
}
}

"validate Option type with None and Some() request" in { _ =>
intercept[TApplicationException] {
await(muxClient
.validateOption(Some(validationStruct), Some(validationUnion), Some(validationException)))
}
// check for option that has None as value
// it shouldn't return an exception
assert(await(muxClient.validateOption(None, None, None)))
}

"validate with Thrift client with servicePerEndpoint[ServicePerEndpoint]" in { _ =>
val clientIface = Thrift.server.serveIface(
new InetSocketAddress(InetAddress.getLoopbackAddress, 0),
iface
)
val clientValidationService =
Thrift.client.servicePerEndpoint[ValidationService.ServicePerEndpoint](
Name.bound(Address(clientIface.boundAddress.asInstanceOf[InetSocketAddress])),
"clientValidationService"
)
intercept[TApplicationException] {
await(clientValidationService.validate(
ValidationService.validate$args(validationStruct, validationUnion, validationException)))
}
intercept[TApplicationException] {
await(
clientValidationService.validateOption(
ValidationService.validateOption$args(
Some(validationStruct),
Some(validationUnion),
Some(validationException))))
}
}

"validate with Thrift client with reqRepServiceEndPoint[ReqRepServiceEndPoint]" in { _ =>
val clientIface = Thrift.server.serveIface(
new InetSocketAddress(InetAddress.getLoopbackAddress, 0),
iface
)
val clientValidationService =
Thrift.client.servicePerEndpoint[ValidationService.ReqRepServicePerEndpoint](
Name.bound(Address(clientIface.boundAddress.asInstanceOf[InetSocketAddress])),
"clientValidationService"
)
intercept[TApplicationException] {
await(clientValidationService.validate(Request(
ValidationService.validate$args(validationStruct, validationUnion, validationException))))
}
intercept[TApplicationException] {
await(
clientValidationService.validateOption(
Request(
ValidationService.validateOption$args(
Some(validationStruct),
Some(validationUnion),
Some(validationException)))))
}
}

"validate if null parameters are passed as requests" in { _ =>
//nullPointerException is handled in the mustache file
assert(await(muxClient.validate(null, null, null)))
}
}

private def assertViolations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ struct ValidationStruct {
9: optional string optionalField
}

service ValidationService {
bool validate(
1: ValidationStruct structRequest,
2: ValidationUnion unionRequest,
3: ValidationException exceptionRequest)
bool validateOption (
1: optional ValidationStruct structRequest,
2: optional ValidationUnion unionRequest,
3: optional ValidationException exceptionRequest)
}

// skip annotations not used for ThriftValidator
struct NonValidationStruct {
1: string stringField (structFieldKey = "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@ addService("{{methodSvcNameForWire}}", {
trace.recordRpc("{{methodSvcNameForWire}}")
trace.recordBinary("srv/thrift_endpoint", "{{package}}.{{ServiceName}}#{{methodSvcNameForCompile}}()")
}
{{#args}}
{{#isValidationType}}
{{#isOption}}
if ({{arg}}.isDefined) {
{{/isOption}}
try {
val {{argResult}} = {{typeParameter}}.validateInstanceValue({{deReferencedArg}})
if ({{argResult}}.nonEmpty) throw new com.twitter.scrooge.thrift_validation.ThriftValidationException("{{methodSvcNameForCompile}}", {{arg}}.getClass, {{argResult}})
} catch {
case _: NullPointerException => ()
}
{{#isOption}}
}
{{/isOption}}
{{/isValidationType}}
{{/args}}
iface.{{methodSvcNameForCompile}}({{argNames}})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ trait ServiceTemplate { self: TemplateGenerator =>
v(s"($typesString)")
}
},
"args" -> v(function.args.map { arg => Dictionary("arg" -> genID(arg.sid)) }),
"args" -> v(function.args.map { arg =>
Dictionary(
"arg" -> genID(arg.sid)
)
}),
"isVoid" -> v(function.funcType == Void || function.funcType == OnewayVoid),
"is_oneway" -> v(function.funcType == OnewayVoid),
"functionType" -> {
Expand Down Expand Up @@ -108,6 +112,14 @@ trait ServiceTemplate { self: TemplateGenerator =>
)
}

/**
* Check if the fieldType is optional, If it is we append `.get` to get the value
* And if it is not we leave it as fieldName
*/
private[this] def genOptional(fieldName: CodeFragment, isOptional: Boolean): CodeFragment = {
if (isOptional) fieldName.append(".get") else fieldName
}

def functionObjectName(f: Function): SimpleID = f.funcName.toTitleCase

/**
Expand Down Expand Up @@ -173,6 +185,18 @@ trait ServiceTemplate { self: TemplateGenerator =>
.map { field => "args." + genID(field.sid).toData }
.mkString(", ")
),
"args" -> v(f.args.map {
arg =>
Dictionary(
"arg" -> v("args." + genID(arg.sid)),
"argResult" -> genID(arg.sid.append("_item")),
"typeParameter" -> genType(arg.fieldType),
"deReferencedArg" -> v(
"args." + genOptional(genID(arg.sid), arg.requiredness.isOptional)),
"isOption" -> v(arg.requiredness.isOptional),
"isValidationType" -> v(arg.fieldType.isInstanceOf[StructType])
)
}),
"typeName" -> genType(f.funcType),
"isVoid" -> v(f.funcType == Void || f.funcType == OnewayVoid),
"resultNamedArg" ->
Expand Down

0 comments on commit 81aa07b

Please sign in to comment.