Skip to content

Commit

Permalink
Better error messages. More tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
treeform committed Jul 19, 2020
1 parent 46f3bce commit d0b31d1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/ws.nim
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ proc handshake*(ws: WebSocket, headers: HttpHeaders) {.async.} =
ws.version = parseInt(headers["Sec-WebSocket-Version"])
ws.key = headers["Sec-WebSocket-Key"].strip()
if headers.hasKey("Sec-WebSocket-Protocol"):
ws.protocol = headers["Sec-WebSocket-Protocol"].strip()
let wantProtocol = headers["Sec-WebSocket-Protocol"].strip()
if ws.protocol != wantProtocol:
raise newException(WebSocketError,
&"Protocol mismatch (expected: {ws.protocol}, got: {wantProtocol})")

let
sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
Expand All @@ -79,16 +82,17 @@ proc handshake*(ws: WebSocket, headers: HttpHeaders) {.async.} =
await ws.tcpSocket.send(response)
ws.readyState = Open

proc newWebSocket*(req: Request): Future[WebSocket] {.async.} =
proc newWebSocket*(req: Request, protocol: string = ""): Future[WebSocket] {.async.} =
## Creates a new socket from a request.
try:
if not req.headers.hasKey("Sec-WebSocket-Version"):
await req.respond(Http404, "Not Found")
raise newException(WebSocketError, "Not a valid websocket handshake.")
raise newException(WebSocketError, "Invalid WebSocket handshake")

var ws = WebSocket()
ws.masked = false
ws.tcpSocket = req.client
ws.protocol = protocol
await ws.handshake(req.headers)
return ws

Expand Down Expand Up @@ -118,7 +122,7 @@ proc newWebSocket*(url: string, protocol: string = ""): Future[WebSocket] {.asyn
port = Port(80)
else:
raise newException(WebSocketError,
&"Scheme {uri.scheme} not supported yet.")
&"Scheme {uri.scheme} not supported yet")
if uri.port.len > 0:
port = Port(parseInt(uri.port))

Expand All @@ -142,8 +146,10 @@ proc newWebSocket*(url: string, protocol: string = ""): Future[WebSocket] {.asyn
client.headers["Sec-WebSocket-Protocol"] = ws.protocol
var res = await client.get($uri)
if ws.protocol != "":
if ws.protocol != res.headers["Sec-WebSocket-Protocol"]:
raise newException(WebSocketError, "Protocols don't match")
let resProtocol = res.headers["Sec-WebSocket-Protocol"]
if ws.protocol != resProtocol:
raise newException(WebSocketError,
&"Protocol mismatch (expected: {ws.protocol}, got: {resProtocol})")
ws.tcpSocket = client.getSocket()

ws.readyState = Open
Expand Down Expand Up @@ -293,7 +299,7 @@ proc recvFrame(ws: WebSocket): Future[Frame] {.async.} =

if header.len != 2:
ws.readyState = Closed
raise newException(WebSocketError, "socket closed")
raise newException(WebSocketError, "Socket closed")

let b0 = header[0].uint8
let b1 = header[1].uint8
Expand All @@ -308,7 +314,7 @@ proc recvFrame(ws: WebSocket): Future[Frame] {.async.} =
# If any of the rsv are set close the socket.
if result.rsv1 or result.rsv2 or result.rsv3:
ws.readyState = Closed
raise newException(WebSocketError, "WebSocket Protocol mismatch")
raise newException(WebSocketError, "WebSocket rsv mismatch")

# Payload length can be 7 bits, 7+16 bits, or 7+64 bits.
var finalLen: uint = 0
Expand Down Expand Up @@ -378,7 +384,7 @@ proc receiveStrPacket*(ws: WebSocket): Future[string] {.async.} =
return data
of Binary:
raise newException(WebSocketError,
"Got binary packet when looking for a string packet")
"Expected string packet, received binary packet")
of Ping:
await ws.send(data, Pong)
of Pong:
Expand All @@ -395,7 +401,7 @@ proc receiveBinaryPacket*(ws: WebSocket): Future[seq[byte]] {.async.} =
case opcode:
of Text:
raise newException(WebSocketError,
"Got text packet when looking for a binary packet")
"Expected binary packet, received string packet")
of Binary:
return cast[seq[byte]](data)
of Ping:
Expand Down
5 changes: 5 additions & 0 deletions test.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import ws, asyncdispatch

proc sendRemote(ws: WebSocket, data: string): Future[string] {.async.} =
await ws.send(data)
result = await ws.receiveStrPacket()
27 changes: 27 additions & 0 deletions tests/test_protocol.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
include ../src/ws

block: # protocol mismatch
# Start server
var
hadProtocolMismatch = false
hadFailedNewSocket = false
proc cb(req: Request) {.async.} =
try:
var ws = await newWebSocket(req, protocol = "foo")
await ws.send("Welcome")
ws.close()
except:
if getCurrentExceptionMsg().startsWith("Protocol mismatch (expected: foo, got: foo2)"):
hadProtocolMismatch = true
req.client.close()
var server = newAsyncHttpServer()
asyncCheck server.serve(Port(9002), cb)
# Send request
try:
var ws = waitFor newWebSocket("ws://127.0.0.1:9002/ws", protocol = "foo2")
ws.close()
except:
hadFailedNewSocket = true
server.close()
assert hadProtocolMismatch
assert hadFailedNewSocket
17 changes: 17 additions & 0 deletions tests/test_ws.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
include ../src/ws

# Start server
proc cb(req: Request) {.async.} =
var ws = await newWebSocket(req)
await ws.send("Welcome")
ws.close()
var server = newAsyncHttpServer()
asyncCheck server.serve(Port(9001), cb)
# Send request
var ws = waitFor newWebSocket("ws://127.0.0.1:9001/ws")
let packet = waitFor ws.receiveStrPacket()

assert packet == "Welcome"

ws.close()
server.close()

0 comments on commit d0b31d1

Please sign in to comment.