Skip to content

Commit

Permalink
add LSH based KNN
Browse files Browse the repository at this point in the history
  • Loading branch information
Saurabh7 committed Aug 11, 2016
1 parent c45f8b0 commit 05cbf89
Show file tree
Hide file tree
Showing 33 changed files with 6,136 additions and 18 deletions.
21 changes: 21 additions & 0 deletions src/shogun/lib/external/falconn/LICENSE.txt
@@ -0,0 +1,21 @@
The MIT License (MIT)

Copyright (c) 2015 Alexandr Andoni, Piotr Indyk, Thijs Laarhoven, Ilya Razenshteyn, and Ludwig Schmidt

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
178 changes: 178 additions & 0 deletions src/shogun/lib/external/falconn/core/bit_packed_flat_hash_table.h
@@ -0,0 +1,178 @@
#ifndef __BIT_PACKED_FLAT_HASH_TABLE_H__
#define __BIT_PACKED_FLAT_HASH_TABLE_H__

#include <algorithm>
#include <vector>

#include "bit_packed_vector.h"
#include "hash_table_helpers.h"
#include "math_helpers.h"

namespace falconn {
namespace core {

class BitPackedFlatHashTableError : public HashTableError {
public:
BitPackedFlatHashTableError(const char* msg) : HashTableError(msg) {}
};

template <typename KeyType, typename ValueType = int_fast64_t,
typename IndexType = int_fast64_t>
class BitPackedFlatHashTable {
public:
class Factory {
public:
Factory(IndexType num_buckets, ValueType num_items)
: num_buckets_(num_buckets), num_items_(num_items) {
if (num_buckets_ < 1) {
throw BitPackedFlatHashTableError(
"Number of buckets must be at least 1.");
}
if (num_items_ < 1) {
throw BitPackedFlatHashTableError(
"Number of items must be at least 1.");
}
}

BitPackedFlatHashTable<KeyType, ValueType, IndexType>* new_hash_table() {
return new BitPackedFlatHashTable<KeyType, ValueType, IndexType>(
num_buckets_, num_items_);
}

private:
IndexType num_buckets_ = 0;
ValueType num_items_ = 0;
};

class Iterator {
public:
Iterator() : index_(0), parent_(nullptr) {}

Iterator(ValueType index, const BitPackedFlatHashTable* parent)
: index_(index), parent_(parent) {}

ValueType operator*() const { return parent_->indices_.get(index_); }

bool operator!=(const Iterator& iter) const {
if (parent_ != iter.parent_) {
return false;
} else {
return index_ != iter.index_;
}
}

bool operator==(const Iterator& iter) const { return !(*this != iter); }

Iterator& operator++() {
index_ += 1;
return *this;
}

private:
ValueType index_;
const BitPackedFlatHashTable* parent_;
};

BitPackedFlatHashTable(IndexType num_buckets, ValueType num_items)
: num_buckets_(num_buckets),
num_items_(num_items),
bucket_start_(num_buckets, log2ceil(num_items)),
indices_(num_items, log2ceil(num_items)) {
// printf("num_buckets = %d num_items_ = %d\n", num_buckets_, num_items_);
if (num_buckets_ < 1) {
throw BitPackedFlatHashTableError(
"Number of buckets must be at least 1.");
}
if (num_items_ < 1) {
throw BitPackedFlatHashTableError("Number of items must be at least 1.");
}
}

void add_entries(const std::vector<KeyType>& keys) {
if (entries_added_) {
throw BitPackedFlatHashTableError("Entries were already added.");
}
entries_added_ = true;
if (static_cast<ValueType>(keys.size()) != num_items_) {
throw BitPackedFlatHashTableError(
"Incorrect number of items in add_entries.");
}

KeyComparator comp(keys);
std::vector<ValueType> tmp_indices(keys.size());
for (IndexType ii = 0; ii < static_cast<IndexType>(tmp_indices.size());
++ii) {
if (static_cast<IndexType>(keys[ii]) >= num_buckets_ || keys[ii] < 0) {
throw BitPackedFlatHashTableError("Key value out of range.");
}
tmp_indices[ii] = ii;
}
std::sort(tmp_indices.begin(), tmp_indices.end(), comp);

for (IndexType ii = 0; ii < static_cast<IndexType>(tmp_indices.size());
++ii) {
indices_.set(ii, tmp_indices[ii]);
}

IndexType cur_index = 0;
std::vector<bool> bucket_empty(num_buckets_, true);

while (cur_index < static_cast<IndexType>(tmp_indices.size())) {
IndexType end_index = cur_index;
do {
end_index += 1;
} while (end_index < static_cast<IndexType>(tmp_indices.size()) &&
keys[tmp_indices[cur_index]] == keys[tmp_indices[end_index]]);

bucket_start_.set(keys[tmp_indices[cur_index]], cur_index);
bucket_empty[keys[tmp_indices[cur_index]]] = false;
cur_index = end_index;
}

if (bucket_empty[num_buckets_ - 1]) {
bucket_start_.set(num_buckets_ - 1, num_items_);
}
for (IndexType ii = num_buckets_ - 2; ii >= 0; --ii) {
if (bucket_empty[ii]) {
bucket_start_.set(ii, bucket_start_.get(ii + 1));
}
}
}

std::pair<Iterator, Iterator> retrieve(const KeyType& key) {
ValueType start = bucket_start_.get(key);
ValueType end = num_items_;
if (static_cast<IndexType>(key) < num_buckets_ - 1) {
end = bucket_start_.get(key + 1);
}
// printf("retrieve for key %u\n", key);
// printf(" start: %lld end %lld\n", start, end);
return std::make_pair(Iterator(start, this), Iterator(end, this));
}

private:
IndexType num_buckets_ = 0;
ValueType num_items_ = 0;
bool entries_added_ = false;

// start of the respective hash bucket
BitPackedVector<ValueType> bucket_start_;
// point indices
BitPackedVector<ValueType> indices_;

class KeyComparator {
public:
KeyComparator(const std::vector<KeyType>& keys) : keys_(keys) {}

bool operator()(IndexType ii, IndexType jj) {
return keys_[ii] < keys_[jj];
}

const std::vector<KeyType>& keys_;
};
};

} // namespace core
} // namespace falconn

