Skip to content

Commit

Permalink
Handle transpose of input matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
reikdas committed Aug 4, 2021
1 parent 72eb040 commit b50f837
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions enzyme/Enzyme/AdjointGenerator.h
Expand Up @@ -4058,6 +4058,8 @@ class AdjointGenerator
Mode == DerivativeMode::ReverseModeGradient) {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
assert(call.getArgOperand(0) == Builder2.getInt32(101) &&
"Unhandled Order");
auto dfunc = gutils->oldFunc->getParent()->getOrInsertFunction(
funcName, Builder2.getVoidTy(), Builder2.getInt32Ty(),
Builder2.getInt32Ty(), Builder2.getInt32Ty(), Builder2.getInt32Ty(),
Expand All @@ -4066,38 +4068,61 @@ class AdjointGenerator
Builder2.getInt32Ty(), call.getArgOperand(7)->getType(),
Builder2.getInt32Ty(), call.getArgOperand(6)->getType(),
call.getArgOperand(7)->getType(), Builder2.getInt32Ty());
auto zeroval = Builder2.getInt32(0);
auto dzeroval = Builder2.CreateBitCast(zeroval, innerType);
auto oneval = Builder2.getInt32(1);
auto doneval = Builder2.CreateBitCast(oneval, innerType);
Value *sabtrans, *sabcol, *sbatrans, *sacol, *sbarow;
if (call.getArgOperand(1) == Builder2.getInt32(112)) {
sbatrans = Builder2.getInt32(111);
sacol = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)),
Builder2);
sbarow = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)),
Builder2);
} else if (call.getArgOperand(1) == Builder2.getInt32(111)) {
sbatrans = Builder2.getInt32(112);
sacol = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)),
Builder2);
sbarow = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)),
Builder2);
}
if (call.getArgOperand(2) == Builder2.getInt32(112)) {
sabtrans = Builder2.getInt32(111);
sabcol = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)),
Builder2);
} else if (call.getArgOperand(2) == Builder2.getInt32(111)) {
sabtrans = Builder2.getInt32(112);
sabcol = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)),
Builder2);
} else
assert(false && "Unhandled: Notify developers");
SmallVector<Value *, 13> safuncargs = {
lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2),
Builder2.getInt32(112),
Builder2.getInt32(111),
sabtrans,
lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(6)), Builder2),
gutils->invertPointerM(call.getArgOperand(12), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(9)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2),
dzeroval,
sabcol,
doneval,
gutils->invertPointerM(call.getArgOperand(7), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(5)),
Builder2)};
sacol};
auto safunccall = Builder2.CreateCall(dfunc, safuncargs);
SmallVector<Value *, 13> sbfuncargs = {
lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
Builder2.getInt32(112),
sbatrans,
Builder2.getInt32(111),
lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(6)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(7)), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), Builder2),
sbarow,
gutils->invertPointerM(call.getArgOperand(12), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), Builder2),
dzeroval,
lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), Builder2),
doneval,
gutils->invertPointerM(call.getArgOperand(9), Builder2),
lookup(gutils->getNewFromOriginal(call.getArgOperand(4)),
Builder2)};
Expand All @@ -4115,10 +4140,8 @@ class AdjointGenerator
gutils->invertPointerM(call.getArgOperand(12), Builder2),
Builder2.getInt32(1)};
auto scfunccall = Builder2.CreateCall(scfunc, scfuncargs);
// setDiffe(&call,
// Constant::getNullValue(call.getArgOperand(7)->getType()), Builder2);
return;
}
return;
}
}

Expand Down

0 comments on commit b50f837

Please sign in to comment.