Skip to content

Commit

Permalink
Temporarily restore previous Web Socket tests (#3653)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski committed Apr 11, 2024
1 parent 10bab7b commit c673b18
Showing 1 changed file with 13 additions and 235 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,20 @@ package sttp.tapir.server.tests
import cats.effect.IO
import cats.syntax.all._
import io.circe.generic.auto._
import org.scalatest.EitherValues
import org.scalatest.matchers.should.Matchers._
import sttp.capabilities.{Streams, WebSockets}
import sttp.client3._
import sttp.model.StatusCode
import sttp.monad.MonadError
import sttp.tapir._
import sttp.tapir.generic.auto._
import sttp.tapir.json.circe._
import sttp.tapir.model.UnsupportedWebSocketFrameException
import sttp.tapir.server.interceptor.CustomiseInterceptors
import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor
import sttp.tapir.server.tests.ServerMetricsTest._
import sttp.tapir.tests.Test
import sttp.tapir.tests.data.Fruit
import sttp.ws.{WebSocket, WebSocketFrame}

import scala.concurrent.duration._

abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
createServerTest: CreateServerTest[F, S with WebSockets, OPTIONS, ROUTE],
val streams: S,
Expand All @@ -30,7 +25,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
handlePong: Boolean
)(implicit
m: MonadError[F]
) extends EitherValues {
) {
import createServerTest._

def functionToPipe[A, B](f: A => B): streams.Pipe[A, B]
Expand All @@ -51,28 +46,11 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
_ <- ws.sendText("test2")
m1 <- ws.receiveText()
m2 <- ws.receiveText()
_ <- ws.close()
m3 <- ws.eitherClose(ws.receiveText())
} yield List(m1, m2, m3)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right(List("echo: test1", "echo: test2", Left(WebSocketFrame.Close(1000, "normal closure")))))
},
testServer(
endpoint.in("elsewhere").out(stringBody),
"WS handshake to a non-existing endpoint"
)((_: Unit) => pureResult("hello".asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test")
m1 <- ws.receiveText()
} yield m1
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.code shouldBe StatusCode.NotFound)
.map(_.body shouldBe Right(List("echo: test1", "echo: test2")))
}, {

val reqCounter = newRequestCounter[F]
Expand Down Expand Up @@ -143,60 +121,6 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
)
)
},
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(true)
),
"pong on ping"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.send(WebSocketFrame.Ping("test-ping-text".getBytes()))
m1 <- ws.receive()
_ <- ws.sendText("test2")
m2 <- ws.receive()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map((r: Response[Either[String, List[WebSocketFrame]]]) =>
assert(
r.body.value exists {
case WebSocketFrame.Pong(array) => array sameElements "test-ping-text".getBytes
case _ => false
},
s"Missing Pong(test-ping-text) in ${r.body}"
)
)
},
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
),
"not pong on ping if disabled"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.send(WebSocketFrame.Ping("test-ping-text".getBytes()))
m1 <- ws.receiveText()
_ <- ws.sendText("test2")
m2 <- ws.receiveText()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(
_.body shouldBe Right(List("echo: test1", "echo: test2"))
)
},
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)),
"empty client stream"
Expand All @@ -209,164 +133,18 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
},
testServer(
endpoint
.in(query[String]("securityToken"))
.in(isWebSocket)
.errorOut(stringBody)
.out(stringWs),
"switch to WS after a normal HTTP request"
)(token => if (token == "correctToken") pureResult(stringEcho.asRight) else pureResult("Incorrect token!".asLeft)) {
(backend, baseUri) =>
for {
response1 <- basicRequest
.response(asString)
.get(uri"$baseUri?securityToken=wrong".scheme("http"))
.send(backend)
response2 <- basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
ws.sendText("testOk") >> ws.receiveText()
})
.get(uri"$baseUri?securityToken=correctToken".scheme("ws"))
.send(backend)
} yield {
response1.body shouldBe Left("Incorrect token!")
response2.body shouldBe Right("echo: testOk")
}
},
testServer(
endpoint
.in(query[String]("securityToken"))
.errorOut(stringBody)
.out(stringWs),
"reject WS handshake, then accept a corrected one"
)(token => if (token == "correctToken") pureResult(stringEcho.asRight) else pureResult("Incorrect token!".asLeft)) {
(backend, baseUri) =>
for {
response1 <- basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
ws.sendText("testWrong") >> ws.receiveText()
})
.get(uri"$baseUri?securityToken=wrong".scheme("ws"))
.send(backend)
response2 <- basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
ws.sendText("testOk") >> ws.receiveText()
})
.get(uri"$baseUri?securityToken=correctToken".scheme("ws"))
.send(backend)
} yield {
response1.code shouldBe StatusCode.BadRequest
response2.body shouldBe Right("echo: testOk")
}
"non web-socket request"
)(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { (backend, baseUri) =>
basicRequest
.response(asString)
.get(baseUri.scheme("http"))
.send(backend)
.map(_.body shouldBe Left("Not a WS!"))
}
) ++ autoPingTests ++ failingPipeTests ++ handlePongTests

