Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PhaseRootNMask to QInterface, implement for QEngineCPU #1009

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/common/oclapi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum OCLAPI {
OCL_API_Z_SINGLE,
OCL_API_Z_SINGLE_WIDE,
OCL_API_PHASE_PARITY,
OCL_API_PHASE_MASK,
OCL_API_ROL,
OCL_API_APPROXCOMPARE,
OCL_API_NORMALIZE,
Expand Down
1 change: 1 addition & 0 deletions include/qengine_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class QEngineCPU : public QEngine {

void XMask(bitCapInt mask);
void PhaseParity(real1_f radians, bitCapInt mask);
void PhaseRootNMask(bitLenInt n, bitCapInt mask);

/**
* \defgroup ArithGate Arithmetic and other opcode-like gate implemenations.
Expand Down
1 change: 1 addition & 0 deletions include/qengine_opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ class QEngineOCL : public QEngine {
void Phase(complex topLeft, complex bottomRight, bitLenInt qubitIndex);
void XMask(bitCapInt mask);
void PhaseParity(real1_f radians, bitCapInt mask);
void PhaseRootNMask(bitLenInt n, bitCapInt mask);

using QEngine::Compose;
bitLenInt Compose(QEngineOCLPtr toCopy);
Expand Down
7 changes: 7 additions & 0 deletions include/qinterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,13 @@ class QInterface : public ParallelFor {
Phase(ONE_CMPLX, pow(-ONE_CMPLX, (real1)(ONE_R1 / pow2Ocl(n - 1U))), qubit);
}

/**
* Masked PhaseRootN gate
*
* Applies a 1/(2^N) phase rotation to all qubits in the mask.
*/
virtual void PhaseRootNMask(bitLenInt n, bitCapInt mask);

/**
* Inverse "PhaseRootN" gate
*
Expand Down
1 change: 1 addition & 0 deletions src/common/oclengine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const std::vector<OCLKernelHandle> OCLEngine::kernelHandles{
OCLKernelHandle(OCL_API_Z_SINGLE, "zsingle"),
OCLKernelHandle(OCL_API_Z_SINGLE_WIDE, "zsinglewide"),
OCLKernelHandle(OCL_API_PHASE_PARITY, "phaseparity"),
OCLKernelHandle(OCL_API_PHASE_MASK, "phasemask"),
OCLKernelHandle(OCL_API_COMPOSE, "compose"),
OCLKernelHandle(OCL_API_COMPOSE_WIDE, "compose"),
OCLKernelHandle(OCL_API_COMPOSE_MID, "composemid"),
Expand Down
28 changes: 28 additions & 0 deletions src/common/qengine.cl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ inline cmplx conj(const cmplx cmp)
return (cmplx)(cmp.x, -cmp.y);
}

inline cmplx polar_unit(real1 theta) {
return (cmplx)(cos(theta), sin(theta));
}

#define OFFSET2_ARG bitCapIntOclPtr[0]
#define OFFSET1_ARG bitCapIntOclPtr[1]
#define MAXI_ARG bitCapIntOclPtr[2]
Expand Down Expand Up @@ -310,6 +314,30 @@ void kernel phaseparity(global cmplx* stateVec, constant bitCapIntOcl* bitCapInt
}
}

void kernel phasemask(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr, constant real1* angle)
{
const bitCapIntOcl Nthreads = get_global_size(0);
const bitCapIntOcl maxI = bitCapIntOclPtr[0];
const bitCapIntOcl mask = bitCapIntOclPtr[1];
const bitCapIntOcl nPhases = bitCapIntOclPtr[2];
const real1 phaseAngle = angle[0];

for (bitCapIntOcl lcv = ID; lcv < maxI; lcv += Nthreads) {
bitCapIntOcl popCount = 0;
bitCapIntOcl v = lcv & mask;
while (v) {
popCount += v & 1;
v >>= 1;
}

const bitCapIntOcl nPhaseSteps = popCount % nPhases;
if (nPhaseSteps != 0) {
const cmplx phaseFactor = polar_unit(nPhaseSteps * phaseAngle);
stateVec[lcv] = zmul(phaseFactor, stateVec[lcv]);
}
}
}

void kernel zsingle(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr)
{
const bitCapIntOcl Nthreads = get_global_size(0);
Expand Down
44 changes: 44 additions & 0 deletions src/qengine/opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,50 @@ void QEngineOCL::PhaseParity(real1_f radians, bitCapInt mask)
BitMask((bitCapIntOcl)mask, OCL_API_PHASE_PARITY, radians);
}

void QEngineOCL::PhaseRootNMask(bitLenInt n, bitCapInt mask)
{
if (bi_compare_0(mask) == 0) {
return;
}

const bitCapIntOcl oclMask = (bitCapIntOcl)mask;
if (oclMask >= maxQPowerOcl) {
throw std::invalid_argument("QEngineOCL::BitMask mask out-of-bounds!");
}

CHECK_ZERO_SKIP();

const bitCapIntOcl nPhases = pow2Ocl(n);
const real1_f radians[1] = { -PI_R1 / pow2Ocl(n - 1U) };

if (isPowerOfTwo(mask)) {
const complex phaseFac = std::polar(ONE_R1, radians[0]);
Phase(ONE_CMPLX, phaseFac, log2(mask));
return;
}

const bitCapIntOcl bciArgs[BCI_ARG_LEN]{ maxQPowerOcl, oclMask, nPhases, 0U, 0U, 0U, 0U, 0U, 0U, 0U };
PoolItemPtr poolItem = GetFreePoolItem();

{
EventVecPtr waitVec = ResetWaitEvents();

cl::Event writeIntArgsEvent;
DISPATCH_TEMP_WRITE(waitVec, *(poolItem->ulongBuffer), sizeof(bitCapIntOcl) * 3, bciArgs, writeIntArgsEvent);

cl::Event writeRealArgsEvent;
DISPATCH_LOC_WRITE(*(poolItem->realBuffer), sizeof(real1), radians, writeRealArgsEvent);

writeIntArgsEvent.wait();
writeRealArgsEvent.wait();
wait_refs.clear();
}

const size_t ngc = FixWorkItemCount(bciArgs[0], nrmGroupCount);
const size_t ngs = FixGroupSize(ngc, nrmGroupSize);
QueueCall(OCL_API_PHASE_MASK, ngc, ngs, { stateBuffer, poolItem->ulongBuffer, poolItem->realBuffer });
}

void QEngineOCL::Apply2x2(bitCapIntOcl offset1, bitCapIntOcl offset2, const complex* mtrx, bitLenInt bitCount,
const bitCapIntOcl* qPowersSorted, bool doCalcNorm, SPECIAL_2X2 special, real1_f norm_thresh)
{
Expand Down
60 changes: 60 additions & 0 deletions src/qengine/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,66 @@ void QEngineCPU::PhaseParity(real1_f radians, bitCapInt mask)
});
}

void QEngineCPU::PhaseRootNMask(bitLenInt n, bitCapInt mask)
{
if (bi_compare(mask, maxQPower) >= 0) {
throw std::invalid_argument("QEngineCPU::PhaseRootNMask mask out-of-bounds!");
}
if (n > sizeof(bitCapIntOcl)) {
throw std::invalid_argument("QEngineCPU::PhaseRootNMask: power of 2 out-of-bounds");
}

CHECK_ZERO_SKIP();

if (bi_compare_0(mask) == 0) {
return;
}

if (n == 0) {
return;
}
if (n == 1) {
ZMask(mask);
return;
}

const real1_f radians = -PI_R1 / pow2Ocl(n - 1U);

if (isPowerOfTwo(mask)) {
const complex phaseFac = std::polar(ONE_R1, radians);
Phase(ONE_CMPLX, phaseFac, log2(mask));
return;
}

if (stateVec->is_sparse()) {
QInterface::PhaseRootNMask(n, mask);
return;
}

Dispatch(maxQPowerOcl, [this, n, mask, radians] {
const bitCapIntOcl maskOcl = (bitCapIntOcl)mask;
const bitCapIntOcl nPhases = pow2Ocl(n);
ParallelFunc fn = [&](const bitCapIntOcl& lcv, const unsigned& cpu) {
bitCapIntOcl popCount = 0;
{
bitCapIntOcl v = lcv & maskOcl;
while (v) {
popCount += v & 1;
v >>= 1;
}
}

const bitCapIntOcl nPhaseSteps = popCount % nPhases;
if (nPhaseSteps != 0) {
const complex phaseFac = std::polar(ONE_R1, radians * nPhaseSteps);
stateVec->write(lcv, phaseFac * stateVec->read(lcv));
}
};

par_for(0U, maxQPowerOcl, fn);
});
}

void QEngineCPU::UniformlyControlledSingleBit(const std::vector<bitLenInt>& controls, bitLenInt qubitIndex,
const complex* mtrxs, const std::vector<bitCapInt>& mtrxSkipPowers, bitCapInt mtrxSkipValueMask)
{
Expand Down
10 changes: 10 additions & 0 deletions src/qinterface/gates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ void QInterface::ZMask(bitCapInt mask)
}
}

void QInterface::PhaseRootNMask(bitLenInt n, bitCapInt mask)
{
bitCapInt v = mask;
while (bi_compare_0(mask) != 0) {
v = v & (v - ONE_BCI);
PhaseRootN(n, log2(mask ^ v));
mask = v;
}
}

void QInterface::Swap(bitLenInt q1, bitLenInt q2)
{
if (q1 == q2) {
Expand Down
24 changes: 24 additions & 0 deletions test/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,30 @@ TEST_CASE_METHOD(QInterfaceTestFixture, "test_zmask")
REQUIRE_THAT(qftReg, HasProbability(0, 20, 0x80001));
}

TEST_CASE_METHOD(QInterfaceTestFixture, "test_phaserootnmask")
{
constexpr BIG_INTEGER_WORD ket = 14062;
constexpr BIG_INTEGER_WORD masks[6] = { 8, 3097, 22225, 16051, 62894, 49134 };
constexpr uint16_t n = 3;
const uint16_t modulus = pow2Ocl(n);
// phaseCounts[ii] = popcount(ket & masks[ii])
constexpr uint16_t phaseCounts[6] = { 1, 2, 5, 7, 8, 10 };

qftReg->SetPermutation(ket);
REQUIRE_THAT(qftReg, HasProbability(0, 20, ket));

for (int ii = 0; ii < 6; ii++) {
const real1_f angle = -PI_R1 * (phaseCounts[ii] % modulus) / pow2Ocl(n - 1U);
const complex expectedPhaseFactor = std::polar(ONE_R1, angle);
const complex amp_before = qftReg->GetAmplitude(ket);

qftReg->PhaseRootNMask(n, masks[ii]);
const complex amp_after = qftReg->GetAmplitude(ket);

REQUIRE_CMPLX(amp_after / amp_before, expectedPhaseFactor);
}
}

TEST_CASE_METHOD(QInterfaceTestFixture, "test_approxcompare")
{
qftReg =
Expand Down
Loading