Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client-Decompression support #123

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
7 changes: 7 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ let package = Package(
],
targets: [
.target(name: "WebSocketKit", dependencies: [
"CZlib",
.product(name: "NIO", package: "swift-nio"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
Expand All @@ -29,6 +30,12 @@ let package = Package(
.product(name: "NIOTransportServices", package: "swift-nio-transport-services"),
.product(name: "Atomics", package: "swift-atomics")
]),
.target(
name: "CZlib",
dependencies: [],
linkerSettings: [
.linkedLibrary("z")
]),
.testTarget(name: "WebSocketKitTests", dependencies: [
.target(name: "WebSocketKit"),
]),
Expand Down
1 change: 1 addition & 0 deletions Sources/CZlib/empty.c
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

14 changes: 14 additions & 0 deletions Sources/CZlib/include/CZlib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef C_ZLIB_H
#define C_ZLIB_H

#include <zlib.h>

static inline int CZlib_inflateInit2(z_streamp strm, int windowBits) {
return inflateInit2(strm, windowBits);
}

static inline Bytef *CZlib_voidPtr_to_BytefPtr(void *in) {
return (Bytef *)in;
}

#endif
23 changes: 23 additions & 0 deletions Sources/WebSocketKit/Concurrency/Compression/Compression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

public enum Compression {
public struct Algorithm {
enum Base {
case gzip
case deflate
}

private let base: Base

var window: CInt {
switch base {
case .deflate:
return 15
case .gzip:
return 15 + 16
}
}

public static let gzip = Self(base: .gzip)
public static let deflate = Self(base: .deflate)
}
}
173 changes: 173 additions & 0 deletions Sources/WebSocketKit/Concurrency/Compression/Decompression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import CZlib

public enum Decompression {

public struct Configuration {
public var algorithm: Compression.Algorithm
public var limit: Limit

public init(algorithm: Compression.Algorithm, limit: Limit) {
self.algorithm = algorithm
self.limit = limit
}
}

/// Specifies how to limit decompression inflation.
public struct Limit: Sendable {
private enum Base {
case none
case size(Int)
case ratio(Int)
}

private var limit: Base

/// No limit will be set.
/// - warning: Setting `limit` to `.none` leaves you vulnerable to denial of service attacks.
public static let none = Limit(limit: .none)
/// Limit will be set on the request body size.
public static func size(_ value: Int) -> Limit {
return Limit(limit: .size(value))
}
/// Limit will be set on a ratio between compressed body size and decompressed result.
public static func ratio(_ value: Int) -> Limit {
return Limit(limit: .ratio(value))
}

func exceeded(compressed: Int, decompressed: Int) -> Bool {
switch self.limit {
case .none:
return false
case .size(let allowed):
return decompressed > allowed
case .ratio(let ratio):
return decompressed > compressed * ratio
}
}
}

public struct DecompressionError: Error, Equatable, CustomStringConvertible {

private enum Base: Error, Equatable {
case limit
case inflationError(Int)
case initializationError(Int)
case invalidTrailingData
}

private var base: Base

/// The set ``DecompressionLimit`` has been exceeded
public static let limit = Self(base: .limit)

/// An error occurred when inflating. Error code is included to aid diagnosis.
public static var inflationError: (Int) -> Self = {
Self(base: .inflationError($0))
}

/// Decoder could not be initialised. Error code is included to aid diagnosis.
public static var initializationError: (Int) -> Self = {
Self(base: .initializationError($0))
}

/// Decompression completed but there was invalid trailing data behind the compressed data.
public static var invalidTrailingData = Self(base: .invalidTrailingData)

public var description: String {
return String(describing: self.base)
}
}

struct Decompressor {
private let limit: Limit
private var stream = z_stream()

init(limit: Limit) {
self.limit = limit
}

/// Assumes `buffer` is a new empty buffer.
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer) throws {
let compressedLength = part.readableBytes
var isComplete = false

while part.readableBytes > 0 && !isComplete {
try self.stream.inflatePart(
input: &part,
output: &buffer,
isComplete: &isComplete
)

if self.limit.exceeded(
compressed: compressedLength,
decompressed: buffer.writerIndex + 1
) {
throw DecompressionError.limit
}
}

if part.readableBytes > 0 {
throw DecompressionError.invalidTrailingData
}
}

mutating func initializeDecoder(encoding: Compression.Algorithm) throws {
self.stream.zalloc = nil
self.stream.zfree = nil
self.stream.opaque = nil

let rc = CZlib_inflateInit2(&self.stream, encoding.window)
guard rc == Z_OK else {
throw DecompressionError.initializationError(Int(rc))
}
}

mutating func deinitializeDecoder() {
inflateEnd(&self.stream)
}
}
}

//MARK: - +z_stream
private extension z_stream {
mutating func inflatePart(
input: inout ByteBuffer,
output: inout ByteBuffer,
isComplete: inout Bool
) throws {
let minimumCapacity = input.readableBytes * 4
try input.readWithUnsafeMutableReadableBytes { pointer in
self.avail_in = UInt32(pointer.count)
self.next_in = CZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

defer {
self.avail_in = 0
self.next_in = nil
self.avail_out = 0
self.next_out = nil
}

isComplete = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)

return pointer.count - Int(self.avail_in)
}
}

private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Bool {
var rc = Z_OK

try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
self.avail_out = UInt32(pointer.count)
self.next_out = CZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

rc = inflate(&self, Z_SYNC_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw Decompression.DecompressionError.inflationError(Int(rc))
}

return pointer.count - Int(self.avail_out)
}

return rc == Z_STREAM_END
}
}
63 changes: 45 additions & 18 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,31 @@ public final class WebSocket {
}

private let channel: Channel
private var onTextCallback: (WebSocket, String) -> ()

