Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 43 additions & 50 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,9 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// See the comment on the declaration of the enum case
// `GraphOperationInfo::ArgumentLowering::TensorAttribute` for more
// information on why this cannot be dynamic.
// TODO(SR-9166): Emit a nice diagnostic earlier in the compiler, so that
// we don't get an assertion failure when we write a
// #tfop(..., value$tensor: ...).
assert(0 && "TensorAttributes cannot be dynamic");
case GraphOperationInfo::ArgumentLowering::Input: {
assert(argumentName.empty() && "inputs cannot have names");
Expand Down Expand Up @@ -2796,8 +2799,6 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// compute the flag by iterating over all the results and checking if they are
// known TensorFlow values that do not need to go through the TensorGroup
// machinery.
// TODO: We can add more types of known TensorFlow values to reduce the amount
// of work that happens at runtime.
bool hasOpaqueTensorGroupResults = false;
if (outParameterAddress.isValid()) {
LLVM_DEBUG(llvm::dbgs() << " Has indirect result of type "
Expand All @@ -2809,9 +2810,7 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
for (auto silResult : i->getResults()) {
LLVM_DEBUG(llvm::dbgs() << " Direct result of type "
<< silResult->getType() << ".\n");
if (silResult->getType().getASTType() == astCtx.TheEmptyTupleType)
continue;
if (tf::isTensorFlowValue(silResult->getType()))
if (tf::isTensorFlowValueOrAggregate(silResult->getType().getASTType()))
continue;
hasOpaqueTensorGroupResults = true;
break;
Expand Down Expand Up @@ -2847,23 +2846,16 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// add up their counts.
expectedReturnValueCount = llvm::ConstantInt::get(IGM.Int32Ty, 0);
for (auto silResult : i->getResults()) {
// If the result is Void, it corresponds to 0 outputs.
if (silResult->getType().getASTType() == astCtx.TheEmptyTupleType) {
directResultTypeMetadatas.push_back(nullptr);
directResultTensorGroupWitnessTables.push_back(nullptr);
auto *cTensorHandleCount = llvm::ConstantInt::get(IGM.Int32Ty, 0);
directResultCTensorHandleCounts.push_back(cTensorHandleCount);
continue;
}

// If the result is a known TensorFlow type, it corresponds to just 1
// output.
if (tf::isTensorFlowValue(silResult->getType())) {
// If the result is a known TensorFlow type or aggregate of known
// TensorFlow types, then we can count it directly.
SmallVector<Type, 4> flattenedTensorFlowTypes;
if (tf::flattenTensorFlowValueAggregate(
silResult->getType().getASTType(), flattenedTensorFlowTypes)) {
directResultTypeMetadatas.push_back(nullptr);
directResultTensorGroupWitnessTables.push_back(nullptr);
auto *cTensorHandleCount = llvm::ConstantInt::get(IGM.Int32Ty, 1);
auto *cTensorHandleCount = llvm::ConstantInt::get(
IGM.Int32Ty, flattenedTensorFlowTypes.size());
directResultCTensorHandleCounts.push_back(cTensorHandleCount);
// Note that `CreateAdd` constant-folds ConstantInts.
expectedReturnValueCount = Builder.CreateAdd(expectedReturnValueCount,
cTensorHandleCount);
continue;
Expand Down Expand Up @@ -2983,42 +2975,43 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
auto tfOutputAddr = Builder.CreateInBoundsGEP(returnValuesAddress,
tfOutputIdx);

// If the result is Void, it corresponds to 0 outputs.
if (silResult->getType().getASTType() == astCtx.TheEmptyTupleType) {
// If the result is a known TensorFlow type or aggregate of known
// TensorFlow types, get them directly.
SmallVector<Type, 4> flattenedTensorFlowTypes;
if (tf::flattenTensorFlowValueAggregate(
silResult->getType().getASTType(), flattenedTensorFlowTypes)) {
Explosion e;
for (auto tensorFlowType : flattenedTensorFlowTypes) {
auto cTensorHandle = Builder.CreateLoad(tfOutputAddr,
IGM.getPointerAlignment());

// Wrap `cTensorHandle` into a _AnyTensorHandle object, and get an
// untyped pointer to the _AnyTensorHandle.
auto *createHandleFn = IGM.getTFC_CreateTensorHandleFromCFn();
llvm::Value *tensorHandle = Builder.CreateCall(createHandleFn,
{cTensorHandle});

// Cast to a pointer of the expected result type (for example,
// TensorHandle<Float>).
auto silType = SILType::getPrimitiveObjectType(
tensorFlowType->getCanonicalType());
tensorHandle = Builder.CreateBitCast(tensorHandle,
IGM.getStorageType(silType));

// Set the result.
e.add(tensorHandle);

// We consumed one output.
// Note that `CreateAdd` constant-folds ConstantInts.
tfOutputIdx = Builder.CreateAdd(
tfOutputIdx, llvm::ConstantInt::get(IGM.Int32Ty, 1));
tfOutputAddr = Builder.CreateInBoundsGEP(returnValuesAddress,
tfOutputIdx);
}
setLoweredExplosion(silResult, e);
continue;
}

// If the result is a known TensorFlow type, get it directly.
if (tf::isTensorFlowValue(silResult->getType())) {
auto cTensorHandle = Builder.CreateLoad(tfOutputAddr,
IGM.getPointerAlignment());

// Wrap `cTensorHandle` into a _AnyTensorHandle object, and get an untyped
// pointer to the _AnyTensorHandle.
auto *createHandleFn = IGM.getTFC_CreateTensorHandleFromCFn();
llvm::Value *tensorHandle = Builder.CreateCall(createHandleFn,
{cTensorHandle});

// Cast to a pointer of the expected result type (for example,
// TensorHandle<Float>).
tensorHandle = Builder.CreateBitCast(
tensorHandle, IGM.getStorageType(silResult->getType()));

// Set the result.
Explosion e;
e.add(tensorHandle);
setLoweredExplosion(silResult, e);

// We consumed one output.
// Note that `CreateAdd` constant-folds ConstantInts.
tfOutputIdx = Builder.CreateAdd(
tfOutputIdx, llvm::ConstantInt::get(IGM.Int32Ty, 1));

continue;
}

// Otherwise, the result type must conform to TensorGroup, so we can
// create it using TFC_InitTensorGroup.

Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/dataset_1.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down
4 changes: 2 additions & 2 deletions test/TensorFlowRuntime/dynamic_attributes.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
// REQUIRES: tensorflow
Expand Down Expand Up @@ -213,7 +213,7 @@ DynamicAttributeTests.test("NormalAttribute Float") {
DynamicAttributeTests.test("NormalAttribute String") {
let result: Tensor<Float> = #tfop("Conv2D", convImage, convFilter,
T$dtype: Float.tensorFlowDataType,
strides: [1, 1, 1, 1],
strides: [1, 1, 1, 1] as [Int32],
padding: loadVALIDString())
expectEqual(convExpectedResult, result.array)
}
Expand Down
9 changes: 1 addition & 8 deletions test/TensorFlowRuntime/dynamic_compilation.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %target-run-simple-swift
// TODO: Revert to %target-run-simple-swift once we fold dynamic compilation into -Onone.
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize

Expand All @@ -14,13 +14,6 @@ import StdlibUnittest

var DynamicCompilationTests = TestSuite("DynamicCompilation")

DynamicCompilationTests.testCPUOrGPU("Const") {
_RuntimeConfig.printsDebugLog = true
let x: TensorHandle<Float> = #tfop("Const", dtype$dtype: Float.tensorFlowDataType, value$tensor: Float(1.0))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case people do this, will we crash, or give a good error message? ideally should we have a test for that? if so pls consider filing a bug or adding a TODO.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. we just crash now. I will file a bug.

_hostOp(x)
expectNearlyEqualWithScalarTensor(1.0, Tensor<Float>(handle: x))
}

DynamicCompilationTests.testCPUOrGPU("ScalarNonConst") {
_RuntimeConfig.printsDebugLog = true
func scalarInitializer_CreateHostTensor(_ x: Float) {
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/models.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/shaped_array.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_api.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_debuglog.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_xla_debuglog.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-dynamic-compilation-swift
// RUN: %target-run-disable-deabstraction-swift
// REQUIRES: executable_test
// REQUIRES: swift_test_mode_optimize
//
Expand Down