Skip to content

Commit

Permalink
Replace incorrect uses of unsafe task with task cancellation handler
Browse files Browse the repository at this point in the history
  • Loading branch information
kdubb committed Jul 10, 2023
1 parent dae87c8 commit f7b8291
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 51 deletions.
33 changes: 22 additions & 11 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
{
"pins" : [
{
"identity" : "swift-docc-plugin",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-docc-plugin",
"state" : {
"revision" : "3303b164430d9a7055ba484c8ead67a52f7b74f6",
"version" : "1.0.0"
"object": {
"pins": [
{
"package": "swift-atomics",
"repositoryURL": "https://github.com/apple/swift-atomics.git",
"state": {
"branch": null,
"revision": "6c89474e62719ddcc1e9614989fff2f68208fe10",
"version": "1.1.0"
}
},
{
"package": "SwiftDocCPlugin",
"repositoryURL": "https://github.com/apple/swift-docc-plugin",
"state": {
"branch": null,
"revision": "3303b164430d9a7055ba484c8ead67a52f7b74f6",
"version": "1.0.0"
}
}
}
],
"version" : 2
]
},
"version": 1
}
7 changes: 6 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ let package = Package(
name: "IOStreams",
targets: ["IOStreams"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-atomics.git", .upToNextMinor(from: "1.1.0"))
],
targets: [
.target(
name: "IOStreams",
dependencies: []),
dependencies: [
.product(name: "Atomics", package: "swift-atomics")
]),
.testTarget(
name: "IOStreamsTests",
dependencies: ["IOStreams"]),
Expand Down
69 changes: 38 additions & 31 deletions Sources/IOStreams/FileStreams.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

import Atomics
import Darwin
import Foundation

Expand Down Expand Up @@ -47,52 +48,45 @@ public class FileSource: FileStream, Source {
public func read(max: Int) async throws -> Data? {
guard let dispatchIO = dispatchIO, !closedState.closed else { throw IOError.streamClosed }

let data: Data? = try await withCheckedThrowingContinuation { continuation in
withUnsafeCurrentTask { task in
let data: Data? = try await withTaskCancellationHandler {

try await withCheckedThrowingContinuation { continuation in

var collectedData = Data()

dispatchIO.read(offset: 0, length: max, queue: .taskPriority) { done, data, error in

if task?.isCancelled ?? false {
continuation.resume(throwing: CancellationError())
return
}
if error == ECANCELED {

if error != 0 {
return continuation.resume(throwing: CancellationError())
}

let errorCode = POSIXError.Code(rawValue: error) ?? .EIO
if let data = data, !data.isEmpty {

continuation.resume(throwing: POSIXError(errorCode))
collectedData.append(Data(data))
}
else if let data = data, !data.isEmpty {

collectedData.append(Data(data))
if error != 0 {

if done {
continuation.resume(returning: collectedData)
}
return continuation.resume(throwing: POSIXError(POSIXError.Code(rawValue: error) ?? .EIO))
}
else if done {
// error is 0, data is empty, and done is true.. flags EOF

if collectedData.isEmpty {
// Signal EOF to caller
continuation.resume(returning: nil)
}
else {
// Return the collected data... EOF will be signaled on next read
continuation.resume(returning: collectedData)
}
return continuation.resume(returning: collectedData.isEmpty ? nil : collectedData)
}
}

}

} onCancel: {
cancel()
}

bytesRead += data?.count ?? 0

return data
}

}


Expand Down Expand Up @@ -126,17 +120,17 @@ public class FileSink: FileStream, Sink {
public func write(data: Data) async throws {
guard let dispatchIO = dispatchIO, !closedState.closed else { throw IOError.streamClosed }

try await withCheckedThrowingContinuation { continuation in
try await withTaskCancellationHandler {

withUnsafeCurrentTask { task in
try await withCheckedThrowingContinuation { continuation in

data.withUnsafeBytes { dataPtr in

let data = DispatchData(bytes: dataPtr)

dispatchIO.write(offset: 0, data: data, queue: .taskPriority) { done, _, error in

if task?.isCancelled ?? false {
if error == ECANCELED {
continuation.resume(throwing: CancellationError())
return
}
Expand All @@ -147,19 +141,19 @@ public class FileSink: FileStream, Sink {

if error != 0 {

let errorCode = POSIXError.Code(rawValue: error) ?? .EIO

continuation.resume(throwing: POSIXError(errorCode))
continuation.resume(throwing: POSIXError(POSIXError.Code(rawValue: error) ?? .EIO))
}
else {
continuation.resume()
}
}

}
} as Void

}
} as Void
} onCancel: {
cancel()
}

bytesWritten += Int(data.count)
}
Expand Down Expand Up @@ -196,6 +190,19 @@ public class FileStream: Stream {

self.fileHandle = fileHandle

reset()
}

fileprivate func cancel() {
// Cancel current dispatches
dispatchIO?.close(flags: .stop)
dispatchIO = nil

reset()
}

fileprivate func reset() {

let dispatchIO =
DispatchIO(type: .stream, fileDescriptor: fileHandle.fileDescriptor, queue: .taskPriority) { error in
let closeError: Error?
Expand Down
150 changes: 142 additions & 8 deletions Tests/IOStreamsTests/FileStreamsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ final class FileStreamsTests: XCTestCase {
let fileHandle = try FileHandle(forUpdating: fileURL)
try fileHandle.truncate(atOffset: UInt64(fileSize))
try fileHandle.seek(toOffset: 0)
try fileHandle.close()

let source = try FileSource(fileHandle: fileHandle)
let source = try FileSource(url: fileURL)

for try await _ in source.buffers() {
// read all buffers to test bytesRead
Expand All @@ -55,12 +56,12 @@ final class FileStreamsTests: XCTestCase {
let fileHandle = try FileHandle(forUpdating: fileURL)
try fileHandle.truncate(atOffset: UInt64(fileSize))
try fileHandle.seek(toOffset: 0)
try fileHandle.close()

let source = try FileSource(url: fileURL)

let reader = Task {
for try await data in source.buffers(size: 3079) {
_ = data.count
for try await _ /* data */ in source.buffers(size: 3079) {
// print("Read \(data.count) bytes of data")
}
}
Expand All @@ -80,12 +81,105 @@ final class FileStreamsTests: XCTestCase {
let fileHandle = try FileHandle(forUpdating: fileURL)
try fileHandle.truncate(atOffset: UInt64(fileSize))
try fileHandle.seek(toOffset: 0)
try fileHandle.close()

let source = try FileSource(fileHandle: fileHandle)
let source = try FileSource(url: fileURL)

let reader = Task {
for try await _ in source.buffers(size: 133) {
// read all buffers to test bytesRead
withUnsafeCurrentTask { $0!.cancel() }
}
}

do {
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}

XCTAssert(source.bytesRead > 0, "Data should have been read from source")
XCTAssert(source.bytesRead < fileSize, "Source should have cancelled iteration")
}

func testSourceContinuesAfterCancel() async throws {

let fileSize = 256 * 1024
let fileHandle = try FileHandle(forUpdating: fileURL)
try fileHandle.truncate(atOffset: UInt64(fileSize))
try fileHandle.seek(toOffset: 0)
try fileHandle.close()

let source = try FileSource(url: fileURL)

let reader = Task {
for try await _ /* data */ in source.buffers(size: 3079) {
// print("Read \(data.count) bytes of data")
}
}

do {
reader.cancel()
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}

XCTAssertEqual(source.bytesRead, 0)

do {
_ = try await source.read(exactly: 1000)
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}

XCTAssertEqual(source.bytesRead, 1000)
}

func testSinkCancels() async throws {

let source = DataSource(data: Data(count: 1024 * 1024))
let sink = try FileSink(url: fileURL)

let reader = Task {
for try await buffer in source.buffers() {
try await sink.write(data: buffer)
}
}

do {
reader.cancel()
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}

XCTAssertEqual(sink.bytesWritten, 0)
}

func testSinkCancelsAfterStart() async throws {

let source = DataSource(data: Data(count: 1024 * 1024))
let sink = try FileSink(url: fileURL)

let reader = Task {
for try await buffer in source.buffers(size: 113) {
try await sink.write(data: buffer)
}
}

Expand All @@ -94,11 +188,51 @@ final class FileStreamsTests: XCTestCase {
do {
reader.cancel()
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}
catch is CancellationError {}

XCTAssert(source.bytesRead > 0, "Data should have been read from source")
XCTAssert(source.bytesRead < fileSize, "Source should have cancelled iteration")
XCTAssert(sink.bytesWritten > 0, "Data should have been written to sink")
XCTAssert(sink.bytesWritten < source.data.count, "Sink should have cancelled iteration")
}

func testSinkContinuesAfterCancel() async throws {

let source = DataSource(data: Data(count: 1024 * 1024))
let sink = try FileSink(url: fileURL)

let reader = Task {
for try await buffer in source.buffers(size: 100) {
try await sink.write(data: buffer)
}
}

do {
reader.cancel()
try await reader.value
XCTFail("Expected cancellation error")
}
catch is CancellationError {
// expected
}
catch {
XCTFail("Unexpected error thrown: \(error.localizedDescription)")
}

XCTAssertEqual(sink.bytesWritten, 0)
XCTAssertEqual(try fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize, 0)

try await sink.write(data: Data(count: 1000))
try sink.close()

XCTAssertEqual(sink.bytesWritten, 1000)
fileURL.removeAllCachedResourceValues()
XCTAssertEqual(try fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize, 1000)
}

func testInvalidFileSourceThrows() async throws {
Expand Down

0 comments on commit f7b8291

Please sign in to comment.