#endif
157 changes: 157 additions & 0 deletions src/shogun/lib/external/falconn/core/bit_packed_vector.h
@@ -0,0 +1,157 @@
#ifndef __BIT_PACKED_VECTOR_H__
#define __BIT_PACKED_VECTOR_H__

#include <cstdint>
#include <vector>

#include "../falconn_global.h"

namespace falconn {
namespace core {

class BitPackedVectorError : public FalconnError {
public:
BitPackedVectorError(const char* msg) : FalconnError(msg) {}
};

template <typename DataType, typename StorageType = uint64_t,
typename IndexType = int_fast64_t>
class BitPackedVector {
public:
BitPackedVector(int_fast64_t num_items, int_fast64_t item_size)
: num_items_(num_items), item_size_(item_size) {
if (item_size > 8 * static_cast<int_fast64_t>(sizeof(DataType))) {
throw BitPackedVectorError(
"DataType too small for the number of bits "
"specified.");
}
if (item_size > num_bits_per_package_) {
throw BitPackedVectorError(
"Currently the item size must be at most the "
"data package size.");
}
if (num_items > std::numeric_limits<IndexType>::max()) {
throw BitPackedVectorError(
"IndexType too small for the vector size "
"specified.");
}
num_data_packets_ = (num_items_ * item_size_) / num_bits_per_package_;
if ((num_items_ * item_size_) % num_bits_per_package_ != 0) {
num_data_packets_ += 1;
}
data_.resize(num_data_packets_);
for (int_fast64_t ii = 0; ii < num_data_packets_; ++ii) {
data_[ii] = 0;
}
}

// For (potential) performance reasons, get() does no bounds checking.
DataType get(IndexType index) const {
int_fast64_t first_bit = index * item_size_;
int_fast64_t first_package = first_bit / num_bits_per_package_;
int_fast64_t offset_in_package =
first_bit - first_package * num_bits_per_package_;
StorageType result = data_[first_package];
// Move first bits to the beginning of the result.
result >>= offset_in_package;
// Check if we need to go to the next package to get all the bits.
// (Currently the class assumes that an item occupies at most two packages.)
int_fast64_t remaining_bits =
num_bits_per_package_ - (offset_in_package + item_size_);
if (remaining_bits >= 0) {
// Zero out the remaining bits
// printf("index %lld remaining_bits %lld result %llu\n", index,
// remaining_bits, result);
result <<= remaining_bits + offset_in_package;
result >>= remaining_bits + offset_in_package;
// printf("index %lld remaining_bits %lld result %llu\n", index,
// remaining_bits, result);
} else {
// printf("get: in else case\n");
// remaining_bits is negative here.
StorageType tmp = data_[first_package + 1];
tmp <<= (num_bits_per_package_ + remaining_bits);
tmp >>= (num_bits_per_package_ + remaining_bits);
result |= (tmp << (item_size_ + remaining_bits));
}
return static_cast<DataType>(result);
}

// For (potential) performance reasons, set() does no bounds checking.
void set(IndexType index, DataType value) {
int_fast64_t first_bit = index * item_size_;
int_fast64_t first_package = first_bit / num_bits_per_package_;
int_fast64_t offset_in_package =
first_bit - first_package * num_bits_per_package_;
/*printf("set index %lld value %lld offset_in_package %lld\n", index,
value, offset_in_package);
printf("set index %lld value %lld data_[first_package] %llu\n",
index, value, data_[first_package]);
printf("set index %lld value %lld shift %lld shifted %llu\n",
index, value, num_bits_per_package_ - offset_in_package,
(data_[first_package] << (num_bits_per_package_ - offset_in_package
- 1)) << 1);*/
// Avoid shift with shift_count == bit_width.
StorageType new_package, tmp;
if (offset_in_package != 0) {
new_package = data_[first_package]
<< (num_bits_per_package_ - offset_in_package);
new_package >>= (num_bits_per_package_ - offset_in_package);
tmp = value;
/*printf("set index %lld value %lld new_package %llu tmp %llu\n",
index, value, new_package, tmp);*/
new_package |= tmp << offset_in_package;
} else {
new_package = value;
}
/*StorageType new_package =
(data_[first_package] << (num_bits_per_package_ - offset_in_package
- 1));
new_package <<= 1;
new_package >>= (num_bits_per_package_ - offset_in_package - 1);
new_package >>= 1;
StorageType tmp = value;*/
/*printf("set index %lld value %lld new_package %llu tmp %llu\n",
index, value, new_package, tmp);*/
// new_package |= tmp << offset_in_package;
/*printf("set index %lld value %lld new_package2 %llu tmp %llu\n",
index, value, new_package, tmp);*/
int_fast64_t remaining_bits =
num_bits_per_package_ - (offset_in_package + item_size_);
if (remaining_bits > 0) {
tmp = data_[first_package];
tmp >>= (num_bits_per_package_ - remaining_bits);
tmp <<= (num_bits_per_package_ - remaining_bits);
/*printf("set index %lld value %lld remaining_bits %lld tmp %llu\n",
index, value, remaining_bits, tmp);*/
data_[first_package] = new_package | tmp;
} else if (remaining_bits == 0) {
/*printf("set index %lld value %lld remaining_bits %lld tmp %llu\n",
index, value, remaining_bits, tmp);*/
data_[first_package] = new_package;
} else {
// printf("set: in else case\n");
data_[first_package] = new_package;
tmp = value;
new_package = tmp >> (item_size_ + remaining_bits);
tmp = data_[first_package + 1];
tmp >>= -remaining_bits;
tmp <<= -remaining_bits;
data_[first_package + 1] = tmp | new_package;
}
// printf("current state: %llx\n", data_[0]);
}

private:
const int_fast64_t num_bits_per_package_ = 8 * sizeof(StorageType);

int_fast64_t num_items_;
int_fast64_t item_size_;
int_fast64_t num_data_packets_;
std::vector<StorageType> data_;
};

} // namespace core
} // namespace falconn

#endif

0 comments on commit 05cbf89

Please sign in to comment.