Skip to content

Commit ccfc0f5

Browse files
authored
Merge pull request #21944 from rxwei/broadcasting-assignment
[TF] [stdlib] Add `.=` for broadcasting assignment.
2 parents 580d9d4 + 76482c7 commit ccfc0f5

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

stdlib/public/TensorFlow/Ops.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ infix operator .>= : ComparisonPrecedence
4646
infix operator .> : ComparisonPrecedence
4747
infix operator .== : ComparisonPrecedence
4848
infix operator .!= : ComparisonPrecedence
49+
infix operator .=
4950

5051
// TODO:
5152
// - Consider explicit broadcasting for elementwise binary ops when
@@ -1440,6 +1441,11 @@ public extension Tensor where Scalar : Numeric {
14401441
func unbroadcast(to shape: TensorShape) -> Tensor {
14411442
return unbroadcast(toShape: Tensor<Int32>(shape.dimensions))
14421443
}
1444+
1445+
@inlinable @inline(__always)
1446+
static func .= (lhs: inout Tensor, rhs: Tensor) {
1447+
lhs = rhs.broadcast(like: lhs)
1448+
}
14431449
}
14441450

14451451
//===----------------------------------------------------------------------===//

test/TensorFlowRuntime/tensor.swift

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -466,18 +466,6 @@ TensorTests.testAllBackends("ReshapeTensor") {
466466
expectEqual([1, 3, 1, 2, 1], result.shape)
467467
}
468468

469-
// FIXME: This test crashes in dynamic compilation + GPU.
470-
#if !CUDA
471-
TensorTests.testAllBackends("BroadcastTensor") {
472-
// 1 -> 2 x 3 x 4
473-
let one = Tensor<Float>(1)
474-
let target = Tensor<Float>(shape: [2, 3, 4], repeating: 0.0)
475-
let broadcasted = one.broadcast(like: target)
476-
expectEqual([2, 3, 4], broadcasted.shape)
477-
expectEqual(Array(repeating: 1, count: 24), broadcasted.scalars)
478-
}
479-
#endif // !CUDA
480-
481469
TensorTests.testAllBackends("Unbroadcast1") {
482470
let x = Tensor<Float>(shape: [2, 3, 4, 5], repeating: 1)
483471
let y = Tensor<Float>(shape: [4, 5], repeating: 1)

test/TensorFlowRuntime/tensor_api.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,14 @@ TensorNonTPUTests.testAllBackends("SliceUpdate") {
3030
expectEqual(ShapedArray(shape:[2, 3], repeating: false), t4.array)
3131
}
3232

33+
TensorNonTPUTests.testAllBackends("BroadcastTensor") {
34+
// 1 -> 2 x 3 x 4
35+
let one = Tensor<Float>(1)
36+
var target = Tensor<Float>(shape: [2, 3, 4], repeating: 0.0)
37+
let broadcasted = one.broadcast(like: target)
38+
expectEqual(Tensor(shape: [2, 3, 4], repeating: 1), broadcasted)
39+
target .= Tensor(shape: [1, 3, 1], repeating: 1)
40+
expectEqual(Tensor(shape: [2, 3, 4], repeating: 1), target)
41+
}
42+
3343
runAllTests()

0 commit comments

Comments
 (0)