Skip to content

Commit

Permalink
Approximate faiss k-NN search (#8952)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people committed Feb 26, 2024
1 parent 91f1ad3 commit ee07443
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))

Expand Down
57 changes: 53 additions & 4 deletions test/nn/pool/test_knn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import pytest
import torch

from torch_geometric.nn import L2KNNIndex, MIPSKNNIndex
from torch_geometric.nn import (
ApproxL2KNNIndex,
ApproxMIPSKNNIndex,
L2KNNIndex,
MIPSKNNIndex,
)
from torch_geometric.testing import withCUDA, withPackage


@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_L2_knn(device, k):
def test_l2(device, k):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(100, 16, device=device)

Expand All @@ -32,7 +37,7 @@ def test_L2_knn(device, k):
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_MIPS_knn(device, k):
def test_mips(device, k):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(100, 16, device=device)

Expand All @@ -53,10 +58,54 @@ def test_MIPS_knn(device, k):
assert torch.equal(out.index, index[:, :k])


@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_approx_l2(device, k):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(10_000, 16, device=device)

index = ApproxL2KNNIndex(
num_cells=10,
num_cells_to_visit=10,
bits_per_vector=8,
emb=rhs,
)

out = index.search(lhs, k)
assert out.score.device == device
assert out.index.device == device
assert out.score.size() == (10, k)
assert out.index.size() == (10, k)
assert out.index.min() >= 0 and out.index.max() < 10_000


@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_approx_mips(device, k):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(10_000, 16, device=device)

index = ApproxMIPSKNNIndex(
num_cells=10,
num_cells_to_visit=10,
bits_per_vector=8,
emb=rhs,
)

out = index.search(lhs, k)
assert out.score.device == device
assert out.index.device == device
assert out.score.size() == (10, k)
assert out.index.size() == (10, k)
assert out.index.min() >= 0 and out.index.max() < 10_000


@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [50])
def test_MIPS_exclude(device, k):
def test_mips_exclude(device, k):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(100, 16, device=device)

Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
from .edge_pool import EdgePooling
from .glob import global_add_pool, global_max_pool, global_mean_pool
from .knn import KNNIndex, L2KNNIndex, MIPSKNNIndex
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
ApproxMIPSKNNIndex)
from .graclus import graclus
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
from .mem_pool import MemPooling
Expand Down Expand Up @@ -324,6 +325,8 @@ def nearest(
'KNNIndex',
'L2KNNIndex',
'MIPSKNNIndex',
'ApproxL2KNNIndex',
'ApproxMIPSKNNIndex',
'TopKPooling',
'SAGPooling',
'EdgePooling',
Expand Down
92 changes: 89 additions & 3 deletions torch_geometric/nn/pool/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ class KNNIndex:
depending on whether to plan to use GPU-processing for :math:`k`-NN search.
Args:
index_factory (str): The name of the index factory to use, *e.g.*,
:obj:`"IndexFlatL2"` or :obj:`"IndexFlatIP"`. See `here
index_factory (str, optional): The name of the index factory to use,
*e.g.*, :obj:`"IndexFlatL2"` or :obj:`"IndexFlatIP"`. See `here
<https://github.com/facebookresearch/faiss/wiki/
The-index-factory>`_ for more information.
emb (torch.Tensor, optional): The data points to add.
(default: :obj:`None`)
"""
def __init__(self, index_factory: str, emb: Optional[Tensor] = None):
def __init__(
self,
index_factory: Optional[str] = None,
emb: Optional[Tensor] = None,
):
warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')

import faiss
Expand Down Expand Up @@ -73,6 +77,8 @@ def add(self, emb: Tensor):
self.index,
)

self.index.train(emb)

self.numel += emb.size(0)
self.index.add(emb.detach())

Expand Down Expand Up @@ -216,3 +222,83 @@ def __init__(self, emb: Optional[Tensor] = None):
def _create_index(self, channels: int):
import faiss
return faiss.IndexFlatIP(channels)


class ApproxL2KNNIndex(KNNIndex):
r"""Performs fast approximate :math:`k`-nearest neighbor search
(:math:`k`-NN) based on the the :math:`L_2` metric via the :obj:`faiss`
library.
Hyperparameters needs to be tuned for speed-accuracy trade-off.
Args:
num_cells (int): The number of cells.
num_cells_to_visit (int): The number of cells that are visited to
perform to search.
bits_per_vector (int): The number of bits per sub-vector.
emb (torch.Tensor, optional): The data points to add.
(default: :obj:`None`)
"""
def __init__(
self,
num_cells: int,
num_cells_to_visit: int,
bits_per_vector: int,
emb: Optional[Tensor] = None,
):
self.num_cells = num_cells
self.num_cells_to_visit = num_cells_to_visit
self.bits_per_vector = bits_per_vector
super().__init__(index_factory=None, emb=emb)

def _create_index(self, channels: int):
import faiss
index = faiss.IndexIVFPQ(
faiss.IndexFlatL2(channels),
channels,
self.num_cells,
self.bits_per_vector,
8,
faiss.METRIC_L2,
)
index.nprobe = self.num_cells_to_visit
return index


class ApproxMIPSKNNIndex(KNNIndex):
r"""Performs fast approximate :math:`k`-nearest neighbor search
(:math:`k`-NN) based on the maximum inner product via the :obj:`faiss`
library.
Hyperparameters needs to be tuned for speed-accuracy trade-off.
Args:
num_cells (int): The number of cells.
num_cells_to_visit (int): The number of cells that are visited to
perform to search.
bits_per_vector (int): The number of bits per sub-vector.
emb (torch.Tensor, optional): The data points to add.
(default: :obj:`None`)
"""
def __init__(
self,
num_cells: int,
num_cells_to_visit: int,
bits_per_vector: int,
emb: Optional[Tensor] = None,
):
self.num_cells = num_cells
self.num_cells_to_visit = num_cells_to_visit
self.bits_per_vector = bits_per_vector
super().__init__(index_factory=None, emb=emb)

def _create_index(self, channels: int):
import faiss
index = faiss.IndexIVFPQ(
faiss.IndexFlatIP(channels),
channels,
self.num_cells,
self.bits_per_vector,
8,
faiss.METRIC_INNER_PRODUCT,
)
index.nprobe = self.num_cells_to_visit
return index

0 comments on commit ee07443

Please sign in to comment.