Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client-Decompression support #123

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ let package = Package(
],
targets: [
.target(name: "WebSocketKit", dependencies: [
"CZlib",
.product(name: "NIO", package: "swift-nio"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
Expand All @@ -29,6 +30,12 @@ let package = Package(
.product(name: "NIOTransportServices", package: "swift-nio-transport-services"),
.product(name: "Atomics", package: "swift-atomics")
]),
.target(
name: "CZlib",
dependencies: [],
linkerSettings: [
.linkedLibrary("z")
]),
.testTarget(name: "WebSocketKitTests", dependencies: [
.target(name: "WebSocketKit"),
]),
Expand Down
1 change: 1 addition & 0 deletions Sources/CZlib/empty.c
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

14 changes: 14 additions & 0 deletions Sources/CZlib/include/CZlib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef C_ZLIB_H
#define C_ZLIB_H

#include <zlib.h>

static inline int CZlib_inflateInit2(z_streamp strm, int windowBits) {
return inflateInit2(strm, windowBits);
}

static inline Bytef *CZlib_voidPtr_to_BytefPtr(void *in) {
return (Bytef *)in;
}

#endif
13 changes: 13 additions & 0 deletions Sources/WebSocketKit/Compression/Compression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

enum Compression {
enum Algorithm {
case deflate

var window: CInt {
switch self {
case .deflate:
return 15
}
}
}
}
122 changes: 122 additions & 0 deletions Sources/WebSocketKit/Compression/Decompression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import CZlib

public enum Decompression {

public struct Configuration {
/// For now we only support `deflate`, because it's the main compression
/// algorithm for web-sockets (RFC 7692).
let algorithm: Compression.Algorithm = .deflate

private init() { }

public static let enabled = Configuration()
}

public struct DecompressionError: Error, Equatable, CustomStringConvertible {

private enum Base: Error, Equatable {
case inflationError(Int)
case initializationError(Int)
case invalidTrailingData
}

private var base: Base

/// An error occurred when inflating. Error code is included to aid diagnosis.
public static var inflationError: (Int) -> Self = {
Self(base: .inflationError($0))
}

/// Decoder could not be initialized. Error code is included to aid diagnosis.
public static var initializationError: (Int) -> Self = {
Self(base: .initializationError($0))
}

/// Decompression completed but there was invalid trailing data behind the compressed data.
public static var invalidTrailingData = Self(base: .invalidTrailingData)

public var description: String {
return String(describing: self.base)
}
}

struct Decompressor {
private var stream = z_stream()

/// Assumes `buffer` is a new empty buffer.
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer) throws {
var isComplete = false

while part.readableBytes > 0 && !isComplete {
try self.stream.inflatePart(
input: &part,
output: &buffer,
isComplete: &isComplete
)
}

if part.readableBytes > 0 {
throw DecompressionError.invalidTrailingData
}
}

mutating func initializeDecoder(encoding: Compression.Algorithm) throws {
self.stream.zalloc = nil
self.stream.zfree = nil
self.stream.opaque = nil

let rc = CZlib_inflateInit2(&self.stream, encoding.window)
guard rc == Z_OK else {
throw DecompressionError.initializationError(Int(rc))
}
}

mutating func deinitializeDecoder() {
inflateEnd(&self.stream)
}
}
}

//MARK: - +z_stream
private extension z_stream {
mutating func inflatePart(
input: inout ByteBuffer,
output: inout ByteBuffer,
isComplete: inout Bool
) throws {
let minimumCapacity = input.readableBytes * 4
try input.readWithUnsafeMutableReadableBytes { pointer in
self.avail_in = UInt32(pointer.count)
self.next_in = CZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

defer {
self.avail_in = 0
self.next_in = nil
self.avail_out = 0
self.next_out = nil
}

isComplete = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)

return pointer.count - Int(self.avail_in)
}
}

private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Bool {
var rc = Z_OK

try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
self.avail_out = UInt32(pointer.count)
self.next_out = CZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

rc = inflate(&self, Z_SYNC_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw Decompression.DecompressionError.inflationError(Int(rc))
}

return pointer.count - Int(self.avail_out)
}

return rc == Z_STREAM_END
}
}
4 changes: 3 additions & 1 deletion Sources/WebSocketKit/HTTPInitialRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
let host: String
let path: String
let query: String?
let decompression: Decompression.Configuration?
let headers: HTTPHeaders
let upgradePromise: EventLoopPromise<Void>

