Skip to content

Commit

Permalink
feat: add support for string arrays in rules engine parameters & supp…
Browse files Browse the repository at this point in the history
…ort for operationContextParams trait (#1119)
  • Loading branch information
0marperez committed Jul 18, 2024
1 parent 1ae66a0 commit fa262f5
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 43 deletions.
5 changes: 5 additions & 0 deletions .changes/22c07786-9168-425a-960b-e03378ee3ce3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "22c07786-9168-425a-960b-e03378ee3ce3",
"type": "feature",
"description": "Add support for operationContextParams trait"
}
5 changes: 5 additions & 0 deletions .changes/f1afb4d6-fa61-4eba-8695-b9a8bc59418a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "f1afb4d6-fa61-4eba-8695-b9a8bc59418a",
"type": "feature",
"description": "Add support for string arrays in rules engine parameters"
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.utils.doubleQuote
import software.amazon.smithy.kotlin.codegen.utils.toCamelCase
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.rulesengine.language.evaluation.value.ArrayValue
import software.amazon.smithy.rulesengine.language.evaluation.value.BooleanValue
import software.amazon.smithy.rulesengine.language.evaluation.value.StringValue
import software.amazon.smithy.rulesengine.language.evaluation.value.Value
Expand All @@ -32,7 +34,7 @@ fun ParameterType.toSymbol(): Symbol =
when (this) {
ParameterType.STRING -> KotlinTypes.String
ParameterType.BOOLEAN -> KotlinTypes.Boolean
ParameterType.STRING_ARRAY -> KotlinTypes.Collections.MutableList
ParameterType.STRING_ARRAY -> KotlinTypes.Collections.list(KotlinTypes.String)
}.asNullable()

