Skip to content

Commit

Permalink
optimization + unit test done
Browse files Browse the repository at this point in the history
  • Loading branch information
nitrieu committed May 7, 2018
1 parent 0511c97 commit b4d2942
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 135 deletions.
29 changes: 20 additions & 9 deletions Tests/PSU_Tests.cpp
Expand Up @@ -532,7 +532,7 @@ namespace tests_libOTe
void PSU_Test_Impl()
{
setThreadName("Sender");
u64 setSize = 1 << 7, psiSecParam = 40, numThreads(2);
u64 setSize = 1 << 10, psiSecParam = 40, numThreads(1);

PRNG prng0(_mm_set_epi32(4253465, 3434565, 234435, 23987045));
PRNG prng1(_mm_set_epi32(4253465, 3434565, 234435, 23987025));
Expand All @@ -542,12 +542,14 @@ namespace tests_libOTe
for (u64 i = 0; i < setSize; ++i)
{
sendSet[i] = prng0.get<block>();
recvSet[i] = prng0.get<block>();
recvSet[i] = sendSet[i];// prng0.get<block>();
}
sendSet[0] = recvSet[0];
sendSet[2] = recvSet[2];
std::cout << "intersection: " << sendSet[0] << "\n";
std::cout << "intersection: " << sendSet[2] << "\n";
//std::random_shuffle(sendSet.begin(), sendSet.begin(), prng0);

sendSet[0] = prng0.get<block>();
sendSet[1] = prng0.get<block>();
std::cout << "disjonted: " << sendSet[0] << " vs " << recvSet[0] << "\n";
std::cout << "disjonted: " << sendSet[1] << " vs " << recvSet[1]<< "\n";


// set up networking
Expand All @@ -567,16 +569,25 @@ namespace tests_libOTe
KrtwSender sender;
KrtwReceiver recv;
auto thrd = std::thread([&]() {
recv.init(40, prng1, recvSet, recvChls);
recv.init(setSize, setSize,40, prng1, recvChls);
recv.output(recvSet, recvChls);

});

sender.init(40, prng0, sendSet, sendChls);
sender.init(setSize, setSize, 40, prng0,sendChls);
sender.output(sendSet, sendChls);
thrd.join();

/*std::cout << "=======sender.simple.print(3);===========\n";
sender.simple.print();
std::cout << "=======recv.simple.print(3)===========\n";
recv.simple.print();*/

std::cout << "recv.mDisjointedOutput.size(): " << recv.mDisjointedOutput.size() << std::endl;
for (u64 i = 0; i < recv.mDisjointedOutput.size(); ++i)//thrds.size()
{
std::cout << "#id: " << recv.mDisjointedOutput[i] << std::endl;
}


for (u64 i = 0; i < numThreads; ++i)
Expand Down Expand Up @@ -694,7 +705,7 @@ namespace tests_libOTe
for (u64 i = 0; i < 10; ++i)
{
poly.evalPolynomial(coeffs, setX[i], y1);
if(memcmp((u8*)&y1, (u8*)&y,64/8))
if(!memcmp((u8*)&y1, (u8*)&y,64/8))
std::cout << y << "\t" << y1 << std::endl;

}
Expand Down
41 changes: 28 additions & 13 deletions frontend/main.cpp
Expand Up @@ -74,6 +74,9 @@ using namespace osuCrypto;
#include <thread>
#include <vector>

u64 disjontedSetSize = 10;
bool isTest=false;

template<typename ... Args>
std::string string_format(const std::string& format, Args ... args)
{
Expand Down Expand Up @@ -103,7 +106,7 @@ void Sender(u64 setSize, span<block> inputs, u64 numThreads)

gTimer.reset();
gTimer.setTimePoint("s start");
sender.init(40, prng0, inputs, sendChls);
sender.init(setSize, inputs.size(), 40, prng0, sendChls);
gTimer.setTimePoint("s offline");

/*std::cout << sender.mBaseOTSend[0][0] << "\t";
Expand Down Expand Up @@ -140,7 +143,7 @@ void Receiver(u64 setSize, span<block> inputs,u64 numThreads)
KrtwReceiver recv;
gTimer.reset();
gTimer.setTimePoint("r start");
recv.init(40, prng1, inputs, recvChls);
recv.init(setSize, inputs.size(), 40, prng1, recvChls);

/*std::cout << recv.mBaseOTRecv[0] << "\n";
std::cout << recv.mBaseOTSend[0][0] << "\t";
Expand All @@ -160,8 +163,13 @@ void Receiver(u64 setSize, span<block> inputs,u64 numThreads)
recvChls[g].resetStats();
}

std::cout << " Total Comm = " << string_format("%5.2f", (dataRecv + dataSent) / std::pow(2.0, 20)) << " MB\n";
std::cout << "Total Comm = " << string_format("%5.2f", (dataRecv + dataSent) / std::pow(2.0, 20)) << " MB\n";

if (isTest)
{
std::cout << "recv.mDisjointedOutput.size(): " << recv.mDisjointedOutput.size() << std::endl;
std::cout << "expectedDisjontedSetSize: " << disjontedSetSize << std::endl;
}

for (u64 i = 0; i < numThreads; ++i)
recvChls[i].close();
Expand Down Expand Up @@ -207,7 +215,7 @@ void PSU_Test_Impl()
KrtwSender sender;
KrtwReceiver recv;
auto thrd = std::thread([&]() {
recv.init(40, prng1, recvSet, recvChls);
recv.init(setSize, recvSet.size(), 40, prng1, recvChls);

std::cout << recv.mBaseOTRecv[0] << "\n";

Expand All @@ -218,7 +226,7 @@ void PSU_Test_Impl()

});

sender.init(40, prng0, sendSet, sendChls);
sender.init(setSize, sendSet.size(), 40, prng0, sendChls);


std::cout << sender.mBaseOTSend[0][0] << "\t";
Expand Down Expand Up @@ -254,7 +262,8 @@ void usage(const char* argv0)

int main(int argc, char** argv)
{
u64 setSize = 1 << 16, numThreads = 2;

u64 setSize = 1 << 12, numThreads = 1;


if (argv[3][0] == '-' && argv[3][1] == 'n'
Expand All @@ -264,6 +273,7 @@ int main(int argc, char** argv)
numThreads = atoi(argv[6]);
}


std::cout << "SetSize: " << setSize << " vs " << setSize << " | numThreads: " << numThreads << "\t";

PRNG prng0(_mm_set_epi32(4253465, 3434565, 234435, 23987045));
Expand All @@ -273,15 +283,17 @@ int main(int argc, char** argv)
for (u64 i = 0; i < setSize; ++i)
{
sendSet[i] = prng0.get<block>();
recvSet[i] = prng0.get<block>();
recvSet[i] = sendSet[i];
}
sendSet[0] = recvSet[0];
sendSet[2] = recvSet[2];
/*std::cout << "intersection: " << sendSet[0] << "\n";
std::cout << "intersection: " << sendSet[2] << "\n";*/


for (u64 i = 0; i < disjontedSetSize; ++i)
sendSet[i] = prng0.get<block>();

//std::random_shuffle(sendSet.begin(), sendSet.begin(), prng0);


#if 0
isTest = true;
std::thread thrd = std::thread([&]() {
Sender(setSize, sendSet, numThreads);
});
Expand All @@ -294,6 +306,7 @@ int main(int argc, char** argv)

if (argv[1][0] == '-' && argv[1][1] == 't') {

isTest = true;
std::thread thrd = std::thread([&]() {
Sender(setSize,sendSet, numThreads);
});
Expand All @@ -304,10 +317,12 @@ int main(int argc, char** argv)

}
else if (argv[1][0] == '-' && argv[1][1] == 'r' && atoi(argv[2]) == 0) {

Sender(setSize, sendSet, numThreads);
}
else if (argv[1][0] == '-' && argv[1][1] == 'r' && atoi(argv[2]) == 1) {
Receiver(setSize, sendSet, numThreads);

Receiver(setSize, recvSet, numThreads);
}
else {
usage(argv[0]);
Expand Down
93 changes: 60 additions & 33 deletions libPSU/PSU/KrtwReceiver.cpp
Expand Up @@ -6,13 +6,12 @@
#include <cryptoTools/Crypto/Commit.h>
#include <cryptoTools/Network/Channel.h>
#include "libPSU/PsuDefines.h"
#include "Tools/SimpleIndex.h"

using namespace std;

namespace osuCrypto
{
void KrtwReceiver::init(u64 psiSecParam, PRNG & prng, span<block> inputs, span<Channel> chls)
void KrtwReceiver::init(u64 myInputSize, u64 theirInputSize, u64 psiSecParam, PRNG & prng, span<Channel> chls)
{
mPsiSecParam = psiSecParam;
mPrng.SetSeed(prng.get<block>());
Expand All @@ -34,36 +33,44 @@ namespace osuCrypto
baseOTs.send(mBaseOTSend, mPrng, chls[0], 1);
recvOprf.setBaseOts(mBaseOTSend);


simple.init(myInputSize);
Ss.resize(simple.mNumBins);

theirMaxBinSize = simple.mMaxBinSize; //assume same set size, sender has mMaxBinSize, receiver has mMaxBinSize+1

polyMaskBytes = (mPsiSecParam + log2(pow(simple.mMaxBinSize, 2)*simple.mNumBins) + 7) / 8;

for (u64 i = 0; i < simple.mNumBins; i++)
{
Ss[i].resize(theirMaxBinSize);
for (u64 j = 0; j < theirMaxBinSize; j++)
{
Ss[i][j] = ZeroBlock;
block tem = mPrng.get<block>();
memcpy((u8*)&Ss[i][j],(u8*)&tem, polyMaskBytes);
}
}

/* for (u64 j = 0; j < 6; j++)
std::cout << "Ss [3][" << j << "]: " << Ss[3][j] << "\n";
*/



}
void KrtwReceiver::output(span<block> inputs, span<Channel> chls)
{
u64 numThreads(chls.size());
const bool isMultiThreaded = numThreads > 1;

// std::cout << "Receiver: numThreads" << numThreads << "\n";

std::mutex mtx;

SimpleIndex simple;
simple.init(inputs.size(),true);
//simple.print();

//std::cout << "Receiver: " << simple.mMaxBinSize << "\t " <<simple.mNumBins<< std::endl ;

u64 theirMaxBinSize = simple.mMaxBinSize - 1; //assume same set size, sender has mMaxBinSize, receiver has mMaxBinSize+1
u64 numOTs = simple.mNumBins*(theirMaxBinSize);

std::vector<block> coeffs;

std::vector<std::vector<block>> Ss(simple.mNumBins);
for (u64 i = 0; i < simple.mNumBins; i++)
{
Ss[i].resize(theirMaxBinSize);
for (u64 j = 0; j < theirMaxBinSize; j++)
Ss[i][j] = mPrng.get<block>();
}

#ifdef DEBUG
std::cout << IoStream::lock << "mBins[1].items[1] " << simple.mBins[1].items[1] << std::endl << IoStream::unlock;
Expand Down Expand Up @@ -96,7 +103,6 @@ namespace osuCrypto
std::cout << IoStream::lock << recvOTMsg[0] << std::endl << IoStream::unlock;
#endif
//poly
u64 polyMaskBytes = (mPsiSecParam + log2(pow(simple.mMaxBinSize + 1,2)*simple.mNumBins) + 7) / 8;

auto routine = [&](u64 t)
{
Expand All @@ -122,7 +128,9 @@ namespace osuCrypto

sendOprf.recvCorrection(chl, curStepSize*theirMaxBinSize); //OPRF

std::vector<u8> sendBuff(curStepSize*theirMaxBinSize*(simple.mMaxBinSize + 1)*polyMaskBytes);
std::vector<u8> sendBuff(curStepSize*theirMaxBinSize*(simple.mMaxBinSize)*polyMaskBytes);

u64 iterSend = 0;

//==========================PMT==========================
for (u64 k = 0; k < curStepSize; ++k)
Expand All @@ -133,28 +141,40 @@ namespace osuCrypto
{

//std::vector<block> setY(simple.mBins[binIdx].mBinRealSizes, Ss[binIdx][itemTheirIdx]);
std::vector<block>sendEncoding(simple.mBins[binIdx].mBinRealSizes);
std::vector<block>sendEncoding(simple.mBins[binIdx].mBinRealSizes+1);
u64 idxBot = simple.mBins[binIdx].mBinRealSizes;

for (u64 itemIdx = 0; itemIdx < simple.mBins[binIdx].mBinRealSizes; ++itemIdx) //compute many F(k,xi)
for (u64 itemIdx = 0; itemIdx < simple.mBins[binIdx].mBinRealSizes; itemIdx++) //compute many F(k,xi)
{
sendOprf.encode(binIdx*theirMaxBinSize + itemTheirIdx
, &simple.mBins[binIdx].items[itemIdx], (u8*)&sendEncoding[itemIdx], sizeof(block));
}


//############ Global Item \bot ####################
sendOprf.encode(binIdx*theirMaxBinSize + itemTheirIdx
, &AllOneBlock, (u8*)&sendEncoding[idxBot], sizeof(block));

//setY.emplace_back(mPrng.get<block>()); //add randome point
//sendEncoding.emplace_back(mPrng.get<block>());

//poly
#ifdef _MSC_VER
std::cout << IoStream::lock;
poly.getBlkCoefficients(simple.mMaxBinSize + 1, sendEncoding, Ss[binIdx][itemTheirIdx], coeffs);
poly.getBlkCoefficients(simple.mMaxBinSize-1, sendEncoding, Ss[binIdx][itemTheirIdx], coeffs);
std::cout << IoStream::unlock;
/*coeffs.resize(simple.mMaxBinSize);
for (u64 c = 0; c < coeffs.size(); ++c)
coeffs[c] = ZeroBlock;*/

#else
poly.getBlkCoefficients(simple.mMaxBinSize + 1, sendEncoding, Ss[binIdx][itemTheirIdx], coeffs);
poly.getBlkCoefficients(simple.mMaxBinSize-1, sendEncoding, Ss[binIdx][itemTheirIdx], coeffs);

#endif


for (u64 c = 0; c < coeffs.size(); ++c)
memcpy(sendBuff.data() + (k*itemTheirIdx*(simple.mMaxBinSize + 1) + c)* polyMaskBytes, (u8*)&coeffs[c], polyMaskBytes);
{
memcpy(sendBuff.data() + iterSend, (u8*)&coeffs[c], polyMaskBytes);
iterSend += polyMaskBytes;
}

//std::cout << IoStream::lock <<"r "<< binIdx << "\t" << itemTheirIdx << std::endl << IoStream::unlock;
}
Expand All @@ -164,6 +184,7 @@ namespace osuCrypto


//==========================PEQT==========================
#if 1
std::vector<block> recvEncoding(curStepSize*theirMaxBinSize);

for (u64 k = 0; k < curStepSize; ++k)
Expand Down Expand Up @@ -273,24 +294,31 @@ namespace osuCrypto
block psuItem;
memcpy((u8*)&psuItem, recvBuff.data() + (k*theirMaxBinSize + itemTheirIdx)* maskOTlength, maskOTlength);

psuItem = psuItem + recvOTMsg[binIdx*theirMaxBinSize + itemTheirIdx];

psuItem = psuItem^recvOTMsg[binIdx*theirMaxBinSize + itemTheirIdx];

/* std::cout << "psuItem: " << psuItem << std::endl;
std::cout << "itemTheirIdx: " << itemTheirIdx << std::endl;
std::cout << "binIdx: " << binIdx << std::endl;
*/
if (isMultiThreaded)
{
std::lock_guard<std::mutex> lock(mtx);
PsuOutput.emplace_back(psuItem);
{
mDisjointedOutput.emplace_back(psuItem);
/*std::cout << "itemTheirIdx: " << itemTheirIdx << std::endl;
std::cout << "binIdx: " << binIdx << std::endl;*/
}
}
else
{
PsuOutput.emplace_back(psuItem);
mDisjointedOutput.emplace_back(psuItem);
}
}

}
}
#endif

#endif
}
};

Expand All @@ -306,7 +334,6 @@ namespace osuCrypto
for (auto& thrd : thrds)
thrd.join();

std::cout << "PsuOutput.size() " <<PsuOutput.size() << std::endl;

}
}

0 comments on commit b4d2942

Please sign in to comment.