Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/IRGen/Linking.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ class LinkEntity {
setForDifferentiabilityWitness(Kind kind,
const SILDifferentiabilityWitness *witness) {
Pointer = const_cast<void *>(static_cast<const void *>(witness));
SecondaryPointer = nullptr;
Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind));
}
// SWIFT_ENABLE_TENSORFLOW_END
Expand Down
10 changes: 6 additions & 4 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,14 @@ class SILBuilder {
NormalDifferentiableFunctionTypeComponent::Original, TheFunction));
}

DifferentiabilityWitnessFunctionInst *
createDifferentiabilityWitnessFunction(
/// Note: explicit function type may be specified only in lowered SIL.
DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction(
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,
SILDifferentiabilityWitness *Witness) {
SILDifferentiabilityWitness *Witness,
Optional<SILType> FunctionType = None) {
return insert(new (getModule()) DifferentiabilityWitnessFunctionInst(
getModule(), getSILDebugLocation(Loc), WitnessKind, Witness));
getModule(), getSILDebugLocation(Loc), WitnessKind, Witness,
FunctionType));
}
// SWIFT_ENABLE_TENSORFLOW END

Expand Down
11 changes: 5 additions & 6 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -8111,27 +8111,26 @@ class DifferentiabilityWitnessFunctionInst
DifferentiabilityWitnessFunctionKind witnessKind;
/// The referenced SIL differentiability witness.
SILDifferentiabilityWitness *witness;
/// Whether the instruction has an explicit function type.
bool hasExplicitFunctionType;

static SILType getDifferentiabilityWitnessType(
SILModule &module,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness);

public:
/// Note: explicit function type may be specified only in lowered SIL.
DifferentiabilityWitnessFunctionInst(
SILModule &module, SILDebugLocation loc,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness);

static DifferentiabilityWitnessFunctionInst *create(
SILModule &module, SILDebugLocation loc,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness);
SILDifferentiabilityWitness *witness, Optional<SILType> FunctionType);

DifferentiabilityWitnessFunctionKind getWitnessKind() const {
return witnessKind;
}
SILDifferentiabilityWitness *getWitness() const { return witness; }
bool getHasExplicitFunctionType() const { return hasExplicitFunctionType; }

ArrayRef<Operand> getAllOperands() const { return {}; }
MutableArrayRef<Operand> getAllOperands() { return {}; }
Expand Down
9 changes: 9 additions & 0 deletions include/swift/SIL/SILModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ class SILModule {
/// Lookup table for SIL differentiability witnesses, keyed by mangled name.
llvm::StringMap<SILDifferentiabilityWitness *> DifferentiabilityWitnessMap;

/// Lookup table for SILDifferentiabilityWitnesses, keyed by original
/// function name.
llvm::StringMap<llvm::SmallVector<SILDifferentiabilityWitness *, 1>>
DifferentiabilityWitnessesByFunction;

/// The list of SILDifferentiabilityWitnesses in the module.
DifferentiabilityWitnessListType differentiabilityWitnesses;
// SWIFT_ENABLE_TENSORFLOW END
Expand Down Expand Up @@ -614,6 +619,10 @@ class SILModule {
SILDifferentiabilityWitness *
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);

/// Look up the differentiability witness corresponding to the given function.
llvm::ArrayRef<SILDifferentiabilityWitness *>
lookUpDifferentiabilityWitnessesForFunction(StringRef name);

/// Attempt to deserialize the SILDifferentiabilityWitness. Returns true if
/// deserialization succeeded, false otherwise.
bool loadDifferentiabilityWitness(SILDifferentiabilityWitness *W);
Expand Down
71 changes: 71 additions & 0 deletions include/swift/SILOptimizer/Utils/DerivativeLookup.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===--- DerivativeLookup.h -----------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// SWIFT_ENABLE_TENSORFLOW
//
// Utilities for looking up derivatives of functions.
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_SILOPTIMIZER_UTILS_DERIVATIVELOOKUP_H
#define SWIFT_SILOPTIMIZER_UTILS_DERIVATIVELOOKUP_H

#include "swift/AST/AutoDiff.h"
#include "swift/SIL/SILDeclRef.h"
#include "swift/SIL/SILModule.h"

namespace swift {

/// Returns a differentiability witness (definition or declaration) exactly
/// matching the specified indices. If none are found in the given `module`,
/// returns `nullptr`.
///
/// \param parameterIndices must be lowered to SIL.
/// \param resultIndices must be lowered to SIL.
SILDifferentiabilityWitness *
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
IndexSubset *parameterIndices,
IndexSubset *resultIndices);

/// Finds the "@differentiable" attribute on `original` whose parameter indices
/// are a minimal superset of the specified parameter indices. Returns `nullptr`
/// if no such attribute exists.
///
/// \param parameterIndices must be lowered to SIL.
/// \param minimalParameterIndices is an output parameter that is set to the SIL
/// indices of the minimal attribute, or to `nullptr` if no attribute exists.
const DifferentiableAttr *
getMinimalASTDifferentiableAttr(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalParameterIndices);

/// Returns a differentiability witness for `original` whose parameter indices
/// are a minimal superset of the specified parameter indices and whose result
/// indices match the given result indices, out of all
/// differentiability witnesses that come from AST "@differentiable" or
/// "@differentiating" attributes.
///
/// This function never creates new differentiability witness definitions.
/// However, this function may create new differentiability witness declarations
/// referring to definitions in other modules when these witnesses have not yet
/// been declared in the current module.
///
/// \param module is the SILModule in which to get or create the witnesses.
/// \param parameterIndices must be lowered to SIL.
/// \param resultIndices must be lowered to SIL.
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
SILModule &module, SILFunction *original, IndexSubset *parameterIndices,
IndexSubset *resultIndices);

} // end namespace swift

