Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/using-executorch-ios.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ let inputTensor = Tensor<Float>(&imageBuffer, shape: [1, 3, 224, 224])
let outputTensor: Tensor<Float> = try module.forward(inputTensor)[0].tensor()!

// Copy the tensor data into logits array for easier access.
let logits = try outputTensor.scalars()
let logits = outputTensor.scalars()

// Use logits...
```
Expand Down Expand Up @@ -444,11 +444,11 @@ Swift:
let tensor = Tensor<Float>([1.0, 2.0, 3.0, 4.0], shape: [2, 2])

// Get data copy as a Swift array.
let scalars = try tensor.scalars()
let scalars = tensor.scalars()
print("All scalars: \(scalars)") // [1.0, 2.0, 3.0, 4.0]

// Access data via a buffer pointer.
try tensor.withUnsafeBytes { buffer in
tensor.withUnsafeBytes { buffer in
print("First float element: \(buffer.first ?? 0.0)")
}

Expand Down Expand Up @@ -482,7 +482,7 @@ Swift:
let tensor = Tensor<Float>([1.0, 2.0, 3.0, 4.0], shape: [2, 2])

// Modify the tensor's data in place.
try tensor.withUnsafeMutableBytes { buffer in
tensor.withUnsafeMutableBytes { buffer in
buffer[1] = 200.0
}
// tensor's data is now [1.0, 200.0, 3.0, 4.0]
Expand Down
39 changes: 16 additions & 23 deletions extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -770,35 +770,29 @@ public final class Tensor<T: Scalar>: Equatable {
/// - Parameter body: A closure that receives an `UnsafeBufferPointer<T>` bound to the tensor’s data.
/// - Returns: The value returned by `body`.
/// - Throws: Any error thrown by `body`.
public func withUnsafeBytes<R>(_ body: (UnsafeBufferPointer<T>) throws -> R) throws -> R {
var result: Result<R, Error>?
anyTensor.bytes { pointer, count, _ in
result = Result { try body(
UnsafeBufferPointer(
start: pointer.assumingMemoryBound(to: T.self),
count: count
)
) }
public func withUnsafeBytes<R>(_ body: (UnsafeBufferPointer<T>) throws -> R) rethrows -> R {
try withoutActuallyEscaping(body) { body in
var result: Result<R, Error>?
anyTensor.bytes { pointer, count, _ in
result = Result { try body(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: T.self), count: count)) }
}
return try result!.get()
}
return try result!.get()
}

/// Calls the closure with a typed, mutable buffer pointer over the tensor’s elements.
///
/// - Parameter body: A closure that receives an `UnsafeMutableBufferPointer<T>` bound to the tensor’s data.
/// - Returns: The value returned by `body`.
/// - Throws: Any error thrown by `body`.
public func withUnsafeMutableBytes<R>(_ body: (UnsafeMutableBufferPointer<T>) throws -> R) throws -> R {
var result: Result<R, Error>?
anyTensor.mutableBytes { pointer, count, _ in
result = Result { try body(
UnsafeMutableBufferPointer(
start: pointer.assumingMemoryBound(to: T.self),
count: count
)
) }
public func withUnsafeMutableBytes<R>(_ body: (UnsafeMutableBufferPointer<T>) throws -> R) rethrows -> R {
try withoutActuallyEscaping(body) { body in
var result: Result<R, Error>?
anyTensor.mutableBytes { pointer, count, _ in
result = Result { try body(UnsafeMutableBufferPointer(start: pointer.assumingMemoryBound(to: T.self), count: count)) }
}
return try result!.get()
}
return try result!.get()
}

/// Resizes the tensor to a new shape.
Expand Down Expand Up @@ -830,9 +824,8 @@ public extension Tensor {
/// Returns the tensor's elements as an array of scalars.
///
/// - Returns: An array of scalars of type `T`.
/// - Throws: An error if the underlying data cannot be accessed.
func scalars() throws -> [T] {
try withUnsafeBytes { Array($0) }
func scalars() -> [T] {
withUnsafeBytes { Array($0) }
}
}

Expand Down
Loading
Loading