Skip to content

Commit

Permalink
V1.1.0 contrib bug fixes (homenc#383)
Browse files Browse the repository at this point in the history
HElib 1.1.0 Beta 0, August 2020
===============================
(tagged as v1.1.0-beta.0)

* Throw `LogicError` on logic failure
* Fix PAlgebra equality operator
* Initialize `cost` for UpperMemoEntry and LowerMemoEntry to 0
* Change `found`  to bool in function makeMask
* `unsigned long` cannot be guaranteed to be `size_t`
* Typo in comments

Co-authored-by: Dmitry Tsarevich <dimhotepus@gmail.com>
Co-authored-by: Jack L Crawford <Jack.Crawford@ibm.com>
  • Loading branch information
3 people authored and GitHub Enterprise committed Aug 17, 2020
1 parent a3be068 commit 6b7a3e6
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 54 deletions.
82 changes: 44 additions & 38 deletions examples/BGV_country_db_lookup/BGV_country_db_lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
// It implements a very simple homomorphic encryption based
// db search algorithm for demonstration purposes.

// This country lookup example is derived from the BGV database demo
// code originally written by Jack Crawford for a lunch and learn
// This country lookup example is derived from the BGV database demo
// code originally written by Jack Crawford for a lunch and learn
// session at IBM Research (Hursley) in 2019.
// The original example code ships with HElib and can be found at
// https://github.com/homenc/HElib/tree/master/examples/BGV_database_lookup
Expand All @@ -40,26 +40,29 @@ void printPoly(NTL::ZZX& poly)
}

