Skip to content

Commit

Permalink
fix for partial messages
Browse files Browse the repository at this point in the history
  • Loading branch information
trikko committed Apr 13, 2024
1 parent b3de66b commit 9c138e8
Showing 1 changed file with 91 additions and 88 deletions.
179 changes: 91 additions & 88 deletions source/serverino/interfaces.d
Original file line number Diff line number Diff line change
Expand Up @@ -1508,126 +1508,129 @@ class WebSocket
{
_isDirty = true;

if (_toParse.length == 0) return WebSocketMessage.init;

while(true)
{
if (_toParse.length == 0) return WebSocketMessage.init;

ubyte[] cursor = _toParse;

if (cursor.length < 2)
{
return WebSocketMessage.init;
}
ubyte[] cursor = _toParse;

import std.system : endian, Endian;
import std.bitmanip : swapEndian;
if (cursor.length < 2)
{
return WebSocketMessage.init;
}

static if(endian == Endian.littleEndian)
const ushort header = swapEndian((cast(ushort[])(cursor[0..2]))[0]);
else
const ushort header = (cast(ushort[])(cursor[0..2]))[0];
import std.system : endian, Endian;
import std.bitmanip : swapEndian;

cursor = cursor[2..$];
static if(endian == Endian.littleEndian)
const ushort header = swapEndian((cast(ushort[])(cursor[0..2]))[0]);
else
const ushort header = (cast(ushort[])(cursor[0..2]))[0];

bool flagFIN = (header & Flags.FIN) == Flags.FIN;
bool flagMASK = (header & Flags.MASK) == Flags.MASK;
cursor = cursor[2..$];

auto opcode = cast(ushort)(header & (0xF << 8)); // MASK = 0xF<<8
bool flagFIN = (header & Flags.FIN) == Flags.FIN;
bool flagMASK = (header & Flags.MASK) == Flags.MASK;

auto payloadLength = cast(size_t)cast(byte)(header & Flags.PAYLOAD_MASK);
ubyte[] payload;
ubyte[] mask = [0, 0, 0, 0];
auto opcode = cast(ushort)(header & (0xF << 8)); // MASK = 0xF<<8

if (payloadLength == 126)
{
if (cursor.length < ushort.sizeof)
return WebSocketMessage.init;
auto payloadLength = cast(size_t)cast(byte)(header & Flags.PAYLOAD_MASK);
ubyte[] payload;
ubyte[] mask = [0, 0, 0, 0];

static if (endian == Endian.littleEndian)
payloadLength = swapEndian((cast(ushort[])(cursor[0..ushort.sizeof]))[0]);
else
payloadLength = (cast(ushort[])(cursor[0..ushort.sizeof]))[0];
if (payloadLength == 126)
{
if (cursor.length < ushort.sizeof)
return WebSocketMessage.init;

cursor = cursor[ushort.sizeof..$];
}
else if (payloadLength == 127)
{
if (cursor.length < size_t.sizeof)
return WebSocketMessage.init;
static if (endian == Endian.littleEndian)
payloadLength = swapEndian((cast(ushort[])(cursor[0..ushort.sizeof]))[0]);
else
payloadLength = (cast(ushort[])(cursor[0..ushort.sizeof]))[0];

payloadLength = (cast(size_t[])(cursor[0..size_t.sizeof]))[0];
cursor = cursor[ushort.sizeof..$];
}
else if (payloadLength == 127)
{
if (cursor.length < size_t.sizeof)
return WebSocketMessage.init;

static if (endian == Endian.littleEndian)
payloadLength = swapEndian((cast(size_t[])(cursor[0..size_t.sizeof]))[0]);
else
payloadLength = (cast(size_t[])(cursor[0..size_t.sizeof]))[0];

cursor = cursor[size_t.sizeof..$];
}
static if (endian == Endian.littleEndian)
payloadLength = swapEndian((cast(size_t[])(cursor[0..size_t.sizeof]))[0]);
else
payloadLength = (cast(size_t[])(cursor[0..size_t.sizeof]))[0];

if (flagMASK)
{
if (cursor.length < 4)
return WebSocketMessage.init;
cursor = cursor[size_t.sizeof..$];
}

mask = cursor[0..4];
cursor = cursor[4..$];
}
if (flagMASK)
{
if (cursor.length < 4)
return WebSocketMessage.init;

mask = cursor[0..4];
cursor = cursor[4..$];
}

if (cursor.length < payloadLength)
return WebSocketMessage.init;

payload = cursor[0..payloadLength];
if (cursor.length < payloadLength)
return WebSocketMessage.init;

if (flagMASK)
foreach(i, ref ubyte b; payload)
b ^= mask[i % 4];
payload = cursor[0..payloadLength];

_parsedData ~= payload;
_toParse = cursor[payloadLength..$];
if (flagMASK)
foreach(i, ref ubyte b; payload)
b ^= mask[i % 4];

if (flagFIN)
{
scope(exit) _parsedData = null;
_parsedData ~= payload;
_toParse = cursor[payloadLength..$];

if (opcode == WebSocketMessage.OpCode.Ping)
if (flagFIN)
{
debug log("PING received, sending PONG");
sendMessage(WebSocketMessage(WebSocketMessage.OpCode.Pong, _parsedData));
return WebSocketMessage.init;
}
scope(exit) _parsedData = null;

auto msg = WebSocketMessage
(
cast(WebSocketMessage.OpCode)opcode,
_parsedData
);
if (opcode == WebSocketMessage.OpCode.Ping)
{
debug log("PING received, sending PONG");
sendMessage(WebSocketMessage(WebSocketMessage.OpCode.Pong, _parsedData));
return WebSocketMessage.init;
}

msg.isValid = true;
auto msg = WebSocketMessage
(
cast(WebSocketMessage.OpCode)opcode,
_parsedData
);

bool propagate = true;
msg.isValid = true;

switch(cast(WebSocketMessage.OpCode)opcode)
{
case WebSocketMessage.OpCode.Binary:
if (propagate && onBinaryMessage !is null) propagate = onBinaryMessage(msg.as!(ubyte[]));
break;

case WebSocketMessage.OpCode.Text:
if (propagate && onTextMessage !is null) propagate = onTextMessage(msg.as!string);
break;

case WebSocketMessage.OpCode.Close:
if (propagate && onCloseMessage !is null) propagate = onCloseMessage(msg);
break;
default: break;
}
bool propagate = true;

if (propagate && onMessage !is null) propagate = onMessage(msg);
switch(cast(WebSocketMessage.OpCode)opcode)
{
case WebSocketMessage.OpCode.Binary:
if (propagate && onBinaryMessage !is null) propagate = onBinaryMessage(msg.as!(ubyte[]));
break;

case WebSocketMessage.OpCode.Text:
if (propagate && onTextMessage !is null) propagate = onTextMessage(msg.as!string);
break;

case WebSocketMessage.OpCode.Close:
if (propagate && onCloseMessage !is null) propagate = onCloseMessage(msg);
break;
default: break;
}

return msg;
if (propagate && onMessage !is null) propagate = onMessage(msg);

return msg;
}
}
else return WebSocketMessage.init;
return WebSocketMessage.init;
}

enum Flags : ushort
Expand Down

0 comments on commit 9c138e8

Please sign in to comment.