Skip to content

Commit

Permalink
Fix cancellation for URLSessionSource
Browse files Browse the repository at this point in the history
  • Loading branch information
kdubb committed Jul 10, 2023
1 parent f7b8291 commit c01c385
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 16 deletions.
68 changes: 61 additions & 7 deletions Sources/IOStreams/URLSessionStreams.swift
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

import Atomics
import Foundation

@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
Expand All @@ -36,6 +37,7 @@ public class URLSessionSource: Source {
public private(set) var bytesRead = 0

private var iterator: Stream.AsyncIterator?
private var availableData: Data? = Data()

public convenience init(url: URL, session: URLSession = .shared) {
self.init(request: URLRequest(url: url), session: session)
Expand Down Expand Up @@ -65,9 +67,32 @@ public class URLSessionSource: Source {
public func read(max: Int) async throws -> Data? {
guard iterator != nil else { throw IOError.streamClosed }

let next = try await iterator?.next()
guard let availableData = availableData else {

bytesRead += next?.count ?? 0
// iterator done, we're done

return nil
}

// Honor cancellation before any work
try Task.checkCancellation()

guard !availableData.isEmpty else {

// no data to return, grab some more and try again

self.availableData = try await iterator?.next()

return try await read(max: max)
}

// Since we cannot control how much data the URL session task provides
// in a single callback, we ensure this function honors the `max` parameter.

let next = availableData.prefix(max)
self.availableData = availableData.dropFirst(next.count)

bytesRead += next.count

return next
}
Expand All @@ -78,14 +103,35 @@ public class URLSessionSource: Source {

private final class DataTaskDelegate: NSObject, URLSessionDataDelegate {

let continuation: Stream.Continuation
var continuation: Stream.Continuation?

init(continuation: Stream.Continuation) {
self.continuation = continuation
}

func finish(throwing error: Error? = nil) {
self.continuation?.finish(throwing: error)
self.continuation = nil
}

func checkCancel(task: URLSessionTask) -> Bool {
if task.state == .canceling {
finish(throwing: CancellationError())
return false
}
return true
}

public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
continuation.finish(throwing: error)
var error = error

// URLSessionTask is hidden so canceellation comes
// from Task cancellation so we normalize errors.
if let urlError = error as? URLError, urlError.code == .cancelled {
error = CancellationError()
}

finish(throwing: error)
}

public func urlSession(
Expand All @@ -94,15 +140,19 @@ public class URLSessionSource: Source {
didReceive response: URLResponse,
completionHandler: @escaping (URLSession.ResponseDisposition) -> Void
) {
guard checkCancel(task: dataTask) else {
completionHandler(.cancel)
return
}

guard let httpResponse = response as? HTTPURLResponse else {
continuation.finish(throwing: HTTPError.invalidResponse)
finish(throwing: HTTPError.invalidResponse)
completionHandler(.cancel)
return
}

if 300 ..< 600 ~= httpResponse.statusCode {
continuation.finish(throwing: HTTPError.invalidStatus)
finish(throwing: HTTPError.invalidStatus)
completionHandler(.cancel)
return
}
Expand All @@ -111,7 +161,11 @@ public class URLSessionSource: Source {
}

public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
continuation.yield(data)
guard checkCancel(task: dataTask) else {
return
}

continuation?.yield(data)
}

}
Expand Down
34 changes: 25 additions & 9 deletions Tests/IOStreamsTests/URLSessionStreamTests.swift
Expand Up @@ -24,8 +24,8 @@ final class URLSessionStreamsTests: XCTestCase {

let source = URL(string: "https://github.com")!.source()

for try await buffer in source.buffers() {
print("### Received \(buffer.count) bytes")
for try await _ /* data */ in source.buffers() {
// print("### Received \(buffer.count) bytes")
}

XCTAssertGreaterThan(source.bytesRead, 50 * 1024)
Expand All @@ -36,16 +36,22 @@ final class URLSessionStreamsTests: XCTestCase {
let source = URL(string: "https://github.com")!.source()

let reader = Task {
for try await buffer in source.buffers(size: 3079) {
print("### Received \(buffer.count) bytes")
for try await _ /* data */ in source.buffers(size: 3079) {
// print("### Received \(buffer.count) bytes")
}
}

do {
reader.cancel()
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}
catch is CancellationError {}

XCTAssertEqual(source.bytesRead, 0)
}
Expand All @@ -54,12 +60,22 @@ final class URLSessionStreamsTests: XCTestCase {

let source = URL(string: "https://github.com")!.source()

let task = Task {
for try await _ in source.buffers(size: 1024) {
withUnsafeCurrentTask { $0?.cancel() }
let reader = Task {
for try await _ in source.buffers(size: 133) {
withUnsafeCurrentTask { $0!.cancel() }
}
}
try await task.value

do {
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}

XCTAssert(source.bytesRead > 0, "Data should have been read from source")
XCTAssert(source.bytesRead < 50 * 1024, "Source should have cancelled iteration")
Expand Down

0 comments on commit c01c385

Please sign in to comment.