init(host: String, path: String, query: String?, headers: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
init(host: String, path: String, query: String?, decompression: Decompression.Configuration?, headers: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
self.host = host
self.path = path
self.query = query
self.decompression = decompression
self.headers = headers
self.upgradePromise = upgradePromise
}
Expand Down
63 changes: 45 additions & 18 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,31 @@ public final class WebSocket {
}

private let channel: Channel
private var onTextCallback: (WebSocket, String) -> ()

private var onTextCallback: ((WebSocket, String) -> ())?
private var onTextBufferCallback: (WebSocket, ByteBuffer) -> ()
private var onBinaryCallback: (WebSocket, ByteBuffer) -> ()
private var onPongCallback: (WebSocket) -> ()
private var onPingCallback: (WebSocket) -> ()

private var frameSequence: WebSocketFrameSequence?
private let type: PeerType

private var decompressor: Decompression.Decompressor?

private var waitingForPong: Bool
private var waitingForClose: Bool
private var scheduledTimeoutTask: Scheduled<Void>?

init(channel: Channel, type: PeerType) {
init(channel: Channel, type: PeerType, decompression: Decompression.Configuration?) throws {
self.channel = channel
self.type = type
if let decompression = decompression {
self.decompressor = Decompression.Decompressor()
try self.decompressor?.initializeDecoder(encoding: decompression.algorithm)
}
self.onTextCallback = { _, _ in }
self.onTextBufferCallback = { _, _ in }
self.onBinaryCallback = { _, _ in }
self.onPongCallback = { _ in }
self.onPingCallback = { _ in }
Expand All @@ -51,6 +62,11 @@ public final class WebSocket {
self.onTextCallback = callback
}

/// The same as `onText`, but with raw data instead of the decoded `String`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What purpose does this callback serve? I don't see any usage of it anywhere, including in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we only have onTextCallback. What is sent to the ws is, of course, Data, not string. but when using onTextCallback, ws-kit turns the data into a string and passes the string to the users of the package. The problem is that if the text is for example in JSON format, ws-kit users need to turn the string into Data again to pass it to somewhere like JSONDecoder. so we have Data -> String -> Data instead of just Data which is wasteful. this new callback solves that problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like it is mentioned in this issue: #79

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i should add some tests, never-the-less.

Copy link
Contributor Author

@MahdiBM MahdiBM Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test to assert both text callbacks have the same behavior.
I should also add that onBinary is not the same as onTextBuffer because onBinary only is activated if the ws frame is an actual binary frame. onTextBuffer is for when the ws frame is a text frame, but users might still prefer to access the string's data directly. I did try to mix the two, but it could cause problems.

public func onTextBuffer(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) {
self.onTextBufferCallback = callback
}

public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) {
self.onBinaryCallback = callback
}
Expand All @@ -64,10 +80,10 @@ public final class WebSocket {
}

/// If set, this will trigger automatic pings on the connection. If ping is not answered before
/// the next ping is sent, then the WebSocket will be presumed innactive and will be closed
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed
/// automatically.
/// These pings can also be used to keep the WebSocket alive if there is some other timeout
/// mechanism shutting down innactive connections, such as a Load Balancer deployed in
/// mechanism shutting down inactive connections, such as a Load Balancer deployed in
/// front of the server.
public var pingInterval: TimeAmount? {
didSet {
Expand Down Expand Up @@ -233,6 +249,7 @@ public final class WebSocket {
} else {
frameSequence = WebSocketFrameSequence(type: frame.opcode)
}

// append this frame and update the sequence
frameSequence.append(frame)
self.frameSequence = frameSequence
Expand All @@ -252,12 +269,27 @@ public final class WebSocket {

// if this frame was final and we have a non-nil frame sequence,
// output it to the websocket and clear storage
if let frameSequence = self.frameSequence, frame.fin {
if var frameSequence = self.frameSequence, frame.fin {
switch frameSequence.type {
case .binary:
self.onBinaryCallback(self, frameSequence.binaryBuffer)
if decompressor != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if let decompressor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to support swifts older than 5.7.

Other than that, iirc we can't trigger a compiler copy of the decompressor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to support swifts older than 5.7.

Oops, I forgot about that. 😆

Other than that, iirc we can't trigger a compiler copy of the decompressor.

Hmm interesting. Then should Decompressor be a class since it manages a resource?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, we will have ARC overhead, which we currently don't have.
This is also not part of the API, so we don't really need to worry about it being friendly to developers, too much.
Swift NIO also had it as a struct, not class.

do {
var buffer = ByteBuffer()
try decompressor!.decompress(part: &frameSequence.buffer, buffer: &buffer)

self.onBinaryCallback(self, buffer)
} catch {
self.close(code: .protocolError, promise: nil)
return
}
} else {
self.onBinaryCallback(self, frameSequence.buffer)
}
case .text:
self.onTextCallback(self, frameSequence.textBuffer)
if let callback = self.onTextCallback {
callback(self, String(buffer: frameSequence.buffer))
}
self.onTextBufferCallback(self, frameSequence.buffer)
case .ping, .pong:
assertionFailure("Control frames never have a frameSequence")
default: break
Expand Down Expand Up @@ -293,30 +325,25 @@ public final class WebSocket {
}

deinit {
self.decompressor?.deinitializeDecoder()
assert(self.isClosed, "WebSocket was not closed before deinit.")
}
}

private struct WebSocketFrameSequence {
var binaryBuffer: ByteBuffer
var textBuffer: String
var buffer: ByteBuffer
var type: WebSocketOpcode

init(type: WebSocketOpcode) {
self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0)
self.textBuffer = .init()
self.buffer = ByteBufferAllocator().buffer(capacity: 0)
self.type = type
}

mutating func append(_ frame: WebSocketFrame) {
var data = frame.unmaskedData
switch type {
case .binary:
self.binaryBuffer.writeBuffer(&data)
case .text:
if let string = data.readString(length: data.readableBytes) {
self.textBuffer += string
}
case .binary, .text:
var data = frame.unmaskedData
self.buffer.writeBuffer(&data)
default: break
}
}
Expand Down
22 changes: 19 additions & 3 deletions Sources/WebSocketKit/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,25 @@ public final class WebSocketClient {
public struct Configuration {
public var tlsConfiguration: TLSConfiguration?
public var maxFrameSize: Int

public var decompression: Decompression.Configuration?

public init(
tlsConfiguration: TLSConfiguration? = nil,
maxFrameSize: Int = 1 << 14
) {
self.tlsConfiguration = tlsConfiguration
self.maxFrameSize = maxFrameSize
self.decompression = nil
}

public init(
tlsConfiguration: TLSConfiguration? = nil,
maxFrameSize: Int = 1 << 14,
decompression: Decompression.Configuration?
) {
self.tlsConfiguration = tlsConfiguration
self.maxFrameSize = maxFrameSize
self.decompression = decompression
}
}

Expand Down Expand Up @@ -69,6 +81,7 @@ public final class WebSocketClient {
host: host,
path: path,
query: query,
decompression: self.configuration.decompression,
headers: headers,
upgradePromise: upgradePromise
)
Expand All @@ -82,7 +95,11 @@ public final class WebSocketClient {
maxFrameSize: self.configuration.maxFrameSize,
automaticErrorHandling: true,
upgradePipelineHandler: { channel, req in
return WebSocket.client(on: channel, onUpgrade: onUpgrade)
return WebSocket.client(
on: channel,
decompression: self.configuration.decompression,
onUpgrade: onUpgrade
)
}
)

Expand Down Expand Up @@ -130,7 +147,6 @@ public final class WebSocketClient {
}
}


public func syncShutdown() throws {
switch self.eventLoopGroupProvider {
case .shared:
Expand Down
Loading