diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 6a728f5e1..62facd948 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -6,24 +6,20 @@ import java.util.Collection; import java.util.LinkedHashSet; -import java.util.Set; +import java.util.List; +import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.codegen.core.SymbolReference; +import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.EventStreamIndex; -import software.amazon.smithy.model.knowledge.EventStreamInfo; -import software.amazon.smithy.model.knowledge.ServiceIndex; import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.model.traits.DocumentationTrait; import software.amazon.smithy.model.traits.StringTrait; import software.amazon.smithy.python.codegen.integrations.PythonIntegration; import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin; -import software.amazon.smithy.python.codegen.sections.InitializeHttpAuthParametersSection; -import software.amazon.smithy.python.codegen.sections.ResolveEndpointSection; -import software.amazon.smithy.python.codegen.sections.ResolveIdentitySection; -import software.amazon.smithy.python.codegen.sections.SendRequestSection; -import software.amazon.smithy.python.codegen.sections.SignRequestSection; import software.amazon.smithy.python.codegen.writer.PythonWriter; import software.amazon.smithy.utils.SmithyInternalApi; @@ -35,10 +31,14 @@ final class ClientGenerator implements Runnable { private final GenerationContext context; private final ServiceShape service; + private final SymbolProvider symbolProvider; + private final Model model; ClientGenerator(GenerationContext context, ServiceShape service) { this.context = context; this.service = service; + this.symbolProvider = context.symbolProvider(); + this.model = context.model(); } @Override @@ -47,7 +47,7 @@ public void run() { } private void generateService(PythonWriter writer) { - var serviceSymbol = context.symbolProvider().toSymbol(service); + var serviceSymbol = symbolProvider.toSymbol(service); var configSymbol = CodegenUtils.getConfigSymbol(context.settings()); var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); writer.addLogger(); @@ -58,6 +58,7 @@ private void generateService(PythonWriter writer) { Output = TypeVar("Output") """); + // TODO: Extend a base client class writer.openBlock("class $L:", "", serviceSymbol.getName(), () -> { var docs = service.getTrait(DocumentationTrait.class) .map(StringTrait::getValue) @@ -77,12 +78,25 @@ private void generateService(PythonWriter writer) { for (PythonIntegration integration : context.integrations()) { for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) { - if (runtimeClientPlugin.matchesService(context.model(), service)) { + if (runtimeClientPlugin.matchesService(model, service)) { runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); } } } + // TODO: Set default client protocol + // Need some mapping between protocol shape and the python implementation + // For now, test with restjson + writer.addImport("smithy_core.protocols", "RestJsonClientProtocol"); + writer.write("protocol = RestJsonClientProtocol"); + + writer.addImport("smithy_core.type_registry", "TypeRegistry"); + writer.write(""" + type_registry = TypeRegistry({ + $C + }) + """, writer.consumer(this::writeErrorTypeRegistry)); + writer.write(""" def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): self._config = config or $1T() @@ -97,606 +111,37 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): plugin(self._config) """, configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); - var topDownIndex = TopDownIndex.of(context.model()); - var eventStreamIndex = EventStreamIndex.of(context.model()); + var topDownIndex = TopDownIndex.of(model); + var eventStreamIndex = EventStreamIndex.of(model); for (OperationShape operation : topDownIndex.getContainedOperations(service)) { if (eventStreamIndex.getInputInfo(operation).isPresent() || eventStreamIndex.getOutputInfo(operation).isPresent()) { - generateEventStreamOperation(writer, operation); + // TODO: event streaming operations } else { generateOperation(writer, operation); } } }); - - if (context.protocolGenerator() != null) { - generateOperationExecutor(writer); - } } - private void generateOperationExecutor(PythonWriter writer) { - writer.pushState(); - - var hasStreaming = hasEventStream(); - writer.putContext("hasEventStream", hasStreaming); - if (hasStreaming) { - writer.addImports("smithy_core.deserializers", - Set.of( - "ShapeDeserializer", - "DeserializeableShape")); - writer.addStdlibImport("typing", "Any"); - } - - var transportRequest = context.applicationProtocol().requestType(); - var transportResponse = context.applicationProtocol().responseType(); - var errorSymbol = CodegenUtils.getServiceError(context.settings()); - var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); - var configSymbol = CodegenUtils.getConfigSymbol(context.settings()); - - writer.addStdlibImport("typing", "Callable"); - writer.addStdlibImport("typing", "Awaitable"); - writer.addStdlibImport("typing", "cast"); - writer.addStdlibImport("copy", "deepcopy"); - writer.addStdlibImport("asyncio", "sleep"); - - writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addImport("smithy_core.exceptions", "SmithyRetryException"); - writer.addImports("smithy_core.interceptors", Set.of("Interceptor", "InterceptorContext")); - writer.addImports("smithy_core.interfaces.retries", Set.of("RetryErrorInfo", "RetryErrorType")); - writer.addImport("smithy_core.interfaces.exceptions", "HasFault"); - - writer.indent(); - writer.write(""" - def _classify_error( - self, - *, - error: Exception, - context: InterceptorContext[Input, Output, $1T, $2T | None] - ) -> RetryErrorInfo: - logger.debug("Classifying error: %s", error) - """, transportRequest, transportResponse); - writer.indent(); - - if (context.applicationProtocol().isHttpProtocol()) { - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.write(""" - if not isinstance(error, HasFault) and not context.transport_response: - return RetryErrorInfo(error_type=RetryErrorType.TRANSIENT) - - if context.transport_response: - if context.transport_response.status in [429, 503]: - retry_after = None - retry_header = context.transport_response.fields["retry-after"] - if retry_header and retry_header.values: - retry_after = float(retry_header.values[0]) - return RetryErrorInfo(error_type=RetryErrorType.THROTTLING, retry_after_hint=retry_after) - - if context.transport_response.status >= 500: - return RetryErrorInfo(error_type=RetryErrorType.SERVER_ERROR) - - """); - } - - writer.write(""" - error_type = RetryErrorType.CLIENT_ERROR - if isinstance(error, HasFault) and error.fault == "server": - error_type = RetryErrorType.SERVER_ERROR - - return RetryErrorInfo(error_type=error_type) - - """); - writer.dedent(); - - writer.write( - """ - async def _execute_operation( - self, - input: Input, - plugins: list[$1T], - serialize: Callable[[Input, $5T], Awaitable[$2T]], - deserialize: Callable[[$3T, $5T], Awaitable[Output]], - config: $5T, - operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} - ) -> Output: - try: - return await self._handle_execution( - input, plugins, serialize, deserialize, config, operation_name, - ${?hasEventStream} - has_input_stream, event_deserializer, event_response_deserializer, - ${/hasEventStream} - ) - except Exception as e: - # Make sure every exception that we throw is an instance of $4T so - # customers can reliably catch everything we throw. - if not isinstance(e, $4T): - raise $4T(e) from e - raise e - - async def _handle_execution( - self, - input: Input, - plugins: list[$1T], - serialize: Callable[[Input, $5T], Awaitable[$2T]], - deserialize: Callable[[$3T, $5T], Awaitable[Output]], - config: $5T, - operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} - ) -> Output: - logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) - context: InterceptorContext[Input, None, None, None] = InterceptorContext( - request=input, - response=None, - transport_request=None, - transport_response=None, - ) - _client_interceptors = config.interceptors - client_interceptors = cast( - list[Interceptor[Input, Output, $2T, $3T]], _client_interceptors - ) - interceptors = client_interceptors - - try: - # Step 1a: Invoke read_before_execution on client-level interceptors - for interceptor in client_interceptors: - interceptor.read_before_execution(context) - - # Step 1b: Run operation-level plugins - config = deepcopy(config) - for plugin in plugins: - plugin(config) - - _client_interceptors = config.interceptors - interceptors = cast( - list[Interceptor[Input, Output, $2T, $3T]], - _client_interceptors, - ) - - # Step 1c: Invoke the read_before_execution hooks on newly added - # interceptors. - for interceptor in interceptors: - if interceptor not in client_interceptors: - interceptor.read_before_execution(context) - - # Step 2: Invoke the modify_before_serialization hooks - for interceptor in interceptors: - context._request = interceptor.modify_before_serialization(context) - - # Step 3: Invoke the read_before_serialization hooks - for interceptor in interceptors: - interceptor.read_before_serialization(context) - - # Step 4: Serialize the request - context_with_transport_request = cast( - InterceptorContext[Input, None, $2T, None], context - ) - logger.debug("Serializing request for: %s", context_with_transport_request.request) - context_with_transport_request._transport_request = await serialize( - context_with_transport_request.request, config - ) - logger.debug("Serialization complete. Transport request: %s", context_with_transport_request._transport_request) - - # Step 5: Invoke read_after_serialization - for interceptor in interceptors: - interceptor.read_after_serialization(context_with_transport_request) - - # Step 6: Invoke modify_before_retry_loop - for interceptor in interceptors: - context_with_transport_request._transport_request = ( - interceptor.modify_before_retry_loop(context_with_transport_request) - ) - - # Step 7: Acquire the retry token. - retry_strategy = config.retry_strategy - retry_token = retry_strategy.acquire_initial_retry_token() - - while True: - # Make an attempt, creating a copy of the context so we don't pass - # around old data. - context_with_response = await self._handle_attempt( - deserialize, - interceptors, - context_with_transport_request.copy(), - config, - operation_name, - ) - - # We perform this type-ignored re-assignment because `context` needs - # to point at the latest context so it can be generically handled - # later on. This is only an issue here because we've created a copy, - # so we're no longer simply pointing at the same object in memory - # with different names and type hints. It is possible to address this - # without having to fall back to the type ignore, but it would impose - # unnecessary runtime costs. - context = context_with_response # type: ignore - - if isinstance(context_with_response.response, Exception): - # Step 7u: Reacquire retry token if the attempt failed - try: - retry_token = retry_strategy.refresh_retry_token_for_retry( - token_to_renew=retry_token, - error_info=self._classify_error( - error=context_with_response.response, - context=context_with_response, - ) - ) - except SmithyRetryException: - raise context_with_response.response - logger.debug( - "Retry needed. Attempting request #%s in %.4f seconds.", - retry_token.retry_count + 1, - retry_token.retry_delay - ) - await sleep(retry_token.retry_delay) - current_body = context_with_transport_request.transport_request.body - if (seek := getattr(current_body, "seek", None)) is not None: - await seek(0) - else: - # Step 8: Invoke record_success - retry_strategy.record_success(token=retry_token) - break - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # At this point, the context's request will have been definitively set, and - # The response will be set either with the modeled output or an exception. The - # transport_request and transport_response may be set or None. - execution_context = cast( - InterceptorContext[Input, Output, $2T | None, $3T | None], context - ) - ${^hasEventStream} - return await self._finalize_execution(interceptors, execution_context) - ${/hasEventStream} - ${?hasEventStream} - operation_output = await self._finalize_execution(interceptors, execution_context) - if has_input_stream or event_deserializer is not None: - ${6C|} - else: - return operation_output - ${/hasEventStream} - - async def _handle_attempt( - self, - deserialize: Callable[[$3T, $5T], Awaitable[Output]], - interceptors: list[Interceptor[Input, Output, $2T, $3T]], - context: InterceptorContext[Input, None, $2T, None], - config: $5T, - operation_name: str, - ) -> InterceptorContext[Input, Output, $2T, $3T | None]: - try: - # assert config.interceptors is not None - # Step 7a: Invoke read_before_attempt - for interceptor in interceptors: - interceptor.read_before_attempt(context) - - """, - pluginSymbol, - transportRequest, - transportResponse, - errorSymbol, - configSymbol, - writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w))); - - boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty(); - writer.pushState(new ResolveIdentitySection()); - if (context.applicationProtocol().isHttpProtocol() && supportsAuth) { - writer.pushState(new InitializeHttpAuthParametersSection()); - writer.write(""" - # Step 7b: Invoke service_auth_scheme_resolver.resolve_auth_scheme - auth_parameters: $1T = $1T( - operation=operation_name, - ${2C|} - ) - - """, - CodegenUtils.getHttpAuthParamsSymbol(context.settings()), - writer.consumer(this::initializeHttpAuthParameters)); - writer.popState(); - - writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addImport("smithy_core.interfaces.identity", "Identity"); - writer.addImports("smithy_http.aio.interfaces.auth", Set.of("HTTPSigner", "HTTPAuthOption")); - writer.addStdlibImport("typing", "Any"); - writer.write(""" - auth_options = config.http_auth_scheme_resolver.resolve_auth_scheme( - auth_parameters=auth_parameters - ) - auth_option: HTTPAuthOption | None = None - for option in auth_options: - if option.scheme_id in config.http_auth_schemes: - auth_option = option - break - - signer: HTTPSigner[Any, Any] | None = None - identity: Identity | None = None - - if auth_option: - auth_scheme = config.http_auth_schemes[auth_option.scheme_id] - - # Step 7c: Invoke auth_scheme.identity_resolver - identity_resolver = auth_scheme.identity_resolver(config=config) - - # Step 7d: Invoke auth_scheme.signer - signer = auth_scheme.signer - - # Step 7e: Invoke identity_resolver.get_identity - identity = await identity_resolver.get_identity( - identity_properties=auth_option.identity_properties - ) - - """); - } - writer.popState(); - - writer.pushState(new ResolveEndpointSection()); - if (context.applicationProtocol().isHttpProtocol()) { - writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addImport("smithy_core", "URI"); - writer.write(""" - # Step 7f: Invoke endpoint_resolver.resolve_endpoint - endpoint_resolver_parameters = $1T.build(config=config) - logger.debug("Calling endpoint resolver with parameters: %s", endpoint_resolver_parameters) - endpoint = await config.endpoint_resolver.resolve_endpoint( - endpoint_resolver_parameters - ) - logger.debug("Endpoint resolver result: %s", endpoint) - if not endpoint.uri.path: - path = "" - elif endpoint.uri.path.endswith("/"): - path = endpoint.uri.path[:-1] - else: - path = endpoint.uri.path - if context.transport_request.destination.path: - path += context.transport_request.destination.path - context._transport_request.destination = URI( - scheme=endpoint.uri.scheme, - host=context.transport_request.destination.host + endpoint.uri.host, - path=path, - port=endpoint.uri.port, - query=context.transport_request.destination.query, - ) - context._transport_request.fields.extend(endpoint.headers) - - """, - CodegenUtils.getEndpointParametersSymbol(context.settings())); - } - writer.popState(); - - writer.write(""" - # Step 7g: Invoke modify_before_signing - for interceptor in interceptors: - context._transport_request = interceptor.modify_before_signing(context) - - # Step 7h: Invoke read_before_signing - for interceptor in interceptors: - interceptor.read_before_signing(context) - - """); - - writer.pushState(new SignRequestSection()); - if (context.applicationProtocol().isHttpProtocol() && supportsAuth) { - writer.write(""" - # Step 7i: sign the request - if auth_option and signer: - logger.debug("HTTP request to sign: %s", context.transport_request) - logger.debug( - "Signer properties: %s", - auth_option.signer_properties - ) - context._transport_request = await signer.sign( - http_request=context.transport_request, - identity=identity, - signing_properties=auth_option.signer_properties, - ) - logger.debug("Signed HTTP request: %s", context._transport_request) - """); - } - writer.popState(); - - writer.write(""" - # Step 7j: Invoke read_after_signing - for interceptor in interceptors: - interceptor.read_after_signing(context) - - # Step 7k: Invoke modify_before_transmit - for interceptor in interceptors: - context._transport_request = interceptor.modify_before_transmit(context) - - # Step 7l: Invoke read_before_transmit - for interceptor in interceptors: - interceptor.read_before_transmit(context) - - """); - - writer.pushState(new SendRequestSection()); - if (context.applicationProtocol().isHttpProtocol()) { - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addImport("smithy_http.interfaces", "HTTPRequestConfiguration"); - writer.write(""" - # Step 7m: Invoke http_client.send - request_config = config.http_request_config or HTTPRequestConfiguration() - context_with_response = cast( - InterceptorContext[Input, None, $1T, $2T], context - ) - logger.debug("HTTP request config: %s", request_config) - logger.debug("Sending HTTP request: %s", context_with_response.transport_request) - context_with_response._transport_response = await config.http_client.send( - request=context_with_response.transport_request, - request_config=request_config, - ) - logger.debug("Received HTTP response: %s", context_with_response.transport_response) - - """, transportRequest, transportResponse); - } - writer.popState(); - - writer.write(""" - # Step 7n: Invoke read_after_transmit - for interceptor in interceptors: - interceptor.read_after_transmit(context_with_response) - - # Step 7o: Invoke modify_before_deserialization - for interceptor in interceptors: - context_with_response._transport_response = ( - interceptor.modify_before_deserialization(context_with_response) - ) - - # Step 7p: Invoke read_before_deserialization - for interceptor in interceptors: - interceptor.read_before_deserialization(context_with_response) - - # Step 7q: deserialize - context_with_output = cast( - InterceptorContext[Input, Output, $1T, $2T], - context_with_response, - ) - logger.debug("Deserializing transport response: %s", context_with_output._transport_response) - context_with_output._response = await deserialize( - context_with_output._transport_response, config - ) - logger.debug("Deserialization complete. Response: %s", context_with_output._response) - - # Step 7r: Invoke read_after_deserialization - for interceptor in interceptors: - interceptor.read_after_deserialization(context_with_output) - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # At this point, the context's request and transport_request have definitively been set, - # the response is either set or an exception, and the transport_resposne is either set or - # None. This will also be true after _finalize_attempt because there is no opportunity - # there to set the transport_response. - attempt_context = cast( - InterceptorContext[Input, Output, $1T, $2T | None], context - ) - return await self._finalize_attempt(interceptors, attempt_context) - - async def _finalize_attempt( - self, - interceptors: list[Interceptor[Input, Output, $1T, $2T]], - context: InterceptorContext[Input, Output, $1T, $2T | None], - ) -> InterceptorContext[Input, Output, $1T, $2T | None]: - # Step 7s: Invoke modify_before_attempt_completion - try: - for interceptor in interceptors: - context._response = interceptor.modify_before_attempt_completion( - context - ) - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # Step 7t: Invoke read_after_attempt - for interceptor in interceptors: - try: - interceptor.read_after_attempt(context) - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - return context - - async def _finalize_execution( - self, - interceptors: list[Interceptor[Input, Output, $1T, $2T]], - context: InterceptorContext[Input, Output, $1T | None, $2T | None], - ) -> Output: - try: - # Step 9: Invoke modify_before_completion - for interceptor in interceptors: - context._response = interceptor.modify_before_completion(context) - - # Step 10: Invoke trace_probe.dispatch_events - try: - pass - except Exception as e: - # log and ignore exceptions - logger.exception("Exception occurred while dispatching trace events: %s", e) - pass - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # Step 11: Invoke read_after_execution - for interceptor in interceptors: - try: - interceptor.read_after_execution(context) - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # Step 12: Return / throw - if isinstance(context.response, Exception): - raise context.response - - # We may want to add some aspects of this context to the output types so we can - # return it to the end-users. - return context.response - """, transportRequest, transportResponse); - writer.dedent(); - writer.popState(); - } - - private boolean hasEventStream() { - var streamIndex = EventStreamIndex.of(context.model()); - var topDownIndex = TopDownIndex.of(context.model()); - for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) { - if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) { - return true; - } + private void writeDefaultPlugins(PythonWriter writer, Collection plugins) { + for (SymbolReference plugin : plugins) { + writer.write("$T,", plugin); } - return false; } - private void initializeHttpAuthParameters(PythonWriter writer) { - var derived = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) { - if (plugin.matchesService(context.model(), service) - && plugin.getAuthScheme().isPresent() - && plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) { - derived.addAll(plugin.getAuthScheme().get().getAuthProperties()); - } - } - } - - for (DerivedProperty property : derived) { - var source = property.source().scopeLocation(); - if (property.initializationFunction().isPresent()) { - writer.write("$L=$T($L),", property.name(), property.initializationFunction().get(), source); - } else if (property.sourcePropertyName().isPresent()) { - writer.write("$L=$L.$L,", property.name(), source, property.sourcePropertyName().get()); - } + /** + * Generates the type-registry for the modeled errors of the service. + * TODO: Implicit errors + */ + private void writeErrorTypeRegistry(PythonWriter writer) { + List errors = service.getErrors(); + if (!errors.isEmpty()) { + writer.addImport("smithy_core.shapes", "ShapeID"); } - } - - private void writeDefaultPlugins(PythonWriter writer, Collection plugins) { - for (SymbolReference plugin : plugins) { - writer.write("$T,", plugin); + for (var error : errors) { + var errSymbol = symbolProvider.toSymbol(model.expectShape(error)); + writer.write("ShapeID($S): $T,", error, errSymbol); } } @@ -704,45 +149,38 @@ private void writeDefaultPlugins(PythonWriter writer, Collection $T:", "", - operationSymbol.getName(), + operationMethodSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> { - writeSharedOperationInit(writer, operation, input); + writeOperationDocs(writer, operation, input); if (context.protocolGenerator() == null) { writer.write("raise NotImplementedError()"); } else { - var protocolGenerator = context.protocolGenerator(); - var serSymbol = protocolGenerator.getSerializationFunction(context, operation); - var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation); - writer.write(""" - return await self._execute_operation( - input=input, - plugins=operation_plugins, - serialize=$T, - deserialize=$T, - config=self._config, - operation_name=$S, - ) - """, serSymbol, deserSymbol, operation.getId().getName()); + // TODO: override config + // TODO: try/except with error + // TODO: align with implementation of call() in the base client class (when its created) + writer.write("return await call(input, $T)", operationSymbol); } }); + } - private void writeSharedOperationInit(PythonWriter writer, OperationShape operation, Shape input) { + private void writeOperationDocs(PythonWriter writer, OperationShape operation, Shape input) { writer.writeDocs(() -> { var docs = operation.getTrait(DocumentationTrait.class) .map(StringTrait::getValue) @@ -756,153 +194,7 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat $L :param input: $L - - :param plugins: A list of callables that modify the configuration dynamically. - Changes made by these plugins only apply for the duration of the operation - execution and will not affect any other operation invocations.""", docs, inputDocs); + """, docs, inputDocs); }); - - var defaultPlugins = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) { - if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { - runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); - } - } - } - writer.write(""" - operation_plugins: list[Plugin] = [ - $C - ] - if plugins: - operation_plugins.extend(plugins) - """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); - - } - - private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { - writer.pushState(); - writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM); - var operationSymbol = context.symbolProvider().toSymbol(operation); - writer.putContext("operationName", operationSymbol.getName()); - var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); - writer.putContext("plugin", pluginSymbol); - - var input = context.model().expectShape(operation.getInputShape()); - var inputSymbol = context.symbolProvider().toSymbol(input); - writer.putContext("input", inputSymbol); - - var eventStreamIndex = EventStreamIndex.of(context.model()); - var inputStreamSymbol = eventStreamIndex.getInputInfo(operation) - .map(EventStreamInfo::getEventStreamTarget) - .map(target -> context.symbolProvider().toSymbol(target)) - .orElse(null); - writer.putContext("inputStream", inputStreamSymbol); - - var output = context.model().expectShape(operation.getOutputShape()); - var outputSymbol = context.symbolProvider().toSymbol(output); - writer.putContext("output", outputSymbol); - - var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation) - .map(EventStreamInfo::getEventStreamTarget) - .map(target -> context.symbolProvider().toSymbol(target)) - .orElse(null); - writer.putContext("outputStream", outputStreamSymbol); - - writer.putContext("hasProtocol", context.protocolGenerator() != null); - if (context.protocolGenerator() != null) { - var serSymbol = context.protocolGenerator().getSerializationFunction(context, operation); - writer.putContext("serSymbol", serSymbol); - var deserSymbol = context.protocolGenerator().getDeserializationFunction(context, operation); - writer.putContext("deserSymbol", deserSymbol); - } else { - writer.putContext("serSymbol", null); - writer.putContext("deserSymbol", null); - } - - if (inputStreamSymbol != null) { - if (outputStreamSymbol != null) { - writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream"); - writer.write(""" - async def ${operationName:L}( - self, - input: ${input:T}, - plugins: list[${plugin:T}] | None = None - ) -> DuplexEventStream[${inputStream:T}, ${outputStream:T}, ${output:T}]: - ${C|} - ${^hasProtocol} - raise NotImplementedError() - ${/hasProtocol} - ${?hasProtocol} - return await self._execute_operation( - input=input, - plugins=operation_plugins, - serialize=${serSymbol:T}, - deserialize=${deserSymbol:T}, - config=self._config, - operation_name=${operationName:S}, - has_input_stream=True, - event_deserializer=$T().deserialize, - event_response_deserializer=${output:T}, - ) # type: ignore - ${/hasProtocol} - """, - writer.consumer(w -> writeSharedOperationInit(w, operation, input)), - outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER)); - } else { - writer.addImport("smithy_event_stream.aio.interfaces", "InputEventStream"); - writer.write(""" - async def ${operationName:L}( - self, - input: ${input:T}, - plugins: list[${plugin:T}] | None = None - ) -> InputEventStream[${inputStream:T}, ${output:T}]: - ${C|} - ${^hasProtocol} - raise NotImplementedError() - ${/hasProtocol} - ${?hasProtocol} - return await self._execute_operation( - input=input, - plugins=operation_plugins, - serialize=${serSymbol:T}, - deserialize=${deserSymbol:T}, - config=self._config, - operation_name=${operationName:S}, - has_input_stream=True, - ) # type: ignore - ${/hasProtocol} - """, writer.consumer(w -> writeSharedOperationInit(w, operation, input))); - } - } else { - writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream"); - writer.write(""" - async def ${operationName:L}( - self, - input: ${input:T}, - plugins: list[${plugin:T}] | None = None - ) -> OutputEventStream[${outputStream:T}, ${output:T}]: - ${C|} - ${^hasProtocol} - raise NotImplementedError() - ${/hasProtocol} - ${?hasProtocol} - return await self._execute_operation( - input=input, - plugins=operation_plugins, - serialize=${serSymbol:T}, - deserialize=${deserSymbol:T}, - config=self._config, - operation_name=${operationName:S}, - event_deserializer=$T().deserialize, - event_response_deserializer=${output:T}, - ) # type: ignore - ${/hasProtocol} - """, - writer.consumer(w -> writeSharedOperationInit(w, operation, input)), - outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER)); - } - - writer.popState(); } } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java index 51f1d189d..54a2531f0 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java @@ -21,6 +21,7 @@ import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective; import software.amazon.smithy.codegen.core.directed.GenerateListDirective; import software.amazon.smithy.codegen.core.directed.GenerateMapDirective; +import software.amazon.smithy.codegen.core.directed.GenerateOperationDirective; import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective; import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective; import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective; @@ -35,6 +36,7 @@ import software.amazon.smithy.python.codegen.generators.IntEnumGenerator; import software.amazon.smithy.python.codegen.generators.ListGenerator; import software.amazon.smithy.python.codegen.generators.MapGenerator; +import software.amazon.smithy.python.codegen.generators.OperationGenerator; import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; import software.amazon.smithy.python.codegen.generators.SchemaGenerator; import software.amazon.smithy.python.codegen.generators.ServiceErrorGenerator; @@ -108,9 +110,10 @@ public void customizeBeforeShapeGeneration(CustomizeDirective directive) { + DirectedCodegen.super.generateOperation(directive); + + directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { + OperationGenerator generator = new OperationGenerator( + directive.context(), + writer, + directive.shape()); + generator.run(); + }); + } + @Override public void generateStructure(GenerateStructureDirective directive) { directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java index 10394bc7d..b9867eddd 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java @@ -276,8 +276,14 @@ public Symbol bigDecimalShape(BigDecimalShape shape) { public Symbol operationShape(OperationShape shape) { // Operation names are escaped like members because ultimately they're // properties on an object too. - var name = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service))); - return createGeneratedSymbolBuilder(shape, name, "client").build(); + var methodName = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service))); + var methodSymbol = createGeneratedSymbolBuilder(shape, methodName, "client", false).build(); + + // We add a symbol for the method in the client as a property, whereas the actual + // operation symbol points to the generated type for it + return createGeneratedSymbolBuilder(shape, getDefaultShapeName(shape), SHAPES_FILE) + .putProperty(SymbolProperties.OPERATION_METHOD, methodSymbol) + .build(); } @Override diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java index 46862aa72..63c26343d 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java @@ -73,5 +73,11 @@ public final class SymbolProperties { */ public static final Property DESERIALIZER = Property.named("deserializer"); + /** + * Contains a symbol pointing to an operation shape's method in the client. This is + * only used for operations. + */ + public static final Property OPERATION_METHOD = Property.named("operationMethod"); + private SymbolProperties() {} } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java new file mode 100644 index 000000000..a09517ccc --- /dev/null +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java @@ -0,0 +1,101 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.codegen.generators; + +import java.util.List; +import java.util.logging.Logger; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.ServiceIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.SymbolProperties; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class OperationGenerator implements Runnable { + private static final Logger LOGGER = Logger.getLogger(OperationGenerator.class.getName()); + + private final GenerationContext context; + private final PythonWriter writer; + private final OperationShape shape; + private final SymbolProvider symbolProvider; + private final Model model; + + public OperationGenerator(GenerationContext context, PythonWriter writer, OperationShape shape) { + this.context = context; + this.writer = writer; + this.shape = shape; + this.symbolProvider = context.symbolProvider(); + this.model = context.model(); + } + + @Override + public void run() { + var opSymbol = symbolProvider.toSymbol(shape); + var inSymbol = symbolProvider.toSymbol(model.expectShape(shape.getInputShape())); + var outSymbol = symbolProvider.toSymbol(model.expectShape(shape.getOutputShape())); + + writer.addStdlibImport("dataclasses", "dataclass"); + writer.addImport("smithy_core.schemas", "APIOperation"); + writer.addImport("smithy_core.type_registry", "TypeRegistry"); + + writer.write(""" + @dataclass(kw_only=True, frozen=True) + class $1L(APIOperation["$2T", "$3T"]): + input = $2T + output = $3T + schema = $4T + input_schema = $5T + output_schema = $6T + error_registry = TypeRegistry({ + $7C + }) + effective_auth_schemes = [ + $8C + ] + """, + opSymbol.getName(), + inSymbol, + outSymbol, + opSymbol.expectProperty(SymbolProperties.SCHEMA), + inSymbol.expectProperty(SymbolProperties.SCHEMA), + outSymbol.expectProperty(SymbolProperties.SCHEMA), + writer.consumer(this::writeErrorTypeRegistry), + writer.consumer(this::writeAuthSchemes) + // TODO: Docs? Maybe not necessary on the operation type itself + // TODO: Singleton? + ); + } + + private void writeErrorTypeRegistry(PythonWriter writer) { + List errors = shape.getErrors(); + if (!errors.isEmpty()) { + writer.addImport("smithy_core.shapes", "ShapeID"); + } + for (var error : errors) { + var errSymbol = symbolProvider.toSymbol(model.expectShape(error)); + writer.write("ShapeID($S): $T,", error, errSymbol); + } + } + + private void writeAuthSchemes(PythonWriter writer) { + var authSchemes = ServiceIndex.of(model) + .getEffectiveAuthSchemes(context.settings().service(), + shape.getId(), + ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE); + + if (!authSchemes.isEmpty()) { + writer.addImport("smithy_core.shapes", "ShapeID"); + } + + for (var authSchemeId : authSchemes.keySet()) { + writer.write("ShapeID($S)", authSchemeId); + } + + } +} diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java index 4e9e7a335..b4caeee9e 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java @@ -21,10 +21,8 @@ import software.amazon.smithy.model.node.NumberNode; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.node.StringNode; -import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.python.codegen.CodegenUtils; import software.amazon.smithy.python.codegen.PythonSettings; -import software.amazon.smithy.python.codegen.SymbolProperties; import software.amazon.smithy.utils.SmithyUnstableApi; import software.amazon.smithy.utils.StringUtils; @@ -305,9 +303,7 @@ public String apply(Object type, String indent) { Symbol typeSymbol = (Symbol) type; // Check if the symbol is an operation - we shouldn't add imports for operations, since // they are methods of the service object and *can't* be imported - if (!isOperationSymbol(typeSymbol)) { - addUseImports(typeSymbol); - } + addUseImports(typeSymbol); return typeSymbol.getName(); } else if (type instanceof SymbolReference) { SymbolReference typeSymbol = (SymbolReference) type; @@ -320,10 +316,6 @@ public String apply(Object type, String indent) { } } - private Boolean isOperationSymbol(Symbol typeSymbol) { - return typeSymbol.getProperty(SymbolProperties.SHAPE).map(Shape::isOperationShape).orElse(false); - } - private final class PythonNodeFormatter implements BiFunction { private final PythonWriter writer; diff --git a/examples/weather/model/weather.smithy b/examples/weather/model/weather.smithy index 42a491ccb..76080f72e 100644 --- a/examples/weather/model/weather.smithy +++ b/examples/weather/model/weather.smithy @@ -26,6 +26,21 @@ resource City { resource Forecast { identifiers: { cityId: CityId } read: GetForecast, + update: PutForecast +} + +@http(method: "PUT", uri: "/city/{cityId}/forecast", code: 200) +@idempotent +operation PutForecast { + input := for Forecast { + @required + @httpLabel + $cityId + + chanceOfRain: Float + } + + output := {} } // "pattern" is a trait. @@ -154,3 +169,5 @@ structure GetForecastInput { structure GetForecastOutput { chanceOfRain: Float } + + diff --git a/packages/smithy-core/src/smithy_core/aio/client.py b/packages/smithy-core/src/smithy_core/aio/client.py index 12d46549a..72a552229 100644 --- a/packages/smithy-core/src/smithy_core/aio/client.py +++ b/packages/smithy-core/src/smithy_core/aio/client.py @@ -468,7 +468,7 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape]( operation=call.operation, request=response_context.transport_request, response=response_context.transport_response, - error_registry="foo", + error_registry=call.operation.error_registry, context=response_context.properties, ) diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index 7908bedbd..eb845b607 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -1,14 +1,14 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections.abc import AsyncIterable -from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any, Callable +from typing import Protocol, runtime_checkable, TYPE_CHECKING, Callable from ...exceptions import UnsupportedStreamException from ...interfaces import URI, Endpoint, TypedProperties from ...interfaces import StreamingBlob as SyncStreamingBlob from .eventstream import EventPublisher, EventReceiver - +from ...type_registry import TypeRegistry if TYPE_CHECKING: from ...schemas import APIOperation @@ -129,7 +129,7 @@ async def deserialize_response[ operation: "APIOperation[OperationInput, OperationOutput]", request: I, response: O, - error_registry: Any, # TODO: add error registry + error_registry: TypeRegistry, context: TypedProperties, ) -> OperationOutput: """Deserializes the output from the tranport response or throws an exception. diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index b9af55cf5..13eecc707 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType: """The Smithy data model type for the underlying contents of the document.""" return self._type + @property + def discriminator(self) -> ShapeID: + """The shape ID that corresponds to the contents of the document.""" + return self._schema.id + def is_none(self) -> bool: """Indicates whether the document contains a null value.""" return self._value is None and self._raw_value is None diff --git a/packages/smithy-core/src/smithy_core/protocols.py b/packages/smithy-core/src/smithy_core/protocols.py new file mode 100644 index 000000000..b1551ad49 --- /dev/null +++ b/packages/smithy-core/src/smithy_core/protocols.py @@ -0,0 +1,135 @@ +import os +from inspect import iscoroutinefunction +from io import BytesIO + +from smithy_core.aio.interfaces import ClientProtocol +from smithy_core.codecs import Codec +from smithy_core.deserializers import DeserializeableShape +from smithy_core.interfaces import Endpoint, TypedProperties, URI +from smithy_core.schemas import APIOperation +from smithy_core.serializers import SerializeableShape +from smithy_core.shapes import ShapeID +from smithy_core.traits import HTTPTrait, EndpointTrait, RestJson1Trait +from smithy_core.type_registry import TypeRegistry +from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse +from smithy_http.deserializers import HTTPResponseDeserializer +from smithy_http.serializers import HTTPRequestSerializer +from smithy_json import JSONCodec + + +class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]): + def set_service_endpoint( + self, + *, + request: HTTPRequest, + endpoint: Endpoint, + ) -> HTTPRequest: + """Update the endpoint of a transport request. + + :param request: The request whose endpoint should be updated. + :param endpoint: The endpoint to set on the request. + """ + uri = endpoint.uri + uri_builder = request.destination + + if uri.scheme: + uri_builder.scheme = uri.scheme + if uri.host: + uri_builder.host = uri.host + if uri.port and uri.port > -1: + uri_builder.port = uri.port + if uri.path: + # TODO: verify, uri helper? + uri_builder.path = os.path.join(uri.path, uri_builder.path or "") + # TODO: merge headers from the endpoint properties bag + return request + + +class HttpBindingClientProtocol(HttpClientProtocol): + @property + def codec(self) -> Codec: + """The codec used for the serde of input and output shapes.""" + ... + + @property + def content_type(self) -> str: + """The media type of the http payload.""" + ... + + def serialize_request[ + OperationInput: "SerializeableShape", + OperationOutput: "DeserializeableShape", + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + input: OperationInput, + endpoint: URI, + context: TypedProperties, + ) -> HTTPRequest: + # TODO: request binding cache like done in SJ + serializer = HTTPRequestSerializer( + payload_codec=self.codec, + http_trait=operation.schema.expect_trait(HTTPTrait), # TODO + endpoint_trait=operation.schema.get_trait(EndpointTrait), + ) + + input.serialize(serializer=serializer) + request = serializer.result + + if request is None: + raise ValueError("Request is None") # TODO + + request.fields["content-type"].add(self.content_type) + return request + + async def deserialize_response[ + OperationInput: "SerializeableShape", + OperationOutput: "DeserializeableShape", + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + request: HTTPRequest, + response: HTTPResponse, + error_registry: TypeRegistry, + context: TypedProperties, + ) -> OperationOutput: + if not (200 <= response.status <= 299): # TODO: extract to utility + # TODO: implement error serde from type registry + raise NotImplementedError + + body = response.body + # TODO: extract to utility, seems common + if (read := getattr(body, "read", None)) is not None and iscoroutinefunction( + read + ): + body = BytesIO(await read()) + + # TODO: response binding cache like done in SJ + deserializer = HTTPResponseDeserializer( + payload_codec=self.codec, + http_trait=operation.schema.expect_trait(HTTPTrait), + response=response, + body=body, # type: ignore + ) + + return operation.output.deserialize(deserializer) + + +class RestJsonClientProtocol(HttpBindingClientProtocol): + _id: ShapeID = RestJson1Trait.id + _codec: JSONCodec = JSONCodec() + _contentType: str = "application/json" + + @property + def id(self) -> ShapeID: + return self._id + + @property + def codec(self) -> Codec: + return self._codec + + @property + def content_type(self) -> str: + return self._contentType diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 72f8f66f8..6fbbe4ed8 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -7,7 +7,7 @@ from .exceptions import ExpectationNotMetException, SmithyException from .shapes import ShapeID, ShapeType from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait - +from .type_registry import TypeRegistry if TYPE_CHECKING: from .serializers import SerializeableShape @@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]: output_schema: Schema """The schema of the operation's output shape.""" - # TODO: Add a type registry for errors - error_registry: Any + error_registry: TypeRegistry """A TypeRegistry used to create errors.""" effective_auth_schemes: Sequence[ShapeID] diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index 5fbbe3765..b30b5fa7a 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -313,3 +313,9 @@ def host_prefix(self) -> str: class HostLabelTrait(Trait, id=ShapeID("smithy.api#hostLabel")): def __post_init__(self): assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class RestJson1Trait(Trait, id=ShapeID("aws.protocols#restJson1")): + def __post_init__(self): + assert self.document_value is None diff --git a/packages/smithy-core/src/smithy_core/type_registry.py b/packages/smithy-core/src/smithy_core/type_registry.py new file mode 100644 index 000000000..36d89d96d --- /dev/null +++ b/packages/smithy-core/src/smithy_core/type_registry.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from smithy_core.deserializers import ( + DeserializeableShape, +) # TODO: fix typo in deserializable +from smithy_core.documents import Document +from smithy_core.shapes import ShapeID + + +# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers. +# TODO: protocol? Also, move into documents.py? +class TypeRegistry: + def __init__( + self, + types: dict[ShapeID, type[DeserializeableShape]], + sub_registry: "TypeRegistry | None" = None, + ): + self._types = types + self._sub_registry = sub_registry + + def get(self, shape: ShapeID) -> type[DeserializeableShape]: + if shape in self._types: + return self._types[shape] + if self._sub_registry is not None: + return self._sub_registry.get(shape) + raise KeyError(f"Unknown shape: {shape}") + + def deserialize(self, document: Document) -> DeserializeableShape: + return document.as_shape(self.get(document.discriminator)) diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py new file mode 100644 index 000000000..db17e0d94 --- /dev/null +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -0,0 +1,48 @@ +from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.documents import Document +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.type_registry import TypeRegistry +import pytest + + +class TestTypeRegistry: + def test_get(self): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + result = registry.get(ShapeID("com.example#Test")) + + assert result == TestShape + + def test_get_sub_registry(self): + sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + registry = TypeRegistry({}, sub_registry) + + result = registry.get(ShapeID("com.example#Test")) + + assert result == TestShape + + def test_get_no_match(self): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"): + registry.get(ShapeID("com.example#Test2")) + + def test_deserialize(self): + shape_id = ShapeID("com.example#Test") + registry = TypeRegistry({shape_id: TestShape}) + + result = registry.deserialize(Document("abc123", schema=TestShape.schema)) + + assert isinstance(result, TestShape) and result.value == "abc123" + + +class TestShape(DeserializeableShape): + schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING) + + def __init__(self, value: str): + self.value = value + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape": + return TestShape(deserializer.read_string(schema=TestShape.schema))