Skip to content

Commit

Permalink
Add PhaseRootNMask to QInterface, implement for QEngineCPU (#1009)
Browse files Browse the repository at this point in the history
* Add `PhaseRootNMask` to `QInterface`, implement for `QEngineCPU`

* fix sign convention and expand test

* Add openCL implementation of `PhaseRootNMask`, +1 unit test
  • Loading branch information
jpacold committed Jun 17, 2024
1 parent e63620a commit 90af46e
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/common/oclapi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum OCLAPI {
OCL_API_Z_SINGLE,
OCL_API_Z_SINGLE_WIDE,
OCL_API_PHASE_PARITY,
OCL_API_PHASE_MASK,
OCL_API_ROL,
OCL_API_APPROXCOMPARE,
OCL_API_NORMALIZE,
Expand Down
1 change: 1 addition & 0 deletions include/qengine_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class QEngineCPU : public QEngine {

void XMask(bitCapInt mask);
void PhaseParity(real1_f radians, bitCapInt mask);
void PhaseRootNMask(bitLenInt n, bitCapInt mask);

/**
* \defgroup ArithGate Arithmetic and other opcode-like gate implemenations.
Expand Down
1 change: 1 addition & 0 deletions include/qengine_opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ class QEngineOCL : public QEngine {
void Phase(complex topLeft, complex bottomRight, bitLenInt qubitIndex);
void XMask(bitCapInt mask);
void PhaseParity(real1_f radians, bitCapInt mask);
void PhaseRootNMask(bitLenInt n, bitCapInt mask);

using QEngine::Compose;
bitLenInt Compose(QEngineOCLPtr toCopy);
Expand Down
7 changes: 7 additions & 0 deletions include/qinterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,13 @@ class QInterface : public ParallelFor {
Phase(ONE_CMPLX, pow(-ONE_CMPLX, (real1)(ONE_R1 / pow2Ocl(n - 1U))), qubit);
}

/**
* Masked PhaseRootN gate
*
* Applies a 1/(2^N) phase rotation to all qubits in the mask.
*/
virtual void PhaseRootNMask(bitLenInt n, bitCapInt mask);

/**
* Inverse "PhaseRootN" gate
*
Expand Down
1 change: 1 addition & 0 deletions src/common/oclengine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const std::vector<OCLKernelHandle> OCLEngine::kernelHandles{
OCLKernelHandle(OCL_API_Z_SINGLE, "zsingle"),
OCLKernelHandle(OCL_API_Z_SINGLE_WIDE, "zsinglewide"),
OCLKernelHandle(OCL_API_PHASE_PARITY, "phaseparity"),
OCLKernelHandle(OCL_API_PHASE_MASK, "phasemask"),
OCLKernelHandle(OCL_API_COMPOSE, "compose"),
OCLKernelHandle(OCL_API_COMPOSE_WIDE, "compose"),
OCLKernelHandle(OCL_API_COMPOSE_MID, "composemid"),
Expand Down
28 changes: 28 additions & 0 deletions src/common/qengine.cl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ inline cmplx conj(const cmplx cmp)
return (cmplx)(cmp.x, -cmp.y);
}

inline cmplx polar_unit(real1 theta) {
return (cmplx)(cos(theta), sin(theta));
}

#define OFFSET2_ARG bitCapIntOclPtr[0]
#define OFFSET1_ARG bitCapIntOclPtr[1]
#define MAXI_ARG bitCapIntOclPtr[2]
Expand Down Expand Up @@ -310,6 +314,30 @@ void kernel phaseparity(global cmplx* stateVec, constant bitCapIntOcl* bitCapInt
}
}

void kernel phasemask(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr, constant real1* angle)
{
const bitCapIntOcl Nthreads = get_global_size(0);
const bitCapIntOcl maxI = bitCapIntOclPtr[0];
const bitCapIntOcl mask = bitCapIntOclPtr[1];
const bitCapIntOcl nPhases = bitCapIntOclPtr[2];
const real1 phaseAngle = angle[0];

for (bitCapIntOcl lcv = ID; lcv < maxI; lcv += Nthreads) {
bitCapIntOcl popCount = 0;
bitCapIntOcl v = lcv & mask;
while (v) {
popCount += v & 1;
v >>= 1;
}

const bitCapIntOcl nPhaseSteps = popCount % nPhases;
if (nPhaseSteps != 0) {
const cmplx phaseFactor = polar_unit(nPhaseSteps * phaseAngle);
stateVec[lcv] = zmul(phaseFactor, stateVec[lcv]);
}
}
}

void kernel zsingle(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr)
{
const bitCapIntOcl Nthreads = get_global_size(0);
Expand Down
44 changes: 44 additions & 0 deletions src/qengine/opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,50 @@ void QEngineOCL::PhaseParity(real1_f radians, bitCapInt mask)
BitMask((bitCapIntOcl)mask, OCL_API_PHASE_PARITY, radians);
}

void QEngineOCL::PhaseRootNMask(bitLenInt n, bitCapInt mask)
{
if (bi_compare_0(mask) == 0) {
return;
}

const bitCapIntOcl oclMask = (bitCapIntOcl)mask;
if (oclMask >= maxQPowerOcl) {
throw std::invalid_argument("QEngineOCL::BitMask mask out-of-bounds!");
}

CHECK_ZERO_SKIP();

const bitCapIntOcl nPhases = pow2Ocl(n);
const real1_f radians[1] = { -PI_R1 / pow2Ocl(n - 1U) };

if (isPowerOfTwo(mask)) {
const complex phaseFac = std::polar(ONE_R1, radians[0]);
Phase(ONE_CMPLX, phaseFac, log2(mask));
return;
}

const bitCapIntOcl bciArgs[BCI_ARG_LEN]{ maxQPowerOcl, oclMask, nPhases, 0U, 0U, 0U, 0U, 0U, 0U, 0U };
PoolItemPtr poolItem = GetFreePoolItem();

{
EventVecPtr waitVec = ResetWaitEvents();

cl::Event writeIntArgsEvent;
DISPATCH_TEMP_WRITE(waitVec, *(poolItem->ulongBuffer), sizeof(bitCapIntOcl) * 3, bciArgs, writeIntArgsEvent);

cl::Event writeRealArgsEvent;
DISPATCH_LOC_WRITE(*(poolItem->realBuffer), sizeof(real1), radians, writeRealArgsEvent);

writeIntArgsEvent.wait();
writeRealArgsEvent.wait();
wait_refs.clear();
}

const size_t ngc = FixWorkItemCount(bciArgs[0], nrmGroupCount);
const size_t ngs = FixGroupSize(ngc, nrmGroupSize);
QueueCall(OCL_API_PHASE_MASK, ngc, ngs, { stateBuffer, poolItem->ulongBuffer, poolItem->realBuffer });
}

void QEngineOCL::Apply2x2(bitCapIntOcl offset1, bitCapIntOcl offset2, const complex* mtrx, bitLenInt bitCount,
const bitCapIntOcl* qPowersSorted, bool doCalcNorm, SPECIAL_2X2 special, real1_f norm_thresh)
{
Expand Down
60 changes: 60 additions & 0 deletions src/qengine/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,66 @@ void QEngineCPU::PhaseParity(real1_f radians, bitCapInt mask)
});
}

void QEngineCPU::PhaseRootNMask(bitLenInt n, bitCapInt mask)
{
if (bi_compare(mask, maxQPower) >= 0) {
throw std::invalid_argument("QEngineCPU::PhaseRootNMask mask out-of-bounds!");
}
if (n > sizeof(bitCapIntOcl)) {
throw std::invalid_argument("QEngineCPU::PhaseRootNMask: power of 2 out-of-bounds");
}

CHECK_ZERO_SKIP();

if (bi_compare_0(mask) == 0) {
return;
}

if (n == 0) {
return;
}
if (n == 1) {
ZMask(mask);
return;
}

const real1_f radians = -PI_R1 / pow2Ocl(n - 1U);

if (isPowerOfTwo(mask)) {
const complex phaseFac = std::polar(ONE_R1, radians);
Phase(ONE_CMPLX, phaseFac, log2(mask));
return;
}

if (stateVec->is_sparse()) {
QInterface::PhaseRootNMask(n, mask);
return;
}

Dispatch(maxQPowerOcl, [this, n, mask, radians] {
const bitCapIntOcl maskOcl = (bitCapIntOcl)mask;
const bitCapIntOcl nPhases = pow2Ocl(n);
ParallelFunc fn = [&](const bitCapIntOcl& lcv, const unsigned& cpu) {
bitCapIntOcl popCount = 0;
{
bitCapIntOcl v = lcv & maskOcl;
while (v) {
popCount += v & 1;
v >>= 1;
}
}

const bitCapIntOcl nPhaseSteps = popCount % nPhases;
if (nPhaseSteps != 0) {
const complex phaseFac = std::polar(ONE_R1, radians * nPhaseSteps);
stateVec->write(lcv, phaseFac * stateVec->read(lcv));
}
};

par_for(0U, maxQPowerOcl, fn);
});
}

void QEngineCPU::UniformlyControlledSingleBit(const std::vector<bitLenInt>& controls, bitLenInt qubitIndex,
const complex* mtrxs, const std::vector<bitCapInt>& mtrxSkipPowers, bitCapInt mtrxSkipValueMask)
{
Expand Down
10 changes: 10 additions & 0 deletions src/qinterface/gates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ void QInterface::ZMask(bitCapInt mask)
}
}

void QInterface::PhaseRootNMask(bitLenInt n, bitCapInt mask)
{
bitCapInt v = mask;
while (bi_compare_0(mask) != 0) {
v = v & (v - ONE_BCI);
PhaseRootN(n, log2(mask ^ v));
mask = v;
}
}

void QInterface::Swap(bitLenInt q1, bitLenInt q2)
{
if (q1 == q2) {
Expand Down
24 changes: 24 additions & 0 deletions test/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,30 @@ TEST_CASE_METHOD(QInterfaceTestFixture, "test_zmask")
REQUIRE_THAT(qftReg, HasProbability(0, 20, 0x80001));
}

TEST_CASE_METHOD(QInterfaceTestFixture, "test_phaserootnmask")
{
constexpr BIG_INTEGER_WORD ket = 14062;
constexpr BIG_INTEGER_WORD masks[6] = { 8, 3097, 22225, 16051, 62894, 49134 };
constexpr uint16_t n = 3;
const uint16_t modulus = pow2Ocl(n);
// phaseCounts[ii] = popcount(ket & masks[ii])
constexpr uint16_t phaseCounts[6] = { 1, 2, 5, 7, 8, 10 };

qftReg->SetPermutation(ket);
REQUIRE_THAT(qftReg, HasProbability(0, 20, ket));

for (int ii = 0; ii < 6; ii++) {
const real1_f angle = -PI_R1 * (phaseCounts[ii] % modulus) / pow2Ocl(n - 1U);
const complex expectedPhaseFactor = std::polar(ONE_R1, angle);
const complex amp_before = qftReg->GetAmplitude(ket);

qftReg->PhaseRootNMask(n, masks[ii]);
const complex amp_after = qftReg->GetAmplitude(ket);

REQUIRE_CMPLX(amp_after / amp_before, expectedPhaseFactor);
}
}

TEST_CASE_METHOD(QInterfaceTestFixture, "test_approxcompare")
{
qftReg =
Expand Down

0 comments on commit 90af46e

Please sign in to comment.