#endif // SWIFT_SILOPTIMIZER_UTILS_DERIVATIVELOOKUP_H
22 changes: 7 additions & 15 deletions lib/IRGen/GenDiffWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,13 @@ void IRGenModule::emitSILDifferentiabilityWitness(
ConstantInitBuilder builder(*this);
auto diffWitnessContents = builder.beginStruct();

// TODO(TF-894): When the differentiation transform canonicalizes all
// differentiability witnesses to have JVP/VJP functions, remove the nullptr
// cases and assert that JVP/VJP functions exist.
if (dw->getJVP()) {
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
} else {
diffWitnessContents.addNullPointer(Int8PtrTy);
}
if (dw->getVJP()) {
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
} else {
diffWitnessContents.addNullPointer(Int8PtrTy);
}
assert(dw->getJVP() && "diff witness should be canonicalized");
assert(dw->getVJP() && "diff witness should be canonicalized");

diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);

getAddrOfDifferentiabilityWitness(
dw, diffWitnessContents.finishAndCreateFuture());
Expand Down
35 changes: 35 additions & 0 deletions lib/IRGen/LoadableByAddress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,10 @@ class LoadableByAddress : public SILModuleTransform {
bool fixStoreToBlockStorageInstr(SILInstruction &I,
SmallVectorImpl<SILInstruction *> &Delete);

// SWIFT_ENABLE_TENSORFLOW
bool recreateDifferentiabilityWitnessFunction(
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);

private:
llvm::SetVector<SILFunction *> modFuncs;
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
Expand Down Expand Up @@ -2672,6 +2676,34 @@ bool LoadableByAddress::fixStoreToBlockStorageInstr(
return true;
}

// SWIFT_ENABLE_TENSORFLOW
bool LoadableByAddress::recreateDifferentiabilityWitnessFunction(
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
auto *instr = dyn_cast<DifferentiabilityWitnessFunctionInst>(&I);
if (!instr)
return false;

// Check if we need to recreate the instruction.
auto *currIRMod = getIRGenModule()->IRGen.getGenModule(instr->getFunction());
auto resultFnTy = instr->getType().castTo<SILFunctionType>();
auto genSig = resultFnTy->getSubstGenericSignature();
GenericEnvironment *genEnv = nullptr;
if (genSig)
genEnv = genSig->getGenericEnvironment();
auto newResultFnTy =
MapperCache.getNewSILFunctionType(genEnv, resultFnTy, *currIRMod);
if (resultFnTy == newResultFnTy)
return true;

SILBuilderWithScope builder(instr);
auto *newInstr = builder.createDifferentiabilityWitnessFunction(
instr->getLoc(), instr->getWitnessKind(), instr->getWitness(),
SILType::getPrimitiveObjectType(newResultFnTy));
instr->replaceAllUsesWith(newInstr);
Delete.push_back(instr);
return true;
}

bool LoadableByAddress::recreateTupleInstr(
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
auto *tupleInstr = dyn_cast<TupleInst>(&I);
Expand Down Expand Up @@ -2994,6 +3026,9 @@ void LoadableByAddress::run() {
continue;
else if (recreateApply(I, Delete))
continue;
// SWIFT_ENABLE_TENSORFLOW
else if (recreateDifferentiabilityWitnessFunction(I, Delete))
continue;
else
fixStoreToBlockStorageInstr(I, Delete);
}
Expand Down
11 changes: 9 additions & 2 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3217,8 +3217,15 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
P.diagnose(keyStartLoc, diag::sil_diff_witness_undefined);
return true;
}
ResultVal = B.createDifferentiabilityWitnessFunction(
InstLoc, witnessKind, witness);
// Parse an optional explicit function type.
Optional<SILType> functionType = None;
if (P.consumeIf(tok::kw_as)) {
functionType = SILType();
if (parseSILType(*functionType))
return true;
}
ResultVal = B.createDifferentiabilityWitnessFunction(InstLoc, witnessKind,
witness, functionType);
break;
}
// SWIFT_ENABLE_TENSORFLOW END
Expand Down
4 changes: 4 additions & 0 deletions lib/SIL/SILDifferentiabilityWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration(
assert(!module.DifferentiabilityWitnessMap.count(mangledKey) &&
"Cannot create duplicate differentiability witness in a module");
module.DifferentiabilityWitnessMap[mangledKey] = diffWitness;
module.DifferentiabilityWitnessesByFunction[originalFunction->getName()]
.push_back(diffWitness);
module.getDifferentiabilityWitnessList().push_back(diffWitness);
return diffWitness;
}
Expand All @@ -56,6 +58,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
assert(!module.DifferentiabilityWitnessMap.count(mangledKey) &&
"Cannot create duplicate differentiability witness in a module");
module.DifferentiabilityWitnessMap[mangledKey] = diffWitness;
module.DifferentiabilityWitnessesByFunction[originalFunction->getName()]
.push_back(diffWitness);
module.getDifferentiabilityWitnessList().push_back(diffWitness);
return diffWitness;
}
Expand Down
11 changes: 7 additions & 4 deletions lib/SIL/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,13 @@ SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType(
DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst(
SILModule &module, SILDebugLocation debugLoc,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness)
: InstructionBase(debugLoc, getDifferentiabilityWitnessType(
module, witnessKind, witness)),
witnessKind(witnessKind), witness(witness) {}
SILDifferentiabilityWitness *witness, Optional<SILType> FunctionType)
: InstructionBase(debugLoc, FunctionType
? *FunctionType
: getDifferentiabilityWitnessType(
module, witnessKind, witness)),
witnessKind(witnessKind), witness(witness),
hasExplicitFunctionType(FunctionType) {}
// SWIFT_ENABLE_TENSORFLOW END

FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind,
Expand Down
6 changes: 6 additions & 0 deletions lib/SIL/SILModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,12 @@ SILModule::lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key) {
mangler.mangleSILDifferentiabilityWitnessKey(key));
}

/// Look up the differentiability witness corresponding to the given indices.
llvm::ArrayRef<SILDifferentiabilityWitness *>
SILModule::lookUpDifferentiabilityWitnessesForFunction(StringRef name) {
return DifferentiabilityWitnessesByFunction[name];
}

bool SILModule::loadDifferentiabilityWitness(SILDifferentiabilityWitness *W) {
auto *NewW = getSILLoader()->lookupDifferentiabilityWitness(W->getKey());
if (!NewW)
Expand Down
4 changes: 4 additions & 0 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
*this << " ";
}
printSILFunctionNameAndType(PrintState.OS, witness->getOriginalFunction());
if (dwfi->getHasExplicitFunctionType()) {
*this << " as ";
*this << dwfi->getType();
}
}
// SWIFT_ENABLE_TENSORFLOW END

Expand Down
10 changes: 3 additions & 7 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
if (auto *vjpDecl = diffAttr->getVJPFunction())
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
AutoDiffConfig config(
diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature().getPointer());
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature());
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
}
};
Expand Down Expand Up @@ -819,14 +818,11 @@ void SILGenModule::emitDifferentiabilityWitness(
};
bool reorderSelf = shouldReorderSelf();

CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = config.derivativeGenericSignature)
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
// Create new SIL differentiability witness.
// Witness JVP and VJP are set below.
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
config.resultIndices, derivativeCanGenSig,
config.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr, originalFunction->isSerialized());

// Set derivative function in differentiability witness.
Expand Down
1 change: 1 addition & 0 deletions lib/SILOptimizer/Mandatory/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ silopt_register_sources(
DefiniteInitialization.cpp
# SWIFT_ENABLE_TENSORFLOW
Differentiation.cpp
# SWIFT_ENABLE_TENSORFLOW END
DIMemoryUseCollector.cpp
DataflowDiagnostics.cpp
DiagnoseInfiniteRecursion.cpp
Expand Down
Loading