Skip to content

Commit

Permalink
Improve: Out-of-bounds checks
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 24, 2023
1 parent 1b40f13 commit 54cecb6
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions python/usearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import math
from typing import Optional, Union, NamedTuple, List, Iterable
from dataclasses import dataclass

import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -101,24 +102,28 @@ def _normalize_metric(metric):

return metric


class Match(NamedTuple):
@dataclass
class Match:
label: int
distance: float


class Matches(NamedTuple):
@dataclass
class Matches:
labels: np.ndarray
distances: np.ndarray

def __len__(self) -> int:
return len(self.labels)

def __getitem__(self, index: int) -> Match:
return Match(
label=self.labels[index],
distance=self.distances[index],
)
if isinstance(index, int) and index < len(self):
return Match(
label=self.labels[index],
distance=self.distances[index],
)
else:
raise IndexError(f"`index` must be an integer under {len(self)}")

def to_list(self) -> List[tuple]:
return [(int(l), float(d)) for l, d in zip(self.labels, self.distances)]
Expand All @@ -127,7 +132,8 @@ def __repr__(self) -> str:
return f"usearch.Matches({len(self)})"


class BatchMatches(NamedTuple):
@dataclass
class BatchMatches:
labels: np.ndarray
distances: np.ndarray
counts: np.ndarray
Expand All @@ -136,10 +142,13 @@ def __len__(self) -> int:
return len(self.counts)

def __getitem__(self, index: int) -> Matches:
return Matches(
labels=self.labels[index, : self.counts[index]],
distances=self.distances[index, : self.counts[index]],
)
if isinstance(index, int) and index < len(self):
return Matches(
labels=self.labels[index, : self.counts[index]],
distances=self.distances[index, : self.counts[index]],
)
else:
raise IndexError(f"`index` must be an integer under {len(self)}")

def to_list(self) -> List[List[tuple]]:
lists = [self.__getitem__(row) for row in range(self.__len__())]
Expand Down

0 comments on commit 54cecb6

Please sign in to comment.