// Utility function to read <K,V> CSV data from file
std::vector<std::pair<std::string, std::string>> read_csv(std::string filename) {
std::vector<std::pair<std::string, std::string>> read_csv(std::string filename)
{
std::vector<std::pair<std::string, std::string>> dataset;
std::ifstream data_file(filename);

if(!data_file.is_open()) throw std::runtime_error("Error: This example failed trying to open the data file: " + filename + "\n Please check this file exists and try again.");
if (!data_file.is_open())
throw std::runtime_error(
"Error: This example failed trying to open the data file: " + filename +
"\n Please check this file exists and try again.");

std::vector<std::string> row;
std::string line, entry, temp;
std::vector<std::string> row;
std::string line, entry, temp;

if(data_file.good())
{
if (data_file.good()) {
// Read each line of file
while(std::getline(data_file, line)){
while (std::getline(data_file, line)) {
row.clear();
std::stringstream ss(line);
while (getline(ss, entry, ',')) {
row.push_back(entry);
}
// Add key value pairs to dataset
dataset.push_back(std::make_pair(row[0],row[1]));
dataset.push_back(std::make_pair(row[0], row[1]));
}
}

Expand Down Expand Up @@ -88,7 +91,7 @@ int main(int argc, char* argv[])
// Size of NTL thread pool (default =1)
unsigned long nthreads = 1;
// input database file name
std::string db_filename="./countries_dataset.csv";
std::string db_filename = "./countries_dataset.csv";
// debug output (default no debug output)
bool debug = false;

Expand All @@ -99,7 +102,9 @@ int main(int argc, char* argv[])
amap.arg("bits", bits, "# of bits in the modulus chain");
amap.arg("c", c, "# fo columns of Key-Switching matrix");
amap.arg("nthreads", nthreads, "Size of NTL thread pool");
amap.arg("db_filename", db_filename, "Qualified name for the database filename");
amap.arg("db_filename",
db_filename,
"Qualified name for the database filename");
amap.toggle().arg("-debug", debug, "Toggle debug output", "");
amap.parse(argc, argv);

Expand All @@ -125,15 +130,14 @@ int main(int argc, char* argv[])
HELIB_NTIMER_START(timer_Context);
helib::Context context(m, p, r);
HELIB_NTIMER_STOP(timer_Context);

// Modify the context, adding primes to the modulus chain
// This defines the ciphertext space
std::cout << "\nBuilding modulus chain ... ";
HELIB_NTIMER_START(timer_CHAIN);
helib::buildModChain(context, bits, c);
HELIB_NTIMER_STOP(timer_CHAIN);


// Secret key management
std::cout << "\nCreating Secret Key ...";
HELIB_NTIMER_START(timer_SecKey);
Expand All @@ -154,27 +158,28 @@ int main(int argc, char* argv[])
HELIB_NTIMER_START(timer_PubKey);
const helib::PubKey& public_key = secret_key;
HELIB_NTIMER_STOP(timer_PubKey);

// Get the EncryptedArray of the context
const helib::EncryptedArray& ea = *(context.ea);

// Print the context
std::cout << std::endl;
if (debug)
context.zMStar.printout();
context.zMStar.printout();

// Print the security level
// Note: This will be negligible to improve performance time.
std::cout << "\n***Security Level: " << context.securityLevel()
<< " *** Negligible for this example ***" << std::endl;

// Get the number of slot (phi(m))
long nslots = ea.size();
std::cout << "\nNumber of slots: " << nslots << std::endl;

/************ Read in the database ************/
std::vector<std::pair<std::string, std::string>> country_db = read_csv(db_filename);

std::vector<std::pair<std::string, std::string>> country_db =
read_csv(db_filename);

// Convert strings into numerical vectors
std::cout << "\n---Initializing the encrypted key,value pair database ("
<< country_db.size() << " entries)...";
Expand All @@ -187,28 +192,27 @@ int main(int argc, char* argv[])
std::vector<std::pair<helib::Ptxt<helib::BGV>, helib::Ptxt<helib::BGV>>>
country_db_ptxt;
for (const auto& country_capital_pair : country_db) {
if (debug) {
if (debug) {
std::cout << "\t\tname_addr_pair.first size = "
<< country_capital_pair.first.size() << " (" << country_capital_pair.first
<< ")"
<< country_capital_pair.first.size() << " ("
<< country_capital_pair.first << ")"
<< "\tname_addr_pair.second size = "
<< country_capital_pair.second.size() << " (" << country_capital_pair.second
<< ")" << std::endl;
}
<< country_capital_pair.second.size() << " ("
<< country_capital_pair.second << ")" << std::endl;
}

helib::Ptxt<helib::BGV> country(context);
// std::cout << "\tname size = " << country.size() << std::endl;
for (long i = 0; i < country_capital_pair.first.size(); ++i)
helib::Ptxt<helib::BGV> country(context);
// std::cout << "\tname size = " << country.size() << std::endl;
for (long i = 0; i < country_capital_pair.first.size(); ++i)
country.at(i) = country_capital_pair.first[i];

helib::Ptxt<helib::BGV> capital(context);
for (long i = 0; i < country_capital_pair.second.size(); ++i)
helib::Ptxt<helib::BGV> capital(context);
for (long i = 0; i < country_capital_pair.second.size(); ++i)
capital.at(i) = country_capital_pair.second[i];
country_db_ptxt.emplace_back(std::move(country), std::move(capital));
country_db_ptxt.emplace_back(std::move(country), std::move(capital));
}
HELIB_NTIMER_STOP(timer_PtxtCountryDB);


// Encrypt the Country DB
std::cout << "Encrypting the database..." << std::endl;
HELIB_NTIMER_START(timer_CtxtCountryDB);
Expand All @@ -218,9 +222,10 @@ int main(int argc, char* argv[])
helib::Ctxt encrypted_capital(public_key);
public_key.Encrypt(encrypted_country, country_capital_pair.first);
public_key.Encrypt(encrypted_capital, country_capital_pair.second);
encrypted_country_db.emplace_back(std::move(encrypted_country), std::move(encrypted_capital));
encrypted_country_db.emplace_back(std::move(encrypted_country),
std::move(encrypted_capital));
}

HELIB_NTIMER_STOP(timer_CtxtCountryDB);

// Print DB Creation Timers
Expand Down Expand Up @@ -260,8 +265,6 @@ int main(int argc, char* argv[])
public_key.Encrypt(query, query_ptxt);
HELIB_NTIMER_STOP(timer_EncryptQuery);



/************ Perform the database search ************/

HELIB_NTIMER_START(timer_QuerySearch);
Expand All @@ -287,7 +290,7 @@ int main(int argc, char* argv[])
// from using the STL and do not use std::accumulate
helib::Ctxt value = mask[0];
for (int i = 1; i < mask.size(); i++)
value += mask[i];
value += mask[i];

HELIB_NTIMER_STOP(timer_QuerySearch);

Expand All @@ -314,7 +317,10 @@ int main(int argc, char* argv[])
}

if (string_result.at(0) == 0x00) {
string_result = "Country name not in the database.\n*** Please make sure to enter the name of a European Country\n*** with the first letter in upper case.";
string_result =
"Country name not in the database."
"\n*** Please make sure to enter the name of a European Country"
"\n*** with the first letter in upper case.";
}
std::cout << "\nQuery result: " << string_result << std::endl;
helib::printNamedTimer(std::cout, "timer_TotalQuery");
Expand Down
6 changes: 4 additions & 2 deletions examples/BGV_packed_arithmetic/BGV_packed_arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ int main(int argc, char* argv[])
// "entry-wise".

// Square the ciphertext
// [0] [1] [2] [3] [4] ... [nslots-1] -> [0] [1] [4] [9] [16] ... [(nslots-1)*(nslots-1)]
// [0] [1] [2] [3] [4] ... [nslots-1]
// -> [0] [1] [4] [9] [16] ... [(nslots-1)*(nslots-1)]
ctxt.multiplyBy(ctxt);
// Plaintext version
ptxt.multiplyBy(ptxt);
Expand Down Expand Up @@ -158,7 +159,8 @@ int main(int argc, char* argv[])
ptxt.addConstant(NTL::ZZX(1l));

// And multiply by constants
// [1] [1] [1] ... [1] [1] -> [1*1] [1*1] [1*1] ... [1*1] [1*1] = [1] [1] [1] ... [1] [1]
// [1] [1] [1] ... [1] [1]
// -> [1*1] [1*1] [1*1] ... [1*1] [1*1] = [1] [1] [1] ... [1] [1]
ctxt *= NTL::ZZX(1l);
// Plaintext version
ptxt *= NTL::ZZX(1l);
Expand Down
2 changes: 1 addition & 1 deletion include/helib/PAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ class PAlgebraMod

bool operator==(const PAlgebraMod& other) const
{
return getZMStar() == getZMStar() && getR() == other.getR();
return getZMStar() == other.getZMStar() && getR() == other.getR();
}
// comparison

Expand Down
2 changes: 1 addition & 1 deletion src/NumbTh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ NTL::ZZX Cyclotomic(long n)
// less than 2^27 in absolute value.
// So we compute the coefficients using 32-bit arithmetic.

// NOTE: _ntl_uint32 is either int or long.
// NOTE: _ntl_uint32 is either unsigned int or unsigned long.

NTL::Vec<_ntl_uint32> A;
A.SetLength(D + 1);
Expand Down
6 changes: 3 additions & 3 deletions src/OptimizePermutations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class BenesMemoEntry
solution = _solution;
}

BenesMemoEntry() {}
BenesMemoEntry() : cost(0) {}
};

