diff --git a/stdlib/public/TensorFlow/TensorShape.swift b/stdlib/public/TensorFlow/TensorShape.swift index b8a604185bada..091bc3385137c 100644 --- a/stdlib/public/TensorFlow/TensorShape.swift +++ b/stdlib/public/TensorFlow/TensorShape.swift @@ -159,3 +159,9 @@ extension TensorShape : Codable { self.init(dimensions) } } + +extension TensorShape : CustomStringConvertible { + public var description: String { + return dimensions.description + } +} diff --git a/test/TensorFlowRuntime/tensor.swift b/test/TensorFlowRuntime/tensor.swift index 08d63630555cc..408291c4199ef 100644 --- a/test/TensorFlowRuntime/tensor.swift +++ b/test/TensorFlowRuntime/tensor.swift @@ -623,6 +623,11 @@ TensorTests.testAllBackends("SimpleCond") { expectEqual(0, selectValue(true).scalar) } +TensorTests.testAllBackends("TensorShapeDescription") { + expectEqual("[2, 2]", Tensor(ones: [2, 2]).shape.description) + expectEqual("[]", Tensor(1).shape.description) +} + @inline(never) func testXORInference() { func xor(_ x: Float, _ y: Float) -> Float {