Skip to content

Commit 13f448c

Browse files
authored
fix problems caused by the rebase to c8ac18d (#17217)
Includes cherry-pick of d093879 to fix a test. These fixes make everything compile, and all the test pass.
1 parent e9fd7aa commit 13f448c

24 files changed

+91
-86
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ DECL_ATTR(differentiable, Differentiable,
383383
OnFunc | LongAttribute, 78)
384384

385385
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
386-
OnFunc | OnConstructor | OnSubscript,
386+
OnAccessor | OnFunc | OnConstructor | OnSubscript,
387387
/* Not serialized */ 79)
388388

389389
#undef TYPE_ATTR

include/swift/SILOptimizer/Utils/SILInliner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ InlineCost instructionInlineCost(SILInstruction &I);
4040
/// Scan the given function body, mandatory inlining calls to callees that
4141
/// satisfy the specified predicate, including call sites exposed by inlining
4242
/// other functions.
43-
typedef std::function<bool (FullApplySite site, const SILFunction &callee)>
43+
typedef std::function<bool(FullApplySite site, SILFunction &callee)>
4444
ShouldMandatoryInlineFnPred;
4545
void inlineForTFDeabstraction(SILFunction &fn,
4646
const ShouldMandatoryInlineFnPred &predicate);

lib/AST/ASTPrinter.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4165,14 +4165,8 @@ void ProtocolConformance::printName(llvm::raw_ostream &os,
41654165
case ProtocolConformanceKind::Specialized: {
41664166
auto spec = cast<SpecializedProtocolConformance>(this);
41674167
os << "specialize <";
4168-
auto subMap = spec->getSubstitutionMap();
4169-
auto genericParams =
4170-
spec->getGenericConformance()->getGenericSignature()->getGenericParams();
4171-
interleave(genericParams,
4172-
[&](Type type) {
4173-
type.subst(subMap).print(os, PO);
4174-
4175-
},
4168+
interleave(spec->getSubstitutionMap().toList(),
4169+
[&](const Substitution &s) { s.print(os, PO); },
41764170
[&] { os << ", "; });
41774171

41784172
os << "> (";

lib/SIL/SILPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
10651065
printDebugVar(ABI->getVarInfo());
10661066
}
10671067

1068+
void printSubstitutions(SubstitutionMap Subs) {
1069+
printSubstitutions(Subs.toList());
1070+
}
1071+
10681072
void printSubstitutions(SubstitutionList Subs) {
10691073
if (Subs.empty())
10701074
return;

lib/SILOptimizer/Mandatory/MandatoryInlining.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -407,20 +407,13 @@ static SILFunction *getCalleeFunction(
407407
return nullptr;
408408
}
409409

410+
// TODO(SR-8015): `shouldInlinePredicate` does too much. Fix this.
410411
// If the CalleeFunction is a not-transparent definition, we can not process
411412
// it.
412413
// SWIFT_ENABLE_TENSORFLOW
413414
if (!shouldInlinePredicate(AI, *CalleeFunction))
414415
return nullptr;
415416

416-
// If CalleeFunction is a declaration, see if we can load it.
417-
if (CalleeFunction->empty())
418-
AI.getModule().loadFunction(CalleeFunction);
419-
420-
// If we fail to load it, bail.
421-
if (CalleeFunction->empty())
422-
return nullptr;
423-
424417
if (F->isSerialized() &&
425418
!CalleeFunction->hasValidLinkageForFragileInline()) {
426419
if (!CalleeFunction->hasValidLinkageForFragileRef()) {
@@ -679,8 +672,16 @@ class MandatoryInlining : public SILModuleTransform {
679672
SetFactory, SetFactory.getEmptySet(), CHA,
680673
// SWIFT_ENABLE_TENSORFLOW
681674
SILInliner::InlineKind::MandatoryInline,
682-
[&](FullApplySite site, const SILFunction &callee) -> bool {
683-
return callee.isTransparent() == IsTransparent;
675+
[&](FullApplySite site, SILFunction &callee) -> bool {
676+
if (callee.isTransparent() != IsTransparent)
677+
return false;
678+
679+
// If CalleeFunction is a declaration, see if we can load it.
680+
if (callee.empty())
681+
site.getModule().loadFunction(&callee);
682+
683+
// If we failed to load it, bail.
684+
return !callee.empty();
684685
}
685686
);
686687
}

lib/SILOptimizer/Mandatory/TFConstExpr.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212

1313
#define DEBUG_TYPE "TFConstExpr"
1414
#include "TFConstExpr.h"
15-
#include "swift/SIL/SILBuilder.h"
16-
#include "swift/SIL/SILConstants.h"
17-
#include "swift/Serialization/SerializedSILLoader.h"
1815
#include "swift/AST/ProtocolConformance.h"
1916
#include "swift/AST/SubstitutionMap.h"
2017
#include "swift/Basic/Defer.h"
18+
#include "swift/Demangling/Demangle.h"
19+
#include "swift/SIL/FormalLinkage.h"
20+
#include "swift/SIL/SILBuilder.h"
21+
#include "swift/SIL/SILConstants.h"
22+
#include "swift/Serialization/SerializedSILLoader.h"
2123
#include "llvm/Support/CommandLine.h"
2224

2325
using namespace swift;

lib/SILOptimizer/Mandatory/TFDeabstraction.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,12 @@ void TFDeabstraction::inlineCalls() {
247247
// Use the mandatory inlining algorithm to expose call sites that contain
248248
// TensorFlow values as their argument or result lists.
249249
inlineForTFDeabstraction(fn,
250-
[&](FullApplySite site, const SILFunction &callee) -> bool {
250+
[&](FullApplySite site, SILFunction &callee) -> bool {
251+
if (callee.empty() &&
252+
!site.getModule().linkFunction(&callee,
253+
SILModule::LinkingMode::LinkAll))
254+
return nullptr;
255+
251256
if (!shouldInline(site, callee))
252257
return false;
253258

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,22 @@
2626
#include "swift/AST/DiagnosticsSIL.h"
2727
#include "swift/AST/Expr.h"
2828
#include "swift/AST/GenericEnvironment.h"
29-
#include "swift/AST/SubstitutionMap.h"
30-
#include "swift/AST/SubstitutionList.h"
31-
#include "swift/AST/ProtocolConformance.h"
29+
#include "swift/AST/Module.h"
3230
#include "swift/AST/ParameterList.h"
31+
#include "swift/AST/ProtocolConformance.h"
32+
#include "swift/AST/SubstitutionList.h"
33+
#include "swift/AST/SubstitutionMap.h"
3334
#include "swift/Basic/Defer.h"
34-
#include "swift/Serialization/SerializedSILLoader.h"
3535
#include "swift/SIL/AbstractionPattern.h"
3636
#include "swift/SIL/Dominance.h"
37-
#include "swift/SIL/SILCloner.h"
37+
#include "swift/SIL/FormalLinkage.h"
3838
#include "swift/SIL/SILBuilder.h"
39+
#include "swift/SIL/SILCloner.h"
3940
#include "swift/SIL/TypeLowering.h"
4041
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
4142
#include "swift/SILOptimizer/PassManager/Passes.h"
4243
#include "swift/SILOptimizer/PassManager/Transforms.h"
44+
#include "swift/Serialization/SerializedSILLoader.h"
4345
#include "llvm/ADT/DenseMap.h"
4446
#include "llvm/ADT/Hashing.h"
4547
#include "llvm/ADT/SmallPtrSet.h"
@@ -977,9 +979,8 @@ static void convertFromIntegerLiteral(intmax_t value,
977979
auto *ebilProto =
978980
astCtx.getProtocol(KnownProtocolKind::ExpressibleByBuiltinIntegerLiteral);
979981
// `init(_builtinIntegerLiteral:)`
980-
DeclName builtinLitInitName(astCtx, astCtx.Id_init, {
981-
astCtx.getIdentifier("_builtinIntegerLiteral")
982-
});
982+
DeclName builtinLitInitName(astCtx, DeclBaseName::createConstructor(),
983+
{astCtx.getIdentifier("_builtinIntegerLiteral")});
983984
auto *initBILDecl =
984985
cast<ConstructorDecl>(ebilProto->lookupDirect(builtinLitInitName)[0]);
985986
SILDeclRef initBILDeclRef(initBILDecl);
@@ -1022,9 +1023,8 @@ static void convertFromIntegerLiteral(intmax_t value,
10221023
// `ExpressibleByIntegerLiteral.init(integerLiteral: %4)`.
10231024
auto *eilProto =
10241025
astCtx.getProtocol(KnownProtocolKind::ExpressibleByIntegerLiteral);
1025-
DeclName intLitInitName(astCtx, astCtx.Id_init, {
1026-
astCtx.getIdentifier("integerLiteral")
1027-
});
1026+
DeclName intLitInitName(astCtx, DeclBaseName::createConstructor(),
1027+
{astCtx.getIdentifier("integerLiteral")});
10281028
auto *initILDecl =
10291029
cast<ConstructorDecl>(eilProto->lookupDirect(intLitInitName)[0]);
10301030
SILDeclRef initILDeclRef(initILDecl);
@@ -1091,7 +1091,8 @@ static void convertToIndirectSeed(intmax_t value, CanType type,
10911091
CanMetatypeType::get(type, MetatypeRepresentation::Thick));
10921092
auto *metatype = builder.createMetatype(loc, metatypeTy);
10931093
// Call `init(_:)` through `VectorNumeric` protocol.
1094-
DeclName initName(astCtx, astCtx.Id_init, { Identifier() });
1094+
DeclName initName(astCtx, DeclBaseName::createConstructor(),
1095+
{Identifier()});
10951096
// Allocate buffer for passing the indirect scalar value.
10961097
// %2 = alloc_stack $<scalar type>
10971098
auto scalarValBuf =
@@ -1242,9 +1243,9 @@ static SILFunction *getOrCreateGradient(
12421243
} else {
12431244
auto loq = seedSILTy.isTrivial(module)
12441245
? LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take;
1245-
auto seedBufAccess =
1246-
builder.createBeginAccess(loc, seedBuf, SILAccessKind::Read,
1247-
SILAccessEnforcement::Static);
1246+
auto seedBufAccess = builder.createBeginAccess(
1247+
loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
1248+
/*noNestedConflict=*/false);
12481249
auto seed = builder.createLoad(loc, seedBufAccess, loq);
12491250
builder.createEndAccess(loc, seedBufAccess, /*aborted*/ false);
12501251
args.push_back(seed);

lib/SILOptimizer/Mandatory/TFPartition.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2615,7 +2615,8 @@ static SILValue createHostSend(SILBuilder &B, SILLocation loc, SILValue value,
26152615
B.createStore(loc, value, stackAlloc, storeOwnership);
26162616

26172617
auto access = B.createBeginAccess(loc, stackAlloc, SILAccessKind::Read,
2618-
SILAccessEnforcement::Static);
2618+
SILAccessEnforcement::Static,
2619+
/*noNestedConflict=*/false);
26192620

26202621
auto createScalarTensorFnRef =
26212622
B.createFunctionRef(loc, createScalarTensorFn);
@@ -3106,7 +3107,8 @@ static SILValue convertScalarToHostTensorHandle(SILValue value, SILBuilder &B,
31063107
B.createStore(loc, value, stackAlloc, storeOwnership);
31073108

31083109
auto access = B.createBeginAccess(loc, stackAlloc, SILAccessKind::Read,
3109-
SILAccessEnforcement::Static);
3110+
SILAccessEnforcement::Static,
3111+
/*noNestedConflict=*/false);
31103112

31113113
Substitution sub(value->getType().getSwiftRValueType(), {});
31123114
auto result = B.createApply(loc, fnRef, {sub}, { access, dtype },
@@ -3396,7 +3398,8 @@ insertTensorComputationStartEndTerminate(ArrayRef<SILValue> resultValues)
33963398
// address of the first element as an UnsafePointer<CTensorHandle>.
33973399
// %1 = begin_access [read] [static] %0 : $*(CTensorHandle, CTensorHandle)
33983400
auto access = B.createBeginAccess(loc, stackAlloc, SILAccessKind::Read,
3399-
SILAccessEnforcement::Static);
3401+
SILAccessEnforcement::Static,
3402+
/*noNestedConflict=*/false);
34003403

34013404
// %2 = tuple_element_addr %1 : $*(CTensorHandle, CTensorHandle), 0
34023405
SILValue firstPtr = access;
@@ -3465,7 +3468,8 @@ insertTensorComputationStartEndTerminate(ArrayRef<SILValue> resultValues)
34653468
// UnsafePointer<CTensorHandle>.
34663469
// %1 = begin_access [write] [static] %0 : $*(CTensorHandle, CTensorHandle)
34673470
access = B.createBeginAccess(loc, stackAlloc, SILAccessKind::Modify,
3468-
SILAccessEnforcement::Static);
3471+
SILAccessEnforcement::Static,
3472+
/*noNestedConflict=*/false);
34693473

34703474
// %2 = tuple_element_addr %1 : $*(CTensorHandle, CTensorHandle), 0
34713475
firstPtr = access;

lib/SILOptimizer/Mandatory/TFUtilities.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "TFUtilities.h"
1515
#include "swift/AST/ASTContext.h"
1616
#include "swift/AST/DiagnosticsSIL.h"
17+
#include "swift/AST/Module.h"
1718
#include "swift/SIL/SILBuilder.h"
1819
#include "swift/SIL/SILModule.h"
1920
#include "swift/SILOptimizer/Utils/Local.h"
@@ -1571,7 +1572,7 @@ bool TensorFunctionClassifier::shouldBePartitioned(SILFunction *fn) {
15711572
return false;
15721573

15731574
auto hasInlinableAttrs = [&](Decl *decl) -> bool {
1574-
if (decl->getAttrs().hasAttribute<InlineableAttr>())
1575+
if (decl->getAttrs().hasAttribute<InlinableAttr>())
15751576
return true;
15761577
return false;
15771578
};

0 commit comments

Comments
 (0)