Skip to content

Commit

Permalink
Add 'single_index' (#162)
Browse files Browse the repository at this point in the history
* Add 'single_index' 

Using the same convention as `single_lookup`, I.E:
    - Assumes Tensor needs expanding prior to predict. 
    - Same Method name notation

* Update `single_index` to `index_single`

* Added unit test for `index_single`

Co-authored-by: owenvallis <owensvallis@gmail.com>
  • Loading branch information
phillips96 and owenvallis committed Sep 24, 2021
1 parent cefe309 commit 6ebc2e6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tensorflow_similarity/models/similarity_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,45 @@ def index(self,
build=build,
verbose=verbose)

def index_single(self,
x: Tensor,
y: IntTensor = None,
data: Optional[Tensor] = None,
build: bool = True,
verbose: int = 1):
"""Index data.
Args:
x: Sample to index.
y: class id associated with the data if any. Defaults to None.
data: store the data associated with the samples in the key
value store. Defaults to None.
build: Rebuild the index after indexing. This is needed to make the
new samples searchable. Set it to false to save processing time
when calling indexing repeatidly without the need to search between
the indexing requests. Defaults to True.
verbose: Output indexing progress info. Defaults to 1.
"""

if not self._index:
raise Exception('You need to compile the model with a valid'
'distance to be able to use the indexing')
if verbose:
print('[Indexing 1 point]')
print('|-Computing embeddings')

x = tf.expand_dims(x, axis=0)
prediction = self.predict(x)
self._index.add(prediction=prediction,
label=y,
data=data,
build=build,
verbose=verbose)

def lookup(self,
x: Tensor,
k: int = 5,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,19 @@ def test_save_no_compile(tmp_path):
model.save(tmp_path)
model2 = tf.keras.models.load_model(tmp_path)
assert isinstance(model2, type(model))


def test_index_single():
"""Unit Test for #161 & #162"""
inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = SimilarityModel(inputs, outputs)
model.compile(optimizer='adam', loss=TripletLoss())

# index data
x = tf.constant([1, 1, 3], dtype='float32')
y = tf.constant([1])

# run individual sample & index
model.index_single(x, y, data=x)
assert model._index.size() == 1

0 comments on commit 6ebc2e6

Please sign in to comment.