val autoPingTests =
if (autoPing)
List(
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(Some((50.millis, WebSocketFrame.ping)))
),
"auto ping"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- IO.sleep(150.millis)
_ <- ws.sendText("test2")
m1 <- ws.receive()
m2 <- ws.receive()
_ <- ws.sendText("test3")
m3 <- ws.receive()
} yield List(m1, m2, m3)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map((r: Response[Either[String, List[WebSocketFrame]]]) =>
assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Ping]), s"Missing Ping frame in WS responses: $r")
)
}
)
else List.empty

// Optional, because some backends don't handle exceptions in the pipe gracefully, they just swallow it silently and hang forever
val failingPipeTests =
if (failingPipe)
List(
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)),
"failing pipe"
)((_: Unit) =>
pureResult(functionToPipe[String, String] {
case "error-trigger" => throw new Exception("Boom!")
case msg => s"echo: $msg"
}.asRight[Unit])
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.sendText("test2")
_ <- ws.sendText("error-trigger")
m1 <- ws.eitherClose(ws.receiveText())
m2 <- ws.eitherClose(ws.receiveText())
m3 <- ws.eitherClose(ws.receiveText())
} yield List(m1, m2, m3)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { r =>
val results = r.body.map(_.map(_.left.map(_.statusCode))).value
results.take(2) shouldBe
List(Right("echo: test1"), Right("echo: test2"))
val closeCode = results.last.left.value
assert(closeCode == 1000 || closeCode == 1011) // some servers respond with Close(normal), some with Close(error)
}
}
)
else List.empty

val handlePongTests =
if (handlePong)
List(
testServer(
{
implicit def textOrPongWebSocketFrame[A, CF <: CodecFormat](implicit
stringCodec: Codec[String, A, CF]
): Codec[WebSocketFrame, A, CF] =
Codec // A custom codec to handle Pongs
.id[WebSocketFrame, CF](stringCodec.format, Schema.string)
.mapDecode {
case WebSocketFrame.Text(p, _, _) => stringCodec.decode(p)
case WebSocketFrame.Pong(payload) =>
stringCodec.decode(new String(payload))
case f => DecodeResult.Error(f.toString, new UnsupportedWebSocketFrameException(f))
}(a => WebSocketFrame.text(stringCodec.encode(a)))
.schema(stringCodec.schema)
)

endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.ignorePong(false)
)
},
"not ignore pong"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.send(WebSocketFrame.Pong("test-pong-text".getBytes()))
m1 <- ws.receiveText()
_ <- ws.sendText("test2")
m2 <- ws.receiveText()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right(List("echo: test1", "echo: test-pong-text")))
}
)
else List.empty
// TODO: tests for ping/pong (control frames handling)
}

0 comments on commit c673b18

Please sign in to comment.