Skip to content

Commit

Permalink
WebSockets Next: produce ExecutionModelAnnotationsAllowedBuildItem
Browse files Browse the repository at this point in the history
- so that callback methods can be annotated with Blocking, NonBlocking
and RunOnVirtualThread
  • Loading branch information
mkouba committed Apr 22, 2024
1 parent ace673e commit 8df1abe
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkus.websockets.next.deployment;

import java.util.List;

import org.jboss.jandex.DotName;

import io.quarkus.websockets.next.OnBinaryMessage;
Expand All @@ -12,6 +14,7 @@
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;
import io.smallrye.common.annotation.RunOnVirtualThread;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand All @@ -33,6 +36,7 @@ final class WebSocketDotNames {
static final DotName MULTI = DotName.createSimple(Multi.class);
static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class);
static final DotName BLOCKING = DotName.createSimple(Blocking.class);
static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class);
static final DotName STRING = DotName.createSimple(String.class);
static final DotName BUFFER = DotName.createSimple(Buffer.class);
static final DotName JSON_OBJECT = DotName.createSimple(JsonObject.class);
Expand All @@ -41,4 +45,7 @@ final class WebSocketDotNames {
static final DotName PATH_PARAM = DotName.createSimple(PathParam.class);
static final DotName HANDSHAKE_REQUEST = DotName.createSimple(WebSocketConnection.HandshakeRequest.class);
static final DotName THROWABLE = DotName.createSimple(Throwable.class);

static final List<DotName> CALLBACK_ANNOTATIONS = List.of(ON_OPEN, ON_CLOSE, ON_BINARY_MESSAGE, ON_TEXT_MESSAGE,
ON_PONG_MESSAGE, ON_ERROR);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -36,6 +37,7 @@
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.Annotations;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.DotNames;
Expand All @@ -47,6 +49,7 @@
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.execannotations.ExecutionModelAnnotationsAllowedBuildItem;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.CatchBlockCreator;
import io.quarkus.gizmo.ClassCreator;
Expand Down Expand Up @@ -117,6 +120,18 @@ void unremovableBeans(BuildProducer<UnremovableBeanBuildItem> unremovableBeans)
unremovableBeans.produce(UnremovableBeanBuildItem.beanTypes(TextMessageCodec.class));
}

@BuildStep
ExecutionModelAnnotationsAllowedBuildItem executionModelAnnotations(
TransformedAnnotationsBuildItem transformedAnnotations) {
return new ExecutionModelAnnotationsAllowedBuildItem(new Predicate<MethodInfo>() {
@Override
public boolean test(MethodInfo method) {
return Annotations.containsAny(transformedAnnotations.getAnnotations(method),
WebSocketDotNames.CALLBACK_ANNOTATIONS);
}
});
}

@BuildStep
public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
Expand Down Expand Up @@ -1006,7 +1021,8 @@ private List<Callback> findErrorHandlers(IndexView index, ClassInfo beanClass, C
List<Callback> errorHandlers = new ArrayList<>();
for (AnnotationInstance annotation : annotations) {
MethodInfo method = annotation.target().asMethod();
Callback callback = new Callback(annotation, method, executionModel(method), callbackArguments,
Callback callback = new Callback(annotation, method, executionModel(method, transformedAnnotations),
callbackArguments,
transformedAnnotations, endpointPath, index);
long errorArguments = callback.arguments.stream().filter(ca -> ca instanceof ErrorCallbackArgument).count();
if (errorArguments != 1) {
Expand Down Expand Up @@ -1052,7 +1068,8 @@ private Callback findCallback(IndexView index, ClassInfo beanClass, DotName anno
} else if (annotations.size() == 1) {
AnnotationInstance annotation = annotations.get(0);
MethodInfo method = annotation.target().asMethod();
Callback callback = new Callback(annotation, method, executionModel(method), callbackArguments,
Callback callback = new Callback(annotation, method, executionModel(method, transformedAnnotations),
callbackArguments,
transformedAnnotations, endpointPath, index);
long messageArguments = callback.arguments.stream().filter(ca -> ca instanceof MessageCallbackArgument).count();
if (callback.acceptsMessage()) {
Expand Down Expand Up @@ -1081,13 +1098,16 @@ private Callback findCallback(IndexView index, ClassInfo beanClass, DotName anno
String.format("There can be only one callback annotated with %s declared on %s", annotationName, beanClass));
}

ExecutionModel executionModel(MethodInfo method) {
if (hasBlockingSignature(method)) {
return method.hasDeclaredAnnotation(WebSocketDotNames.RUN_ON_VIRTUAL_THREAD) ? ExecutionModel.VIRTUAL_THREAD
: ExecutionModel.WORKER_THREAD;
ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem transformedAnnotations) {
if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD)) {
return ExecutionModel.VIRTUAL_THREAD;
} else if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.BLOCKING)) {
return ExecutionModel.WORKER_THREAD;
} else if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.NON_BLOCKING)) {
return ExecutionModel.EVENT_LOOP;
} else {
return hasBlockingSignature(method) ? ExecutionModel.WORKER_THREAD : ExecutionModel.EVENT_LOOP;
}
return method.hasDeclaredAnnotation(WebSocketDotNames.BLOCKING) ? ExecutionModel.WORKER_THREAD
: ExecutionModel.EVENT_LOOP;
}

boolean hasBlockingSignature(MethodInfo method) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package io.quarkus.websockets.next.test.executionmodel;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Context;
import io.vertx.core.Vertx;

public class BlockingAnnotationTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("endpoint")
URI endUri;

@Test
void testEndoint() {
try (WSClient client = new WSClient(vertx).connect(endUri)) {
assertEquals("evenloop:false,worker:true", client.sendAndAwaitReply("foo").toString());
}
}

@WebSocket(path = "/endpoint")
public static class Endpoint {

@Blocking
@OnTextMessage
Uni<String> message(String ignored) {
return Uni.createFrom().item("evenloop:" + Context.isOnEventLoopThread() + ",worker:" + Context.isOnWorkerThread());
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.quarkus.websockets.next.test.executionmodel;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.smallrye.common.annotation.NonBlocking;
import io.vertx.core.Context;
import io.vertx.core.Vertx;

public class NonBlockingAnnotationTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("endpoint")
URI endUri;

@Test
void testEndoint() {
try (WSClient client = new WSClient(vertx).connect(endUri)) {
assertEquals("evenloop:true,worker:false", client.sendAndAwaitReply("foo").toString());
}
}

@WebSocket(path = "/endpoint")
public static class Endpoint {

@NonBlocking
@OnTextMessage
String message(String ignored) {
return "evenloop:" + Context.isOnEventLoopThread() + ",worker:" + Context.isOnWorkerThread();
}

}

}

0 comments on commit 8df1abe

Please sign in to comment.