Skip to content

Commit

Permalink
Allow to use send as handler name, by adding a prefix in front of it …
Browse files Browse the repository at this point in the history
…when generating the client.
  • Loading branch information
slinkydeveloper committed Apr 11, 2024
1 parent 37c83e3 commit 017538c
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,24 @@
import java.io.Writer;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class HandlebarsTemplateEngine {

private final String baseTemplateName;
private final Map<ComponentType, Template> templates;
private final Set<String> handlerNamesToPrefix;

public HandlebarsTemplateEngine(
String baseTemplateName,
TemplateLoader templateLoader,
Map<ComponentType, String> templates) {
String baseTemplateName,
TemplateLoader templateLoader,
Map<ComponentType, String> templates,
Set<String> handlerNamesToPrefix) {
this.baseTemplateName = baseTemplateName;
this.handlerNamesToPrefix = handlerNamesToPrefix;

Handlebars handlebars = new Handlebars(templateLoader);
Handlebars handlebars = new Handlebars(templateLoader);
handlebars.registerHelpers(StringHelpers.class);

this.templates =
Expand Down Expand Up @@ -65,7 +69,7 @@ public void generate(ThrowingFunction<String, Writer> createFile, Component comp
this.templates
.get(component.getComponentType())
.apply(
Context.newBuilder(new ComponentTemplateModel(component, this.baseTemplateName))
Context.newBuilder(new ComponentTemplateModel(component, this.baseTemplateName, this.handlerNamesToPrefix))
.resolver(FieldValueResolver.INSTANCE)
.build(),
out);
Expand All @@ -86,7 +90,7 @@ static class ComponentTemplateModel {
public final boolean isService;
public final List<HandlerTemplateModel> handlers;

private ComponentTemplateModel(Component inner, String baseTemplateName) {
private ComponentTemplateModel(Component inner, String baseTemplateName, Set<String> handlerNamesToPrefix) {
this.originalClassPkg = inner.getTargetPkg().toString();
this.originalClassFqcn = inner.getTargetFqcn().toString();
this.generatedClassSimpleNamePrefix = inner.getSimpleComponentName();
Expand All @@ -99,12 +103,13 @@ private ComponentTemplateModel(Component inner, String baseTemplateName) {
this.isService = inner.getComponentType() == ComponentType.SERVICE;

this.handlers =
inner.getMethods().stream().map(HandlerTemplateModel::new).collect(Collectors.toList());
inner.getMethods().stream().map(h -> new HandlerTemplateModel(h, handlerNamesToPrefix)).collect(Collectors.toList());
}
}

static class HandlerTemplateModel {
public final String name;
public final String methodName;
public final String handlerType;
public final boolean isWorkflow;
public final boolean isShared;
Expand All @@ -123,8 +128,9 @@ static class HandlerTemplateModel {
public final String boxedOutputFqcn;
public final String outputSerdeFieldName;

private HandlerTemplateModel(Handler inner) {
private HandlerTemplateModel(Handler inner, Set<String> handlerNamesToPrefix) {
this.name = inner.getName().toString();
this.methodName = (handlerNamesToPrefix.contains(this.name) ? "_" : "") + this.name;
this.handlerType = inner.getHandlerType().toString();
this.isWorkflow = inner.getHandlerType() == HandlerType.WORKFLOW;
this.isShared = inner.getHandlerType() == HandlerType.SHARED;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class ComponentProcessor extends AbstractProcessor {
private HandlebarsTemplateEngine bindableComponentCodegen;
private HandlebarsTemplateEngine clientCodegen;

private final static Set<String> RESERVED_METHOD_NAMES = Set.of("send");

@Override
public synchronized void init(ProcessingEnvironment processingEnv) {
super.init(processingEnv);
Expand All @@ -55,7 +57,8 @@ public synchronized void init(ProcessingEnvironment processingEnv) {
ComponentType.SERVICE,
"templates/BindableComponentFactory.hbs",
ComponentType.VIRTUAL_OBJECT,
"templates/BindableComponentFactory.hbs"));
"templates/BindableComponentFactory.hbs"),
RESERVED_METHOD_NAMES);
this.bindableComponentCodegen =
new HandlebarsTemplateEngine(
"BindableComponent",
Expand All @@ -66,7 +69,7 @@ public synchronized void init(ProcessingEnvironment processingEnv) {
ComponentType.SERVICE,
"templates/BindableComponent.hbs",
ComponentType.VIRTUAL_OBJECT,
"templates/BindableComponent.hbs"));
"templates/BindableComponent.hbs"), RESERVED_METHOD_NAMES);
this.clientCodegen =
new HandlebarsTemplateEngine(
"Client",
Expand All @@ -77,7 +80,7 @@ public synchronized void init(ProcessingEnvironment processingEnv) {
ComponentType.SERVICE,
"templates/Client.hbs",
ComponentType.VIRTUAL_OBJECT,
"templates/Client.hbs"));
"templates/Client.hbs"), RESERVED_METHOD_NAMES);
}

@Override
Expand Down
28 changes: 14 additions & 14 deletions sdk-api-gen/src/main/resources/templates/Client.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class {{generatedClassSimpleName}} {
}

{{#handlers}}
public Awaitable<{{{boxedOutputFqcn}}}> {{name}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
public Awaitable<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.ctx.call(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -65,7 +65,7 @@ public class {{generatedClassSimpleName}} {
}

{{#handlers}}
public void {{name}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
public void {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
ContextClient.this.ctx.send(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, ContextClient.this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -86,13 +86,13 @@ public class {{generatedClassSimpleName}} {
}

{{#handlers}}
public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{name}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
{{^outputEmpty}}return {{/outputEmpty}}this.{{name}}(
public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
{{^outputEmpty}}return {{/outputEmpty}}this.{{methodName}}(
{{^inputEmpty}}req, {{/inputEmpty}}
dev.restate.sdk.client.RequestOptions.DEFAULT);
}

public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{name}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
{{^outputEmpty}}return {{/outputEmpty}}this.ingressClient.call(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -101,13 +101,13 @@ public class {{generatedClassSimpleName}} {
requestOptions);
}

public {{#if outputEmpty}}java.util.concurrent.CompletableFuture<Void>{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{name}}Async(
public {{#if outputEmpty}}java.util.concurrent.CompletableFuture<Void>{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{methodName}}Async(
{{^inputEmpty}}req, {{/inputEmpty}}
dev.restate.sdk.client.RequestOptions.DEFAULT);
}

public {{#if outputEmpty}}java.util.concurrent.CompletableFuture<Void>{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
public {{#if outputEmpty}}java.util.concurrent.CompletableFuture<Void>{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
return this.ingressClient.callAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -133,13 +133,13 @@ public class {{generatedClassSimpleName}} {
}

{{#handlers}}
public String {{name}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{name}}(
public String {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{methodName}}(
{{^inputEmpty}}req, {{/inputEmpty}}
dev.restate.sdk.client.RequestOptions.DEFAULT);
}

public String {{name}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
public String {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
return IngressClient.this.ingressClient.send(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -148,13 +148,13 @@ public class {{generatedClassSimpleName}} {
requestOptions);
}

public java.util.concurrent.CompletableFuture<String> {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{name}}Async(
public java.util.concurrent.CompletableFuture<String> {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) {
return this.{{methodName}}Async(
{{^inputEmpty}}req, {{/inputEmpty}}
dev.restate.sdk.client.RequestOptions.DEFAULT);
}

public java.util.concurrent.CompletableFuture<String> {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
public java.util.concurrent.CompletableFuture<String> {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) {
return IngressClient.this.ingressClient.sendAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand Down
9 changes: 9 additions & 0 deletions sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ public void primitiveInput(Context context, int input) {
}
}

@VirtualObject
static class CornerCases {
@Exclusive
public String send(ObjectContext context, String request) {
// Just needs to compile
return CodegenTestCornerCasesClient.fromContext(context, request)._send("my_send").await();
}
}

@Override
public Stream<TestDefinitions.TestDefinition> definitions() {
return Stream.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,31 @@ import java.nio.charset.Charset
class ComponentProcessor(private val logger: KSPLogger, private val codeGenerator: CodeGenerator) :
SymbolProcessor {

companion object {
private val RESERVED_METHOD_NAMES: Set<String> = setOf("send")
}

private val bindableComponentFactoryCodegen: HandlebarsTemplateEngine =
HandlebarsTemplateEngine(
"BindableComponentFactory",
ClassPathTemplateLoader(),
mapOf(
ComponentType.SERVICE to "templates/BindableComponentFactory",
ComponentType.VIRTUAL_OBJECT to "templates/BindableComponentFactory"))
ComponentType.VIRTUAL_OBJECT to "templates/BindableComponentFactory"), RESERVED_METHOD_NAMES)
private val bindableComponentCodegen: HandlebarsTemplateEngine =
HandlebarsTemplateEngine(
"BindableComponent",
ClassPathTemplateLoader(),
mapOf(
ComponentType.SERVICE to "templates/BindableComponent",
ComponentType.VIRTUAL_OBJECT to "templates/BindableComponent"))
ComponentType.VIRTUAL_OBJECT to "templates/BindableComponent"), RESERVED_METHOD_NAMES)
private val clientCodegen: HandlebarsTemplateEngine =
HandlebarsTemplateEngine(
"Client",
ClassPathTemplateLoader(),
mapOf(
ComponentType.SERVICE to "templates/Client",
ComponentType.VIRTUAL_OBJECT to "templates/Client"))
ComponentType.VIRTUAL_OBJECT to "templates/Client"), RESERVED_METHOD_NAMES)

override fun process(resolver: Resolver): List<KSAnnotated> {
val converter = KElementConverter(logger, resolver.builtIns)
Expand Down
8 changes: 4 additions & 4 deletions sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object {{generatedClassSimpleName}} {

class ContextClient(private val ctx: Context{{#isObject}}, private val key: String{{/isObject}}){
{{#handlers}}
suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}): Awaitable<{{{boxedOutputFqcn}}}> {
suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}): Awaitable<{{{boxedOutputFqcn}}}> {
return this.ctx.callAsync(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -46,7 +46,7 @@ object {{generatedClassSimpleName}} {

inner class Send(private val delay: Duration) {
{{#handlers}}
suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}) {
suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}) {
this@ContextClient.ctx.send(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this@ContextClient.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -59,7 +59,7 @@ object {{generatedClassSimpleName}} {
class IngressClient(private val ingressClient: dev.restate.sdk.client.IngressClient{{#isObject}}, private val key: String{{/isObject}}) {

{{#handlers}}
suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): {{{boxedOutputFqcn}}} {
suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): {{{boxedOutputFqcn}}} {
return this.ingressClient.callSuspend(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand All @@ -74,7 +74,7 @@ object {{generatedClassSimpleName}} {

inner class Send(private val delay: Duration) {
{{#handlers}}
suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): String {
suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): String {
return this@IngressClient.ingressClient.sendSuspend(
{{#if isObject}}Target.virtualObject(COMPONENT_NAME, this@IngressClient.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}},
{{inputSerdeFieldName}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ class CodegenTest : TestDefinitions.TestSuite {
}
}

@VirtualObject
class CornerCases {
@Exclusive
suspend fun send(context: ObjectContext, request: String): String {
// Just needs to compile
return CodegenTestCornerCasesClient.fromContext(context, request)._send("my_send").await()
}
}

override fun definitions(): Stream<TestDefinition> {
return Stream.of(
testInvocation({ ServiceGreeter() }, "greet")
Expand Down

0 comments on commit 017538c

Please sign in to comment.