Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Commit

Permalink
Refine tiled (#689)
Browse files Browse the repository at this point in the history
Add tiled API and tests with refined preconditions.
  • Loading branch information
t-ae committed Feb 21, 2020
1 parent 62e5802 commit 84e0da0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
19 changes: 19 additions & 0 deletions Sources/TensorFlow/Operators/Basic.swift
Expand Up @@ -132,6 +132,25 @@ public extension Tensor {
splitDim: Tensor<Int32>(Int32(axis)),
numSplit: Int64(sizes.shape[0]))
}

/// Returns a tiled tensor, constructed by tiling this tensor.
///
/// This constructor creates a new tensor by replicating this tensor `multiples` times. The
/// constructed tensor's `i`'th dimension has `self.shape[i] * multiples[i]` elements, and the
/// values of this tensor are replicated `multiples[i]` times along the `i`'th dimension. For
/// example, tiling `[a b c d]` by `[2]` produces `[a b c d a b c d]`.
///
/// - Precondition: The expected `rank` of multiples must be `1`.
/// - Precondition: The shape of `multiples` must be `[tensor.rank]`.
/// - Precondition: All scalars in `multiples` must be non-negative.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func tiled(multiples: [Int]) -> Tensor {
precondition(multiples.allSatisfy { $0 >= 0 },
"All scalars in multiples must be non-negative.")
// TODO(TF-433): Remove workaround for differentiating `map`.
return tiled(multiples: Tensor<Int32>({multiples.map(Int32.init)}()))
}

/// Returns a tiled tensor, constructed by tiling this tensor.
///
Expand Down
11 changes: 11 additions & 0 deletions Tests/TensorFlowTests/OperatorTests/BasicTests.swift
Expand Up @@ -475,6 +475,16 @@ final class BasicOperatorTests: XCTestCase {
XCTAssertEqual(grad.shape, [3, 2, 1])
XCTAssertEqual(grad.scalars, [1, 1, 1, 1, 1, 1])
}

func testTile() {
let tensor = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])
let tiled = tensor.tiled(multiples: [3, 2])

XCTAssertEqual(tiled.shape, [6, 6])
XCTAssertEqual(tiled, [[0, 1, 2, 0, 1, 2], [3, 4, 5, 3, 4, 5],
[0, 1, 2, 0, 1, 2], [3, 4, 5, 3, 4, 5],
[0, 1, 2, 0, 1, 2], [3, 4, 5, 3, 4, 5]])
}

func testReshape() {
// 2 x 3 -> 1 x 3 x 1 x 2 x 1
Expand Down Expand Up @@ -692,6 +702,7 @@ final class BasicOperatorTests: XCTestCase {
("testConcatenation", testConcatenation),
("testVJPConcatenation", testVJPConcatenation),
("testTranspose", testTranspose),
("testTile", testTile),
("testReshape", testReshape),
("testFlatten", testFlatten),
("testFlatten0D", testFlatten0D),
Expand Down
10 changes: 10 additions & 0 deletions Tests/TensorFlowTests/TensorAutoDiffTests.swift
Expand Up @@ -482,6 +482,15 @@ final class TensorAutoDiffTests: XCTestCase {
XCTAssertEqual(pullback(at: [[3, 5]], in: f1)([1, 1]), [[6, 10]])
XCTAssertEqual(pullback(at: [[3, 5]], in: f2)([1, 1]), [[6, 10]])
}

func testTiled() {
let input = Tensor<Float>([[1, 2, 3], [4, 5, 6]])
let tiledPullback = pullback(at: input) { (a: Tensor<Float>) in
a.tiled(multiples: [2, 1])
}
let tiled = Tensor<Float>([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
XCTAssertEqual(input * 2, tiledPullback(tiled))
}

func testReshapedBackprop() {
func f1(a: Tensor<Float>) -> Tensor<Float> { a.reshaped(toShape: Tensor<Int32>([2, 1])).squared() }
Expand Down Expand Up @@ -795,6 +804,7 @@ final class TensorAutoDiffTests: XCTestCase {
("testTensorInitStacking", testTensorInitStacking),
("testExpandingShape", testExpandingShape),
("testSqueezingShape", testSqueezingShape),
("testTiled", testTiled),
("testReshapedBackprop", testReshapedBackprop),
("testReshaped", testReshaped),
("testConcatenationPlusPlus", testConcatenationPlusPlus),
Expand Down

0 comments on commit 84e0da0

Please sign in to comment.