Skip to content

Commit

Permalink
Make sure Context is set for Java WebSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
gmethvin authored and marcospereira committed Mar 28, 2016
1 parent 47e5d8b commit 22da246
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
Expand Up @@ -8,6 +8,7 @@
import akka.stream.javadsl.Sink;
import akka.stream.javadsl.Source;
import play.libs.F;
import play.mvc.Http;
import play.mvc.Results;
import play.mvc.WebSocket;
import scala.compat.java8.FutureConverters;
Expand Down Expand Up @@ -35,20 +36,26 @@ public class WebSocketSpecJavaActions {
}

public static WebSocket allowConsumingMessages(Promise<List<String>> messages) {
ensureContext();
return WebSocket.Text.accept(request -> Flow.fromSinkAndSource(getChunks(messages::success), emptySource()));
}

public static WebSocket allowSendingMessages(List<String> messages) {
ensureContext();
return WebSocket.Text.accept(request -> Flow.fromSinkAndSource(Sink.ignore(), Source.from(messages)));
}

public static WebSocket closeWhenTheConsumerIsDone() {
ensureContext();
return WebSocket.Text.accept(request -> Flow.fromSinkAndSource(Sink.cancelled(), emptySource()));
}

public static WebSocket allowRejectingAWebSocketWithAResult(int statusCode) {
ensureContext();
return WebSocket.Text.acceptOrResult(request -> CompletableFuture.completedFuture(F.Either.Left(Results.status(statusCode))));
}


private static Http.Context ensureContext() {
return Http.Context.current();
}
}
Expand Up @@ -3,25 +3,28 @@
*/
package play.it.http.websocket

import java.net.URI
import java.util.concurrent.atomic.AtomicReference
import java.util.function.{Consumer, Function}

import akka.actor._
import akka.stream.scaladsl._
import akka.util.ByteString
import org.specs2.matcher.Matcher
import play.api.Application
import play.api.http.websocket._
import play.api.inject.guice.GuiceApplicationBuilder
import play.api.test._
import play.api.Application
import play.api.mvc.{ Handler, Results, WebSocket }
import play.api.libs.iteratee._
import play.api.mvc.{Handler, Results, WebSocket}
import play.api.test._
import play.core.routing.HandlerDef
import play.it._
import play.it.http.websocket.WebSocketClient.{ContinuationMessage, SimpleMessage, ExtendedMessage}
import scala.concurrent.{ Future, Promise }
import scala.concurrent.duration._
import play.it.http.websocket.WebSocketClient.{ContinuationMessage, ExtendedMessage, SimpleMessage}

import scala.concurrent.ExecutionContext.Implicits.global
import java.net.URI
import java.util.concurrent.atomic.AtomicReference
import java.util.function.{ Consumer, Function }
import scala.concurrent.duration._
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag

object NettyWebSocketSpec extends WebSocketSpec with NettyIntegrationSpecification
object AkkaHttpWebSocketSpec extends WebSocketSpec with AkkaHttpIntegrationSpecification
Expand Down Expand Up @@ -412,15 +415,17 @@ trait WebSocketSpec extends PlaySpecification with WsTestClient with ServerInteg

"allow handling a WebSocket in java" in {

import java.util.{List => JList}

import play.core.routing.HandlerInvokerFactory
import play.core.routing.HandlerInvokerFactory._
import java.util.{List => JList}

import scala.collection.JavaConverters._

implicit def toHandler[J <: AnyRef](javaHandler: J)(implicit factory: HandlerInvokerFactory[J]): Handler = {
implicit def toHandler[J <: AnyRef](javaHandler: => J)(implicit factory: HandlerInvokerFactory[J], ct: ClassTag[J]): Handler = {
val invoker = factory.createInvoker(
javaHandler,
new HandlerDef(javaHandler.getClass.getClassLoader, "package", "controller", "method", Nil, "GET", "", "/stream")
new HandlerDef(ct.runtimeClass.getClassLoader, "package", "controller", "method", Nil, "GET", "", "/stream")
)
invoker.call(javaHandler)
}
Expand Down Expand Up @@ -452,7 +457,7 @@ trait WebSocketSpec extends PlaySpecification with WsTestClient with ServerInteg

import play.core.routing.HandlerInvokerFactory
import play.core.routing.HandlerInvokerFactory._
import play.mvc.{ LegacyWebSocket, WebSocket => JWebSocket, Results => JResults }
import play.mvc.{LegacyWebSocket, Results => JResults, WebSocket => JWebSocket}
import JWebSocket.{In, Out}

implicit def toHandler[J <: AnyRef](javaHandler: J)(implicit factory: HandlerInvokerFactory[J]): Handler = {
Expand Down
Expand Up @@ -4,16 +4,16 @@
package play.core.routing

import java.util.Optional
import java.util.concurrent.{ CompletableFuture, CompletionStage }
import java.util.concurrent.{CompletableFuture, CompletionStage}

import akka.stream.scaladsl.Flow
import org.apache.commons.lang3.reflect.MethodUtils
import play.api.mvc._
import play.core.j
import play.core.j.{ JavaHandlerComponents, JavaHandler, JavaActionAnnotations }
import play.mvc.Http.RequestBody
import play.core.j.{JavaActionAnnotations, JavaHandler, JavaHandlerComponents}
import play.mvc.Http.{Context, RequestBody}

import scala.compat.java8.{ OptionConverters, FutureConverters }
import scala.compat.java8.{FutureConverters, OptionConverters}
import scala.util.control.NonFatal

/**
Expand Down Expand Up @@ -69,9 +69,9 @@ trait HandlerInvokerFactory[-T] {

object HandlerInvokerFactory {

import play.mvc.{ Result => JResult, LegacyWebSocket, WebSocket => JWebSocket }
import play.core.j.JavaWebSocket
import com.fasterxml.jackson.databind.JsonNode
import play.core.j.JavaWebSocket
import play.mvc.{LegacyWebSocket, Result => JResult, WebSocket => JWebSocket}

private[routing] def handlerTags(handlerDef: HandlerDef): Map[String, String] = Map(
play.api.routing.Router.Tags.RoutePattern -> handlerDef.path,
Expand Down Expand Up @@ -196,12 +196,25 @@ object HandlerInvokerFactory {
}

implicit def javaWebSocket: HandlerInvokerFactory[JWebSocket] = new HandlerInvokerFactory[JWebSocket] {
import play.http.websocket.{ Message => JMessage }
import play.api.http.websocket._
import play.api.libs.iteratee.Execution.Implicits.trampoline
import play.http.websocket.{Message => JMessage}

def createInvoker(fakeCall: => JWebSocket, handlerDef: HandlerDef) = new HandlerInvoker[JWebSocket] {
def call(call: => JWebSocket) = WebSocket.acceptOrResult[Message, Message] { request =>
FutureConverters.toScala(call(new j.RequestHeaderImpl(request))).map { resultOrFlow =>

val javaContext = JavaWebSocket.createJavaContext(request)

val callWithContext = {
try {
Context.current.set(javaContext)
FutureConverters.toScala(call(new j.RequestHeaderImpl(request)))
} finally {
Context.current.remove()
}
}

callWithContext.map { resultOrFlow =>
if (resultOrFlow.left.isPresent) {
Left(resultOrFlow.left.get.asScala())
} else {
Expand Down

0 comments on commit 22da246

Please sign in to comment.