Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.traits.DocumentationTrait;
import software.amazon.smithy.model.traits.StringTrait;
import software.amazon.smithy.python.codegen.integration.PythonIntegration;
Expand Down Expand Up @@ -123,6 +124,16 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
}

private void generateOperationExecutor(PythonWriter writer) {
writer.pushState();

var hasStreaming = hasEventStream();
writer.putContext("hasEventStream", hasStreaming);
if (hasStreaming) {
writer.addImports("smithy_core.deserializers", Set.of(
"ShapeDeserializer", "DeserializeableShape"));
writer.addStdlibImport("typing", "Any");
}

var transportRequest = context.applicationProtocol().requestType();
var transportResponse = context.applicationProtocol().responseType();
var errorSymbol = CodegenUtils.getServiceError(context.settings());
Expand Down Expand Up @@ -191,10 +202,18 @@ async def _execute_operation(
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
operation_name: str,
${?hasEventStream}
has_input_stream: bool = False,
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
event_response_deserializer: DeserializeableShape | None = None,
${/hasEventStream}
) -> Output:
try:
return await self._handle_execution(
input, plugins, serialize, deserialize, config, operation_name
input, plugins, serialize, deserialize, config, operation_name,
${?hasEventStream}
has_input_stream, event_deserializer, event_response_deserializer,
${/hasEventStream}
)
except Exception as e:
# Make sure every exception that we throw is an instance of $4T so
Expand All @@ -211,6 +230,11 @@ async def _handle_execution(
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
operation_name: str,
${?hasEventStream}
has_input_stream: bool = False,
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
event_response_deserializer: DeserializeableShape | None = None,
${/hasEventStream}
) -> Output:
logger.debug(f"Making request for operation {operation_name} with parameters: {input}")
context: InterceptorContext[Input, None, None, None] = InterceptorContext(
Expand Down Expand Up @@ -326,7 +350,16 @@ await sleep(retry_token.retry_delay)
execution_context = cast(
InterceptorContext[Input, Output, $2T | None, $3T | None], context
)
${^hasEventStream}
return await self._finalize_execution(interceptors, execution_context)
${/hasEventStream}
${?hasEventStream}
operation_output = await self._finalize_execution(interceptors, execution_context)
if has_input_stream or event_deserializer is not None:
${6C|}
else:
return operation_output
${/hasEventStream}

async def _handle_attempt(
self,
Expand All @@ -342,7 +375,8 @@ async def _handle_attempt(
for interceptor in interceptors:
interceptor.read_before_attempt(context)

""", pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol);
""", pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol,
writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w)));

boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty();
writer.pushState(new ResolveIdentitySection());
Expand Down Expand Up @@ -604,6 +638,18 @@ async def _finalize_execution(
return context.response
""", transportRequest, transportResponse);
writer.dedent();
writer.popState();
}

private boolean hasEventStream() {
var streamIndex = EventStreamIndex.of(context.model());
var topDownIndex = TopDownIndex.of(context.model());
for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) {
if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) {
return true;
}
}
return false;
}

private void initializeHttpAuthParameters(PythonWriter writer) {
Expand Down Expand Up @@ -649,40 +695,7 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {

writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "",
operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> {
writer.writeDocs(() -> {
var docs = operation.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse(String.format("Invokes the %s operation.", operation.getId().getName()));

var inputDocs = input.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse("The operation's input.");

writer.write("""
$L

:param input: $L

:param plugins: A list of callables that modify the configuration dynamically.
Changes made by these plugins only apply for the duration of the operation
execution and will not affect any other operation invocations.""", docs, inputDocs);
});

var defaultPlugins = new LinkedHashSet<SymbolReference>();
for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) {
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
}
}
}
writer.write("""
operation_plugins: list[Plugin] = [
$C
]
if plugins:
operation_plugins.extend(plugins)
""", writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));
writeSharedOperationInit(writer, operation, input);

if (context.protocolGenerator() == null) {
writer.write("raise NotImplementedError()");
Expand All @@ -704,16 +717,55 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
});
}

private void writeSharedOperationInit(PythonWriter writer, OperationShape operation, Shape input) {
writer.writeDocs(() -> {
var docs = operation.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse(String.format("Invokes the %s operation.", operation.getId().getName()));

var inputDocs = input.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse("The operation's input.");

writer.write("""
$L

:param input: $L

:param plugins: A list of callables that modify the configuration dynamically.
Changes made by these plugins only apply for the duration of the operation
execution and will not affect any other operation invocations.""", docs, inputDocs);
});

var defaultPlugins = new LinkedHashSet<SymbolReference>();
for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) {
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
}
}
}
writer.write("""
operation_plugins: list[Plugin] = [
$C
]
if plugins:
operation_plugins.extend(plugins)
""", writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));

}

private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) {
writer.pushState();
writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM);
writer.addImports("smithy_event_stream.aio.interfaces", Set.of(
"EventStream", "InputEventStream", "OutputEventStream"));
var operationSymbol = context.symbolProvider().toSymbol(operation);
writer.putContext("operationName", operationSymbol.getName());
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
writer.putContext("plugin", pluginSymbol);

var input = context.model().expectShape(operation.getInputShape());
var inputSymbol = context.symbolProvider().toSymbol(input);
writer.putContext("input", inputSymbol);

var eventStreamIndex = EventStreamIndex.of(context.model());
var inputStreamSymbol = eventStreamIndex.getInputInfo(operation)
Expand All @@ -724,22 +776,107 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op

var output = context.model().expectShape(operation.getOutputShape());
var outputSymbol = context.symbolProvider().toSymbol(output);
writer.putContext("output", outputSymbol);

var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation)
.map(EventStreamInfo::getEventStreamTarget)
.map(target -> context.symbolProvider().toSymbol(target))
.orElse(null);
writer.putContext("outputStream", outputStreamSymbol);

writer.write("""
async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[
${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\
${^inputStream}None${/inputStream},
${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\
${^outputStream}None${/outputStream},
$T
]:
raise NotImplementedError()
""", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol);
writer.putContext("hasProtocol", context.protocolGenerator() != null);
if (context.protocolGenerator() != null) {
var serSymbol = context.protocolGenerator().getSerializationFunction(context, operation);
writer.putContext("serSymbol", serSymbol);
var deserSymbol = context.protocolGenerator().getDeserializationFunction(context, operation);
writer.putContext("deserSymbol", deserSymbol);
} else {
writer.putContext("serSymbol", null);
writer.putContext("deserSymbol", null);
}

if (inputStreamSymbol != null) {
if (outputStreamSymbol != null) {
writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream");
writer.write("""
async def ${operationName:L}(
self,
input: ${input:T},
plugins: list[${plugin:T}] | None = None
) -> DuplexEventStream[${inputStream:T}, ${outputStream:T}, ${output:T}]:
${C|}
${^hasProtocol}
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
has_input_stream=True,
event_deserializer=$T().deserialize,
event_response_deserializer=${output:T},
) # type: ignore
${/hasProtocol}
""",
writer.consumer(w -> writeSharedOperationInit(w, operation, input)),
outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER));
} else {
writer.addImport("smithy_event_stream.aio.interfaces", "InputEventStream");
writer.write("""
async def ${operationName:L}(
self,
input: ${input:T},
plugins: list[${plugin:T}] | None = None
) -> InputEventStream[${inputStream:T}, ${output:T}]:
${C|}
${^hasProtocol}
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
has_input_stream=True,
) # type: ignore
${/hasProtocol}
""", writer.consumer(w -> writeSharedOperationInit(w, operation, input)));
}
} else {
writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream");
writer.write("""
async def ${operationName:L}(
self,
input: ${input:T},
plugins: list[${plugin:T}] | None = None
) -> OutputEventStream[${outputStream:T}, ${output:T}]:
${C|}
${^hasProtocol}
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
event_deserializer=$T().deserialize,
event_response_deserializer=${output:T},
) # type: ignore
${/hasProtocol}
""",
writer.consumer(w -> writeSharedOperationInit(w, operation, input)),
outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER));
}

writer.popState();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,26 @@ public final class SmithyPythonDependency {
false
);

/**
* Core interfaces for event streams.
*/
public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency(
"smithy_event_stream",
"==0.0.1",
Type.DEPENDENCY,
false
);

/**
* EventStream implementations for application/vnd.amazon.eventstream.
*/
public static final PythonDependency AWS_EVENT_STREAM = new PythonDependency(
"aws_event_stream",
"==0.0.1",
Type.DEPENDENCY,
false
);

/**
* testing framework used in generated functional tests.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.smithy.model.shapes.ToShapeId;
import software.amazon.smithy.python.codegen.ApplicationProtocol;
import software.amazon.smithy.python.codegen.GenerationContext;
import software.amazon.smithy.python.codegen.PythonWriter;
import software.amazon.smithy.utils.CaseUtils;
import software.amazon.smithy.utils.SmithyUnstableApi;

Expand Down Expand Up @@ -167,4 +168,25 @@ default void generateSharedDeserializerComponents(GenerationContext context) {
*/
default void generateProtocolTests(GenerationContext context) {
}

/**
* Generates the code to wrap an operation output into an event stream.
*
* <p>Important context variables are:
* <ul>
* <li>execution_context - Has the context, including the transport input and output.</li>
* <li>operation_output - The deserialized operation output.</li>
* <li>has_input_stream - Whether or not there is an input stream.</li>
* <li>event_deserializer - The deserialize method for output events, or None for no output stream.</li>
* <li>event_response_deserializer - A DeserializeableShape representing the operation's output shape,
* or None for no output stream. This is used when the operation sends the initial response over the
* event stream.
* </li>
* </ul>
*
* @param context Generation context.
* @param writer The writer to write to.
*/
default void wrapEventStream(GenerationContext context, PythonWriter writer) {
}
}
Loading
Loading