Skip to content

Commit

Permalink
Improve: Broader types support in usearch.io
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 29, 2023
1 parent 9dff0fb commit b1a1439
Showing 1 changed file with 38 additions and 21 deletions.
59 changes: 38 additions & 21 deletions python/usearch/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,43 @@
import numpy as np


def numpy_scalar_size(dtype) -> int:
return {
np.float64: 8,
np.int64: 8,
np.uint64: 8,
np.float32: 4,
np.int32: 4,
np.uint32: 4,
np.float16: 2,
np.int16: 2,
np.uint16: 2,
np.int8: 1,
np.uint8: 1,
}[dtype]


def guess_numpy_dtype_from_filename(filename) -> typing.Optional[type]:
if filename.endswith(".fbin"):
return np.float32
elif filename.endswith(".dbin"):
return np.float64
elif filename.endswith(".hbin"):
return np.float16
elif filename.endswith(".ibin"):
return np.int32
elif filename.endswith(".bbin"):
return np.uint8
else:
return None


def load_matrix(
filename: str,
start_row: int = 0,
count_rows: int = None,
view: bool = False,
dtype: typing.Optional[type] = None,
) -> typing.Optional[np.ndarray]:
"""Read *.ibin, *.bbib, *.hbin, *.fbin, *.dbin files with matrices.
Expand All @@ -21,25 +53,11 @@ def load_matrix(
:return: parsed matrix
:rtype: numpy.ndarray
"""
dtype = np.float32
scalar_size = 4
if filename.endswith(".fbin"):
dtype = np.float32
scalar_size = 4
elif filename.endswith(".dbin"):
dtype = np.float64
scalar_size = 8
elif filename.endswith(".hbin"):
dtype = np.float16
scalar_size = 2
elif filename.endswith(".ibin"):
dtype = np.int32
scalar_size = 4
elif filename.endswith(".bbin"):
dtype = np.uint8
scalar_size = 1
else:
raise Exception("Unknown file type")
if dtype is None:
dtype = guess_numpy_dtype_from_filename(filename)
if dtype is None:
raise Exception("Unknown file type")
scalar_size = numpy_scalar_size(dtype)

if not os.path.exists(filename):
return None
Expand Down Expand Up @@ -74,7 +92,6 @@ def save_matrix(vectors: np.ndarray, filename: str):
:param filename: path to the matrix file
:type filename: str
"""
dtype = np.float32
if filename.endswith(".fbin"):
dtype = np.float32
elif filename.endswith(".dbin"):
Expand All @@ -86,7 +103,7 @@ def save_matrix(vectors: np.ndarray, filename: str):
elif filename.endswith(".bbin"):
dtype = np.uint8
else:
raise Exception("Unknown file type")
dtype = vectors.dtype

assert len(vectors.shape) == 2, "Input array must have 2 dimensions"
with open(filename, "wb") as f:
Expand Down

0 comments on commit b1a1439

Please sign in to comment.