Skip to content

Commit

Permalink
Rough implementation of setting query/form/header params on AHC
Browse files Browse the repository at this point in the history
  • Loading branch information
kelnos committed Mar 5, 2019
1 parent 05742f3 commit 6b2b552
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 9 deletions.
Expand Up @@ -12,15 +12,15 @@ import com.github.javaparser.ast.body.{ClassOrInterfaceDeclaration, MethodDeclar
import com.github.javaparser.ast.expr.{MethodCallExpr, NameExpr, _}
import com.github.javaparser.ast.stmt._
import com.twilio.guardrail.SwaggerUtil.jpaths
import com.twilio.guardrail.generators.Response
import com.twilio.guardrail.generators.{Response, ScalaParameter}
import com.twilio.guardrail.generators.syntax.Java._
import com.twilio.guardrail.languages.JavaLanguage
import com.twilio.guardrail.protocol.terms.client._
import com.twilio.guardrail.terms.RouteMeta
import com.twilio.guardrail.{RenderedClientOperation, StaticDefns, Target}
import java.net.URI
import java.util
import java.util.Locale
import javax.lang.model.`type`.PrimitiveType

object AsyncHttpClientClientGenerator {
private val URI_TYPE = JavaParser.parseClassOrInterfaceType("URI")
Expand All @@ -31,6 +31,7 @@ object AsyncHttpClientClientGenerator {
private val REQUEST_BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("RequestBuilder")
private val REQUEST_TYPE = JavaParser.parseClassOrInterfaceType("Request")
private val RESPONSE_TYPE = JavaParser.parseClassOrInterfaceType("Response")
private val STRING_PART_TYPE = JavaParser.parseClassOrInterfaceType("StringPart")
private val OBJECT_MAPPER_TYPE = JavaParser.parseClassOrInterfaceType("ObjectMapper")
private val BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("Builder")
private val MARSHALLING_EXCEPTION_TYPE = JavaParser.parseClassOrInterfaceType("MarshallingException")
Expand All @@ -42,6 +43,82 @@ object AsyncHttpClientClientGenerator {
completionStageType.setTypeArguments(RESPONSE_TYPE)
))

private def showParam(param: ScalaParameter[JavaLanguage], overrideParamName: Option[String] = None): Expression = {
val paramName = overrideParamName.getOrElse(param.paramName.asString)

def doShow(tpe: Type): Expression = tpe match {
case _: PrimitiveType =>
new MethodCallExpr(new NameExpr("String"), "valueOf", new NodeList[Expression](new NameExpr(paramName)))
case cls: ClassOrInterfaceType if cls.isOptional =>
doShow(cls.containedType)
case cls: ClassOrInterfaceType if cls.isBoxedType =>
new MethodCallExpr(new NameExpr(paramName), "toString")
case cls: ClassOrInterfaceType if cls.isNamed("List") =>
doShow(cls.containedType)
case cls: ClassOrInterfaceType if cls.getName.asString() == "String" =>
new NameExpr(paramName)
case _: ClassOrInterfaceType =>
// FIXME: this will cover our autogenerated enum types, but it would be really nice if we could at identify
// our enum times via a list of some sort. this will fail to compile in most non-enum cases, but there's
// the possibility that this could generate incorrect but compilable code.
new MethodCallExpr(new NameExpr(paramName), "getValue")
case _: VoidType =>
new NullLiteralExpr
case other =>
println(s"WARN: Unhandled arg type ${other.getClass.getName} for arg typed ${other.name} ${param.paramName}")
new NameExpr("UNSUPPORTED_PARAMETER_TYPE_PLEASE_FILE_AN_ISSUE")
}

doShow(param.argType)
}

private def optionIfPresent(optionVarType: Type, optionVarName: String, innerStatement: Statement): Statement = {
new ExpressionStmt(new MethodCallExpr(new NameExpr(optionVarName), "ifPresent", new NodeList[Expression](
new LambdaExpr(new NodeList(new Parameter(util.EnumSet.of(FINAL), optionVarType, new SimpleName("arg"))),
innerStatement,
true
))
))
}

private def generateBuilderMethodCalls(params: List[ScalaParameter[JavaLanguage]], builderMethodName: String): List[Statement] = {
val needsMultipart = params.exists(_.isFile)
params.map({ param =>
val finalMethodName = if (needsMultipart) "addBodyPart" else builderMethodName
val argName = if (param.required) param.paramName.asString else "arg"
val containedType = param.argType.containedType
val isList = if (param.required) param.argType.isNamed("List") else containedType.isNamed("List")
val listType = if (param.required) containedType else containedType.containedType

val makeArgList: String => NodeList[Expression] = name =>
if (containedType.isNamed("FilePart") || listType.isNamed("FilePart")) {
new NodeList[Expression](new NameExpr(name))
} else if (needsMultipart) {
new NodeList[Expression](new ObjectCreationExpr(null, STRING_PART_TYPE, new NodeList(showParam(param, Some(name)))))
} else {
new NodeList[Expression](new StringLiteralExpr(param.argName.value), showParam(param, Some(name)))
}

val builderStatement: Statement = if (isList) {
new ForEachStmt(
new VariableDeclarationExpr(listType, "member", FINAL),
new NameExpr(argName),
new BlockStmt(new NodeList(
new ExpressionStmt(new MethodCallExpr(new NameExpr("builder"), finalMethodName, makeArgList("member")))
))
)
} else {
new ExpressionStmt(new MethodCallExpr(new NameExpr("builder"), finalMethodName, makeArgList(argName)))
}

if (param.required) {
builderStatement
} else {
optionIfPresent(containedType, param.paramName.asString, builderStatement)
}
})
}

object ClientTermInterp extends (ClientTerm[JavaLanguage, ?] ~> Target) {
def apply[T](term: ClientTerm[JavaLanguage, T]): Target[T] = term match {
case GenerateClientOperation(_, RouteMeta(pathStr, httpMethod, operation), methodName, tracing, parameters, responses) =>
Expand Down Expand Up @@ -71,13 +148,19 @@ object AsyncHttpClientClientGenerator {
"setUrl", new NodeList[Expression](pathExpr)
)

val builderMethodCalls: List[Statement] = List(
generateBuilderMethodCalls(parameters.queryStringParams, "addQueryParam"),
generateBuilderMethodCalls(parameters.formParams, "addFormParam"),
generateBuilderMethodCalls(parameters.headerParams, "addHeader")
).flatten

val httpMethodCallExpr = new MethodCallExpr(
new FieldAccessExpr(new ThisExpr, "httpClient"),
"apply",
new NodeList[Expression](new MethodCallExpr(new NameExpr("builder"), "build"))
)
val requestCall = new MethodCallExpr(httpMethodCallExpr, "thenApply", new NodeList[Expression](
new LambdaExpr(new NodeList(new Parameter(RESPONSE_TYPE, "response")), new BlockStmt(new NodeList(
new LambdaExpr(new NodeList(new Parameter(util.EnumSet.of(FINAL), RESPONSE_TYPE, new SimpleName("response"))), new BlockStmt(new NodeList(
new SwitchStmt(new MethodCallExpr(new NameExpr("response"), "getStatusCode"), new NodeList(
responses.value.map(response => new SwitchEntryStmt(new IntegerLiteralExpr(response.statusCode), new NodeList(response.value match {
case None => new ReturnStmt(new ObjectCreationExpr(null, JavaParser.parseClassOrInterfaceType(s"${responseParentName}.${response.statusCodeName.asString}"), new NodeList()))
Expand Down Expand Up @@ -120,8 +203,9 @@ object AsyncHttpClientClientGenerator {
))

method.setBody(new BlockStmt(new NodeList(
new ExpressionStmt(requestBuilder),
new ReturnStmt(requestCall)
new ExpressionStmt(requestBuilder) +:
builderMethodCalls :+
new ReturnStmt(requestCall): _*
)))

RenderedClientOperation[JavaLanguage](method, List.empty)
Expand Down
Expand Up @@ -34,7 +34,7 @@ object JacksonGenerator {
private def sortParams(params: List[ProtocolParameter[JavaLanguage]]): (List[ParameterTerm], List[ParameterTerm]) = {
// TODO: if a required field has a default specified, include it in optionalTerms instead
val (req, opt) = params.partition(_.term.getType match {
case cls: ClassOrInterfaceType => !isOptionalType(cls)
case cls: ClassOrInterfaceType => !cls.isOptional
case _ => true
})

Expand Down Expand Up @@ -70,9 +70,6 @@ object JacksonGenerator {
})
}

private def isOptionalType(cls: ClassOrInterfaceType): Boolean =
(cls.getScope.asScala.fold("")(_.asString + ".") + cls.getName.asString) == "java.util.Optional"

private def lookupTypeName(tpeName: String, concreteTypes: List[PropMeta[JavaLanguage]])(f: Type => Target[Type]): Option[Target[Type]] =
concreteTypes
.find(_.clsName == tpeName)
Expand Down
Expand Up @@ -18,6 +18,35 @@ object Java {
def asScala: Option[T] = if (o.isPresent) Option(o.get) else None
}

implicit class RichType(val tpe: Type) extends AnyVal {
def isOptional: Boolean =
tpe match {
case cls: ClassOrInterfaceType =>
val scope = cls.getScope.asScala
cls.getNameAsString == "Optional" && (scope.isEmpty || scope.map(_.asString).contains("java.util"))
case _ => false
}

def containedType: Type =
tpe match {
case cls: ClassOrInterfaceType => cls.getTypeArguments.asScala.filter(_.size == 1).fold(tpe)(_.get(0))
case _ => tpe
}

def isNamed(name: String): Boolean =
tpe match {
case cls: ClassOrInterfaceType if name.contains(".") => (cls.getScope.asScala.fold("")(_ + ".") + cls.getNameAsString) == name
case cls: ClassOrInterfaceType => cls.getNameAsString == name
case _ => false
}

def name: Option[String] =
tpe match {
case cls: ClassOrInterfaceType => Some(cls.getScope.asScala.fold("")(_ + ".") + cls.getNameAsString)
case _ => None
}
}

private[this] def safeParse[T](log: String)(parser: String => T, s: String)(implicit cls: ClassTag[T]): Target[T] = {
Target.log.debug(log)(s) >> (
Try(parser(s)).toEither.fold(t => Target.raiseError(s"Unable to parse '${s}' to a ${cls.runtimeClass.getName}: ${t.getMessage}"), Target.pure)
Expand Down

0 comments on commit 6b2b552

Please sign in to comment.