From a806e4215bb6d97dd22857ac56f509cf8b8736ad Mon Sep 17 00:00:00 2001 From: ksasi Date: Wed, 24 Apr 2019 19:47:37 +0000 Subject: [PATCH 1/2] TensorShape.swift file updated to improve TensorShape printing --- stdlib/public/TensorFlow/TensorShape.swift | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/stdlib/public/TensorFlow/TensorShape.swift b/stdlib/public/TensorFlow/TensorShape.swift index b8a604185bada..a75b1a671f6f0 100644 --- a/stdlib/public/TensorFlow/TensorShape.swift +++ b/stdlib/public/TensorFlow/TensorShape.swift @@ -159,3 +159,10 @@ extension TensorShape : Codable { self.init(dimensions) } } + +extension TensorShape : CustomStringConvertible { + @inlinable + public var description: String { + return "TensorShape(\(dimensions))" + } +} From 04cf7da6118fbdb1566ef9c95c036bc310cff485 Mon Sep 17 00:00:00 2001 From: ksasi Date: Thu, 25 Apr 2019 05:06:55 +0000 Subject: [PATCH 2/2] TensorShape.swift file updated to improve TensorShape printing Added test to test/TensorFlowRuntime/tensor.swift --- stdlib/public/TensorFlow/TensorShape.swift | 5 ++--- test/TensorFlowRuntime/tensor.swift | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/stdlib/public/TensorFlow/TensorShape.swift b/stdlib/public/TensorFlow/TensorShape.swift index a75b1a671f6f0..091bc3385137c 100644 --- a/stdlib/public/TensorFlow/TensorShape.swift +++ b/stdlib/public/TensorFlow/TensorShape.swift @@ -161,8 +161,7 @@ extension TensorShape : Codable { } extension TensorShape : CustomStringConvertible { - @inlinable public var description: String { - return "TensorShape(\(dimensions))" - } + return dimensions.description + } } diff --git a/test/TensorFlowRuntime/tensor.swift b/test/TensorFlowRuntime/tensor.swift index 08d63630555cc..408291c4199ef 100644 --- a/test/TensorFlowRuntime/tensor.swift +++ b/test/TensorFlowRuntime/tensor.swift @@ -623,6 +623,11 @@ TensorTests.testAllBackends("SimpleCond") { expectEqual(0, selectValue(true).scalar) } +TensorTests.testAllBackends("TensorShapeDescription") { + expectEqual("[2, 2]", Tensor(ones: [2, 2]).shape.description) + expectEqual("[]", Tensor(1).shape.description) +} + @inline(never) func testXORInference() { func xor(_ x: Float, _ y: Float) -> Float {