typedef std::
Expand Down Expand Up @@ -510,7 +510,7 @@ class LowerMemoEntry
solution = _solution;
}

LowerMemoEntry() {}
LowerMemoEntry() : cost(0) {}
};

typedef std::
Expand Down Expand Up @@ -611,7 +611,7 @@ class UpperMemoEntry
solution = _solution;
}

UpperMemoEntry() {}
UpperMemoEntry() : cost(0) {}
};

typedef std::
Expand Down
2 changes: 1 addition & 1 deletion src/PermNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ static std::pair<long, bool> makeMask(std::vector<long>& mask,
NTL::Vec<long>& haystack,
long needle)
{
long found = false;
bool found = false;
long fstNonZeroIdx = -1;
for (long i = 0; i < (long)mask.size(); i++) {
if (haystack[i] == needle) { // found a needle
Expand Down
2 changes: 1 addition & 1 deletion src/Ptxt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ long Ptxt<Scheme>::coordToIndex(const std::vector<long>& coords)
{
const PAlgebra& zMStar = context->zMStar;
assertEq<LogicError>(coords.size(),
static_cast<unsigned long>(zMStar.numOfGens()),
static_cast<std::size_t>(zMStar.numOfGens()),
"Coord must have same size as hypercube structure");
long index = 0;
// Convert the coordinates into its corresponding index by computing the
Expand Down
14 changes: 7 additions & 7 deletions src/norms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ static double basic_embeddingLargestCoeff(const std::vector<double>& f,
long sz = f.size();

if (sz > m)
LogicError("vector too big f canonicalEmbedding");
throw LogicError("vector too big f canonicalEmbedding");

std::vector<cx_double> buf(m);
for (long i : range(0, sz))
Expand Down Expand Up @@ -172,7 +172,7 @@ static double half_embeddingLargestCoeff(const std::vector<double>& f,
long sz = f.size();

if (sz > m / 2)
LogicError("vector too big f canonicalEmbedding");
throw LogicError("vector too big f canonicalEmbedding");

const half_FFT& hfft = palg.getHalfFFTInfo();
const cx_double* pow = &hfft.pow[0];
Expand Down Expand Up @@ -209,7 +209,7 @@ static double quarter_embeddingLargestCoeff(const std::vector<double>& f,
long sz = f.size();

if (sz > m / 2)
LogicError("vector too big f canonicalEmbedding");
throw LogicError("vector too big f canonicalEmbedding");

const quarter_FFT& qfft = palg.getQuarterFFTInfo();
const cx_double* pow1 = &qfft.pow1[0];
Expand Down Expand Up @@ -275,7 +275,7 @@ static void basic_embeddingLargestCoeff_x2(double& norm1,
long sz2 = f2.size();

if (sz1 > m || sz2 > m)
LogicError("vector too big in canonicalEmbedding");
throw LogicError("vector too big in canonicalEmbedding");

long sz_max = std::max(sz1, sz2);
long sz_min = std::min(sz1, sz2);
Expand Down Expand Up @@ -348,7 +348,7 @@ static void half_embeddingLargestCoeff_x2(double& norm1,
long sz2 = f2.size();

if (sz1 > m / 2 || sz2 > m / 2)
LogicError("vector too big in canonicalEmbedding");
throw LogicError("vector too big in canonicalEmbedding");

long sz_max = std::max(sz1, sz2);
long sz_min = std::min(sz1, sz2);
Expand Down Expand Up @@ -583,7 +583,7 @@ void CKKS_embedInSlots(zzX& f,
long m = palg.getM();

if (!(palg.getP() == -1 && palg.getPow2() >= 2))
LogicError("bad args to CKKS_canonicalEmbedding");
throw LogicError("bad args to CKKS_canonicalEmbedding");

std::vector<cx_double> buf(m / 2, cx_double(0));
for (long i : range(m / 4)) {
Expand Down Expand Up @@ -645,7 +645,7 @@ canonicalEmbedding(std::vector<cx_double>& v,
long m = palg.getM();

if (long(in.size()) > m)
LogicError("std::vector too big in canonicalEmbedding");
throw LogicError("std::vector too big in canonicalEmbedding");

vector<cx_double> buf(m);
for (long i: range(in.size())) buf[i] = in[i];
Expand Down

0 comments on commit 6b7a3e6

Please sign in to comment.