From 2dde4d4105fa21dd6a6941a9162b47bfbdab2e6e Mon Sep 17 00:00:00 2001 From: solomon Date: Wed, 10 Feb 2016 17:10:49 +0100 Subject: [PATCH] permitted handling of userdefined elementwise functions with non-identity alpha, since this is needed when scaling for symmetric permutations --- examples/dft.cxx | 2 ++ src/contraction/sym_seq_ctr.cxx | 16 +++++++++------- src/interface/fun_term.cxx | 3 ++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/dft.cxx b/examples/dft.cxx index a509442b..0e0f6932 100644 --- a/examples/dft.cxx +++ b/examples/dft.cxx @@ -47,6 +47,8 @@ int test_dft(int64_t n, std::complex (0.0, 0.0), "ik");*/ DFT["ik"] = .5*DFT["ij"]*IDFT["jk"]; + Scalar< std::complex > ss(wrld); + ss[""] = ((Function< std::complex, std::complex, std::complex >)([](std::complex a, std::complex b){ return a+b; }))(DFT["ij"],DFT["ij"]); DFT.read_local(&np, &idx, &data); int pass = 1; diff --git a/src/contraction/sym_seq_ctr.cxx b/src/contraction/sym_seq_ctr.cxx index 1d224534..11964204 100644 --- a/src/contraction/sym_seq_ctr.cxx +++ b/src/contraction/sym_seq_ctr.cxx @@ -248,18 +248,20 @@ printf("HERE1\n"); } CTF_FLOPS_ADD(2*(imax-imin)); } else { - ASSERT(0); - assert(0); + //ASSERT(0); + //assert(0); //printf("HERTE alpha = %d\n",*(int*)alpha); for (int i=imin; iel_size]; - sr_C->mul(A+offsets_A[0][i], + func->apply_f(A+offsets_A[0][i], + B+offsets_B[0][i], + tmp); + sr_C->mul(tmp, alpha, tmp); - func->acc_f(tmp, - B+offsets_B[0][i], - C+offsets_C[0][i], - sr_C); + sr_C->add(tmp, + C+offsets_C[0][i], + C+offsets_C[0][i]); } CTF_FLOPS_ADD(3*(imax-imin)); } diff --git a/src/interface/fun_term.cxx b/src/interface/fun_term.cxx index 35c3fa8a..5c531fb0 100644 --- a/src/interface/fun_term.cxx +++ b/src/interface/fun_term.cxx @@ -104,7 +104,8 @@ namespace CTF_int { ASSERT(0); assert(0); } - contraction c(opA.parent, opA.idx_map, opB.parent, opB.idx_map, NULL, output.parent, output.idx_map, output.scale, func); + contraction c(opA.parent, opA.idx_map, opB.parent, opB.idx_map, output.sr->mulid(), output.parent, output.idx_map, output.scale, func); + //contraction c(opA.parent, opA.idx_map, opB.parent, opB.idx_map, NULL, output.parent, output.idx_map, output.scale, func); c.execute(); // if (scl != NULL) cdealloc(scl); }