/**
Expand All @@ -42,5 +44,16 @@ fun Value.toLiteral(): String =
when (this) {
is StringValue -> value.doubleQuote()
is BooleanValue -> value.toString()
is ArrayValue -> values.joinToString(", ", "listOf(", ")") { value ->
value.expectStringValue().value.doubleQuote()
}
else -> throw IllegalArgumentException("unrecognized parameter value type $type")
}

/**
* Format a list of string nodes for codegen
*/
fun List<Node>.format(): String =
this.joinToString(", ", "listOf(", ")") { element ->
element.expectStringNode().value.doubleQuote()
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import software.amazon.smithy.model.knowledge.KnowledgeIndex
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rulesengine.traits.ContextParamTrait
import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition
import software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait
import software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait

/**
Expand Down Expand Up @@ -44,14 +46,23 @@ class EndpointParameterIndex private constructor(model: Model) : KnowledgeIndex
}
}

/**
* Get the [operationContextParams](https://smithy.io/2.0/additional-specs/rules-engine/parameters.html#smithy-rules-operationcontextparams-trait)
* for an operation.
*
* @param op the operation shape to get context params for.
*/
fun operationContextParams(op: OperationShape): Map<String, OperationContextParamDefinition>? =
op.getTrait<OperationContextParamsTrait>()?.parameters

/**
* Check if there are any context parameters bound to an operation
*
* @param op operation to check parameters for
* @return true if there are any static or input context parameters for the given operation
* @return true if there are any static, input, or operation context parameters for the given operation
*/
fun hasContextParams(op: OperationShape): Boolean =
staticContextParams(op) != null || inputContextParams(op).isNotEmpty()
staticContextParams(op) != null || inputContextParams(op).isNotEmpty() || operationContextParams(op) != null

companion object {
fun of(model: Model): EndpointParameterIndex = EndpointParameterIndex(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.format
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.utils.toCamelCase
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.StringNode
Expand Down Expand Up @@ -130,6 +132,7 @@ class DefaultEndpointProviderTestGenerator(
when (v) {
is StringNode -> writer.writeInline("#S", v.value)
is BooleanNode -> writer.writeInline("#L", v.value)
is ArrayNode -> writer.writeInline("#L", v.elements.format())
else -> throw IllegalArgumentException("unexpected test case param value")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@ package software.amazon.smithy.kotlin.codegen.rendering.endpoints

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.jmespath.JmespathExpression
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.integration.SectionId
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.model.knowledge.EndpointParameterIndex
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.rendering.waiters.KotlinJmespathExpressionVisitor
import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter
import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType
import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait
import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition
import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition
import software.amazon.smithy.utils.StringUtils

object EndpointBusinessMetrics : SectionId
Expand Down Expand Up @@ -77,6 +83,7 @@ class EndpointResolverAdapterGenerator(
val topDownIndex = TopDownIndex.of(ctx.model)
val operations = topDownIndex.getContainedOperations(ctx.service)
val epParameterIndex = EndpointParameterIndex.of(ctx.model)
val operationsWithContextBindings = operations.filter { epParameterIndex.hasContextParams(it) }

writer.write(
"private typealias BindOperationContextParamsFn = (#T.Builder, #T) -> Unit",
Expand All @@ -88,24 +95,28 @@ class EndpointResolverAdapterGenerator(
"private val opContextBindings = mapOf<String, BindOperationContextParamsFn> (",
")",
) {
val operationsWithContextBindings = operations.filter { epParameterIndex.hasContextParams(it) }
operationsWithContextBindings.forEach { op ->
val bindFn = op.bindEndpointContextFn(ctx.settings) { fnWriter ->
fnWriter.withBlock(
"private fun #L(builder: #T.Builder, request: #T): Unit {",
"}",
op.bindEndpointContextFnName(),
EndpointParametersGenerator.getSymbol(ctx.settings),
RuntimeTypes.HttpClient.Operation.ResolveEndpointRequest,
) {
renderBindOperationContextParams(epParameterIndex, op, fnWriter)
}
}
write("#S to ::#T,", op.id.name, bindFn)
write("#S to ::#L,", op.id.name, op.bindEndpointContextFnName())
}
}

operationsWithContextBindings.forEach { op ->
renderBindOperationContextFunction(op, epParameterIndex)
}
}

private fun renderBindOperationContextFunction(op: OperationShape, epParameterIndex: EndpointParameterIndex) =
writer.write("")
.withBlock(
"private fun #L(builder: #T.Builder, request: #T): Unit {",
"}",
op.bindEndpointContextFnName(),
EndpointParametersGenerator.getSymbol(ctx.settings),
RuntimeTypes.HttpClient.Operation.ResolveEndpointRequest,
) {
renderBindOperationContextParams(epParameterIndex, op)
}

private fun renderResolveEndpointParams() {
// NOTE: this is internal as it's re-used for auth scheme resolver generators in specific instances where they
// fallback to endpoint rules (e.g. S3 & EventBridge)
Expand All @@ -119,14 +130,21 @@ class EndpointResolverAdapterGenerator(
) {
writer.addImport(RuntimeTypes.Core.Collections.get)
withBlock("return #T {", "}", EndpointParametersGenerator.getSymbol(ctx.settings)) {
// The SEP dictates a specific source order to use when binding parameters (from most specific to least):
// 1. staticContextParams (from operation shape)
// 2. contextParam (from member of operation input shape)
// 3. clientContextParams (from service shape)
// 4. builtin binding
// 5. builtin default
// Sources 4 and 5 are SDK-specific, builtin bindings are plugged in and rendered beforehand such that any bindings
// from source 1 or 2 can supersede them.
/*
The spec dictates a specific source order to use when binding parameters (from most specific to least):
1. staticContextParams (from operation shape)
2. contextParam (from member of operation input shape)
3. operationContextParams (from operation shape)
4. clientContextParams (from service shape)
5. builtin binding
6. builtin default
Sources 5 and 6 are SDK-specific
Builtin bindings are plugged in and rendered beforehand such that any bindings from source 1, 2, or 3
can supersede them.
*/

// Render builtins
if (rules != null) {
Expand All @@ -140,7 +158,7 @@ class EndpointResolverAdapterGenerator(
// Render client context
renderBindClientContextParams(ctx, writer)

// Render operation static/input context (if any)
// Render operation static/input/operation context (if any)
write("val opName = request.context[#T.OperationName]", RuntimeTypes.SmithyClient.SdkClientOption)
write("opContextBindings[opName]?.invoke(this, request)")
}
Expand All @@ -167,42 +185,87 @@ class EndpointResolverAdapterGenerator(
private fun renderBindOperationContextParams(
epParameterIndex: EndpointParameterIndex,
op: OperationShape,
writer: KotlinWriter,
) {
if (rules == null) return

val staticContextParams = epParameterIndex.staticContextParams(op)
val inputContextParams = epParameterIndex.inputContextParams(op)
val operationContextParams = epParameterIndex.operationContextParams(op)

if (inputContextParams.isNotEmpty()) {
writer.addImport(RuntimeTypes.Core.Collections.get)
writer.write("@Suppress(#S)", "UNCHECKED_CAST")
val opInputShape = ctx.model.expectShape(op.inputShape)
val inputSymbol = ctx.symbolProvider.toSymbol(opInputShape)
writer.write("val input = request.context[#T.OperationInput] as #T", RuntimeTypes.HttpClient.Operation.HttpOperationContext, inputSymbol)
}
if (inputContextParams.isNotEmpty()) renderInput(op)

for (param in rules.parameters.toList()) {
val paramName = param.name.toString()
val paramDefaultName = param.defaultName()

// Check static params
val staticParam = staticContextParams?.parameters?.get(paramName)

if (staticParam != null) {
writer.writeInline("builder.#L = ", paramDefaultName)
when (param.type) {
ParameterType.STRING -> writer.write("#S", staticParam.value.expectStringNode().value)
ParameterType.BOOLEAN -> writer.write("#L", staticParam.value.expectBooleanNode().value)
else -> throw CodegenException("unexpected static context param type ${param.type}")
}
renderStaticParam(staticParam, paramDefaultName, param)
continue
}

// Check input params
val inputParam = inputContextParams[paramName]
if (inputParam != null) {
renderInputParam(inputParam, paramDefaultName)
continue
}

inputContextParams[paramName]?.let {
writer.write("builder.#L = input.#L", paramDefaultName, it.defaultName())
// Check operation params
val operationParam = operationContextParams?.get(paramName)
if (operationParam != null) {
renderOperationParam(operationParam, paramDefaultName, op, inputContextParams)
}
}
}

private fun renderInput(op: OperationShape) {
writer.addImport(RuntimeTypes.Core.Collections.get)
writer.write("@Suppress(#S)", "UNCHECKED_CAST")
val opInputShape = ctx.model.expectShape(op.inputShape)
val inputSymbol = ctx.symbolProvider.toSymbol(opInputShape)
writer.write("val input = request.context[#T.OperationInput] as #T", RuntimeTypes.HttpClient.Operation.HttpOperationContext, inputSymbol)
}

private fun renderStaticParam(staticParam: StaticContextParamDefinition, paramDefaultName: String, param: Parameter) {
writer.writeInline("builder.#L = ", paramDefaultName)
when (param.type) {
ParameterType.STRING -> writer.write("#S", staticParam.value.expectStringNode().value)
ParameterType.BOOLEAN -> writer.write("#L", staticParam.value.expectBooleanNode().value)
ParameterType.STRING_ARRAY -> writer.write("#L", staticParam.value.expectArrayNode().elements.format())
else -> throw CodegenException("unexpected static context param type ${param.type}")
}
}

private fun renderInputParam(inputParam: MemberShape, paramDefaultName: String) {
writer.write("builder.#L = input.#L", paramDefaultName, inputParam.defaultName())
}

private fun renderOperationParam(operationParam: OperationContextParamDefinition, paramDefaultName: String, op: OperationShape, inputContextParams: Map<String, MemberShape>) {
val opInputShape = ctx.model.expectShape(op.inputShape)

if (inputContextParams.isEmpty()) {
// This will already be rendered in the block if inputContextParams is not empty
renderInput(op)
}

val jmespathVisitor = KotlinJmespathExpressionVisitor(
GenerationContext(
ctx.model,
ctx.symbolProvider,
ctx.settings,
),
writer,
opInputShape,
"input", // reference the operation input during jmespath codegen
)
val expression = JmespathExpression.parse(operationParam.path)
val expressionResult = expression.accept(jmespathVisitor)

writer.write("builder.#L = #L", paramDefaultName, expressionResult.identifier)
}

private fun renderBindClientContextParams(ctx: ProtocolGenerator.GenerationContext, writer: KotlinWriter) {
val clientContextParams = ctx.service.getTrait<ClientContextParamsTrait>() ?: return
if (rules == null) return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ private val suffixSequence = sequenceOf("") + generateSequence(2) { it + 1 }.map
* @param ctx The surrounding [CodegenContext].
* @param writer The [KotlinWriter] to generate code into.
* @param shape The modeled [Shape] on which this JMESPath expression is operating.
* @param topLevelParentName The name used to reference the top level "parent" of an expression during codegen.
* Defaults to `it`. E.g. `it.field`.
*/
class KotlinJmespathExpressionVisitor(
val ctx: CodegenContext,
val writer: KotlinWriter,
shape: Shape,
private val topLevelParentName: String = "it",
) : ExpressionVisitor<VisitedExpression> {
private val tempVars = mutableSetOf<String>()

Expand Down Expand Up @@ -172,7 +175,8 @@ class KotlinJmespathExpressionVisitor(

override fun visitExpressionType(expression: ExpressionTypeExpression): VisitedExpression = throw CodegenException("ExpressionTypeExpression is unsupported")

override fun visitField(expression: FieldExpression): VisitedExpression = subfield(expression, "it")
override fun visitField(expression: FieldExpression): VisitedExpression =
if (shapeCursor.size == 1) subfield(expression, topLevelParentName) else subfield(expression, "it")

override fun visitFilterProjection(expression: FilterProjectionExpression): VisitedExpression {
val left = expression.left.accept(this)
Expand Down Expand Up @@ -444,6 +448,10 @@ class KotlinJmespathExpressionVisitor(
private fun projection(expression: ProjectionExpression, parentName: String): VisitedExpression {
val left = when (expression.left) {
is SliceExpression -> slice(expression.left as SliceExpression, parentName)
is FieldExpression -> subfield(expression.left as FieldExpression, parentName)
is IndexExpression -> index(expression.left as IndexExpression, parentName)
is Subexpression -> subexpression(expression.left as Subexpression, parentName)
is ProjectionExpression -> projection(expression.left as ProjectionExpression, parentName)
else -> expression.left.accept(this)
}
requireNotNull(left.shape) { "projection is operating on nothing" }
Expand Down
Loading

0 comments on commit fa262f5

Please sign in to comment.