From c01c385ae80ef1f5f0ecdb73af6871084b88ab5f Mon Sep 17 00:00:00 2001 From: Kevin Wooten Date: Sun, 9 Jul 2023 09:06:09 -0700 Subject: [PATCH] Fix cancellation for URLSessionSource --- Sources/IOStreams/URLSessionStreams.swift | 68 +++++++++++++++++-- .../URLSessionStreamTests.swift | 34 +++++++--- 2 files changed, 86 insertions(+), 16 deletions(-) diff --git a/Sources/IOStreams/URLSessionStreams.swift b/Sources/IOStreams/URLSessionStreams.swift index 6513a30..378deaa 100644 --- a/Sources/IOStreams/URLSessionStreams.swift +++ b/Sources/IOStreams/URLSessionStreams.swift @@ -14,6 +14,7 @@ * limitations under the License. */ +import Atomics import Foundation @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) @@ -36,6 +37,7 @@ public class URLSessionSource: Source { public private(set) var bytesRead = 0 private var iterator: Stream.AsyncIterator? + private var availableData: Data? = Data() public convenience init(url: URL, session: URLSession = .shared) { self.init(request: URLRequest(url: url), session: session) @@ -65,9 +67,32 @@ public class URLSessionSource: Source { public func read(max: Int) async throws -> Data? { guard iterator != nil else { throw IOError.streamClosed } - let next = try await iterator?.next() + guard let availableData = availableData else { - bytesRead += next?.count ?? 0 + // iterator done, we're done + + return nil + } + + // Honor cancellation before any work + try Task.checkCancellation() + + guard !availableData.isEmpty else { + + // no data to return, grab some more and try again + + self.availableData = try await iterator?.next() + + return try await read(max: max) + } + + // Since we cannot control how much data the URL session task provides + // in a single callback, we ensure this function honors the `max` parameter. + + let next = availableData.prefix(max) + self.availableData = availableData.dropFirst(next.count) + + bytesRead += next.count return next } @@ -78,14 +103,35 @@ public class URLSessionSource: Source { private final class DataTaskDelegate: NSObject, URLSessionDataDelegate { - let continuation: Stream.Continuation + var continuation: Stream.Continuation? init(continuation: Stream.Continuation) { self.continuation = continuation } + func finish(throwing error: Error? = nil) { + self.continuation?.finish(throwing: error) + self.continuation = nil + } + + func checkCancel(task: URLSessionTask) -> Bool { + if task.state == .canceling { + finish(throwing: CancellationError()) + return false + } + return true + } + public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - continuation.finish(throwing: error) + var error = error + + // URLSessionTask is hidden so canceellation comes + // from Task cancellation so we normalize errors. + if let urlError = error as? URLError, urlError.code == .cancelled { + error = CancellationError() + } + + finish(throwing: error) } public func urlSession( @@ -94,15 +140,19 @@ public class URLSessionSource: Source { didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void ) { + guard checkCancel(task: dataTask) else { + completionHandler(.cancel) + return + } guard let httpResponse = response as? HTTPURLResponse else { - continuation.finish(throwing: HTTPError.invalidResponse) + finish(throwing: HTTPError.invalidResponse) completionHandler(.cancel) return } if 300 ..< 600 ~= httpResponse.statusCode { - continuation.finish(throwing: HTTPError.invalidStatus) + finish(throwing: HTTPError.invalidStatus) completionHandler(.cancel) return } @@ -111,7 +161,11 @@ public class URLSessionSource: Source { } public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { - continuation.yield(data) + guard checkCancel(task: dataTask) else { + return + } + + continuation?.yield(data) } } diff --git a/Tests/IOStreamsTests/URLSessionStreamTests.swift b/Tests/IOStreamsTests/URLSessionStreamTests.swift index 41af61c..0fa3234 100644 --- a/Tests/IOStreamsTests/URLSessionStreamTests.swift +++ b/Tests/IOStreamsTests/URLSessionStreamTests.swift @@ -24,8 +24,8 @@ final class URLSessionStreamsTests: XCTestCase { let source = URL(string: "https://github.com")!.source() - for try await buffer in source.buffers() { - print("### Received \(buffer.count) bytes") + for try await _ /* data */ in source.buffers() { + // print("### Received \(buffer.count) bytes") } XCTAssertGreaterThan(source.bytesRead, 50 * 1024) @@ -36,16 +36,22 @@ final class URLSessionStreamsTests: XCTestCase { let source = URL(string: "https://github.com")!.source() let reader = Task { - for try await buffer in source.buffers(size: 3079) { - print("### Received \(buffer.count) bytes") + for try await _ /* data */ in source.buffers(size: 3079) { + // print("### Received \(buffer.count) bytes") } } do { reader.cancel() try await reader.value + XCTFail("Expected cancellation error") + } + catch is CancellationError { + // expected + } + catch { + XCTFail("Unexpected error thrown: \(error.localizedDescription)") } - catch is CancellationError {} XCTAssertEqual(source.bytesRead, 0) } @@ -54,12 +60,22 @@ final class URLSessionStreamsTests: XCTestCase { let source = URL(string: "https://github.com")!.source() - let task = Task { - for try await _ in source.buffers(size: 1024) { - withUnsafeCurrentTask { $0?.cancel() } + let reader = Task { + for try await _ in source.buffers(size: 133) { + withUnsafeCurrentTask { $0!.cancel() } } } - try await task.value + + do { + try await reader.value + XCTFail("Expected cancellation error") + } + catch is CancellationError { + // expected + } + catch { + XCTFail("Unexpected error thrown: \(error.localizedDescription)") + } XCTAssert(source.bytesRead > 0, "Data should have been read from source") XCTAssert(source.bytesRead < 50 * 1024, "Source should have cancelled iteration")