private var onTextCallback: ((WebSocket, String) -> ())?
private var onTextBufferCallback: (WebSocket, ByteBuffer) -> ()
private var onBinaryCallback: (WebSocket, ByteBuffer) -> ()
private var onPongCallback: (WebSocket) -> ()
private var onPingCallback: (WebSocket) -> ()

private var frameSequence: WebSocketFrameSequence?
private let type: PeerType

private var decompressor: Decompression.Decompressor?

private var waitingForPong: Bool
private var waitingForClose: Bool
private var scheduledTimeoutTask: Scheduled<Void>?

init(channel: Channel, type: PeerType) {
init(channel: Channel, type: PeerType, decompression: Decompression.Configuration?) throws {
self.channel = channel
self.type = type
if let decompression = decompression {
self.decompressor = Decompression.Decompressor(limit: decompression.limit)
try self.decompressor?.initializeDecoder(encoding: decompression.algorithm)
}
self.onTextCallback = { _, _ in }
self.onTextBufferCallback = { _, _ in }
self.onBinaryCallback = { _, _ in }
self.onPongCallback = { _ in }
self.onPingCallback = { _ in }
Expand All @@ -51,6 +62,11 @@ public final class WebSocket {
self.onTextCallback = callback
}

/// The same as `onText`, but with raw data instead of the decoded `String`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What purpose does this callback serve? I don't see any usage of it anywhere, including in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we only have onTextCallback. What is sent to the ws is, of course, Data, not string. but when using onTextCallback, ws-kit turns the data into a string and passes the string to the users of the package. The problem is that if the text is for example in JSON format, ws-kit users need to turn the string into Data again to pass it to somewhere like JSONDecoder. so we have Data -> String -> Data instead of just Data which is wasteful. this new callback solves that problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like it is mentioned in this issue: #79

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i should add some tests, never-the-less.

Copy link
Contributor Author

@MahdiBM MahdiBM Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test to assert both text callbacks have the same behavior.
I should also add that onBinary is not the same as onTextBuffer because onBinary only is activated if the ws frame is an actual binary frame. onTextBuffer is for when the ws frame is a text frame, but users might still prefer to access the string's data directly. I did try to mix the two, but it could cause problems.

public func onTextBuffer(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) {
self.onTextBufferCallback = callback
}

public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) {
self.onBinaryCallback = callback
}
Expand All @@ -64,10 +80,10 @@ public final class WebSocket {
}

/// If set, this will trigger automatic pings on the connection. If ping is not answered before
/// the next ping is sent, then the WebSocket will be presumed innactive and will be closed
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed
/// automatically.
/// These pings can also be used to keep the WebSocket alive if there is some other timeout
/// mechanism shutting down innactive connections, such as a Load Balancer deployed in
/// mechanism shutting down inactive connections, such as a Load Balancer deployed in
/// front of the server.
public var pingInterval: TimeAmount? {
didSet {
Expand Down Expand Up @@ -233,6 +249,7 @@ public final class WebSocket {
} else {
frameSequence = WebSocketFrameSequence(type: frame.opcode)
}

// append this frame and update the sequence
frameSequence.append(frame)
self.frameSequence = frameSequence
Expand All @@ -252,12 +269,27 @@ public final class WebSocket {

// if this frame was final and we have a non-nil frame sequence,
// output it to the websocket and clear storage
if let frameSequence = self.frameSequence, frame.fin {
if var frameSequence = self.frameSequence, frame.fin {
switch frameSequence.type {
case .binary:
self.onBinaryCallback(self, frameSequence.binaryBuffer)
if decompressor != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if let decompressor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to support swifts older than 5.7.

Other than that, iirc we can't trigger a compiler copy of the decompressor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to support swifts older than 5.7.

Oops, I forgot about that. 😆

Other than that, iirc we can't trigger a compiler copy of the decompressor.

Hmm interesting. Then should Decompressor be a class since it manages a resource?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, we will have ARC overhead, which we currently don't have.
This is also not part of the API, so we don't really need to worry about it being friendly to developers, too much.
Swift NIO also had it as a struct, not class.

do {
var buffer = ByteBuffer()
try decompressor!.decompress(part: &frameSequence.buffer, buffer: &buffer)

self.onBinaryCallback(self, buffer)
} catch {
self.close(code: .protocolError, promise: nil)
return
}
} else {
self.onBinaryCallback(self, frameSequence.buffer)
}
case .text:
self.onTextCallback(self, frameSequence.textBuffer)
if let callback = self.onTextCallback {
callback(self, String(buffer: frameSequence.buffer))
}
self.onTextBufferCallback(self, frameSequence.buffer)
case .ping, .pong:
assertionFailure("Control frames never have a frameSequence")
default: break
Expand Down Expand Up @@ -293,30 +325,25 @@ public final class WebSocket {
}

deinit {
self.decompressor?.deinitializeDecoder()
assert(self.isClosed, "WebSocket was not closed before deinit.")
}
}

private struct WebSocketFrameSequence {
var binaryBuffer: ByteBuffer
var textBuffer: String
var buffer: ByteBuffer
var type: WebSocketOpcode

init(type: WebSocketOpcode) {
self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0)
self.textBuffer = .init()
self.buffer = ByteBufferAllocator().buffer(capacity: 0)
self.type = type
}

mutating func append(_ frame: WebSocketFrame) {
var data = frame.unmaskedData
switch type {
case .binary:
self.binaryBuffer.writeBuffer(&data)
case .text:
if let string = data.readString(length: data.readableBytes) {
self.textBuffer += string
}
case .binary, .text:
var data = frame.unmaskedData
self.buffer.writeBuffer(&data)
default: break
}
}
Expand Down
Loading