@@ -2276,7 +2276,7 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22762276 loc, getterVJPRef, /* substitutionMap*/ {},
22772277 /* args*/ {getMappedValue (sei->getOperand ())}, /* isNonThrowing*/ false );
22782278
2279- // Get the VJP results (original results and pullback)
2279+ // Get the VJP results (original results and pullback).
22802280 SmallVector<SILValue, 8 > vjpDirectResults;
22812281 extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
22822282 ArrayRef<SILValue> originalDirectResults =
@@ -2291,6 +2291,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22912291
22922292 // Checkpoint the original results.
22932293 getPrimalInfo ().addStaticPrimalValueDecl (sei);
2294+ getBuilder ().createRetainValue (loc, originalDirectResult,
2295+ getBuilder ().getDefaultAtomicity ());
22942296 staticPrimalValues.push_back (originalDirectResult);
22952297
22962298 // Checkpoint the pullback.
@@ -3612,28 +3614,14 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36123614 // Construct the pullback arguments.
36133615 SmallVector<SILValue, 8 > args;
36143616 auto seed = getAdjointValue (sei);
3615- auto *seedBuf = builder.createAllocStack (loc, seed.getType ());
3616- materializeAdjointIndirectHelper (seed, seedBuf);
3617- if (seed.getType ().isAddressOnly (getModule ()))
3618- args.push_back (seedBuf);
3619- else {
3620- auto access = builder.createBeginAccess (
3621- loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3622- /* noNestedConflict*/ true ,
3623- /* fromBuiltin*/ false );
3624- args.push_back (builder.createLoad (
3625- loc, access, getBufferLOQ (seed.getSwiftType (), getAdjoint ())));
3626- builder.createEndAccess (loc, access, /* aborted*/ false );
3627- }
3617+ assert (seed.getType ().isObject ());
3618+ args.push_back (materializeAdjointDirect (seed, loc));
36283619
36293620 // Call the pullback.
36303621 auto *pullbackCall = builder.createApply (loc, pullback, SubstitutionMap (),
36313622 args, /* isNonThrowing*/ false );
36323623 assert (!pullbackCall->hasIndirectResults ());
36333624
3634- // Clean up seed allocation.
3635- builder.createDeallocStack (loc, seedBuf);
3636-
36373625 // Set adjoint for the `struct_extract` operand.
36383626 addAdjointValue (sei->getOperand (),
36393627 AdjointValue::getMaterialized (pullbackCall));
0 commit comments