Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index - Single Input #161

Closed
phillips96 opened this issue Sep 23, 2021 · 9 comments
Closed

Index - Single Input #161

phillips96 opened this issue Sep 23, 2021 · 9 comments

Comments

@phillips96
Copy link
Contributor

phillips96 commented Sep 23, 2021

Hello,

The SimilarityModel.index method seems to require a batch size greater than 1, as it uses the add_batch method from the Indexer, which appears to cause some strided slice errors.

Could we get an update to index to check for a batch size of one and use the add from Indexer instead of add_batch? Or given there's already a single_lookup just a new method single_index if it's preffered?

Also as a side-note, is there a reason that the SimilarityModel._index is limited to either the compile or load_index methods? I'm using the sub-classing API, and due to custom training loops accommodating dynamic sizes, I never need to call compile so had to manually set the _index when running through the first time without a saved one.

Thank you for your time :)

(Note i'm happy to make a small PR to fix it!)

@owenvallis
Copy link
Collaborator

Thanks for the feedback, can you provide some more details on the strided slice errors you're seeing?

Currently the add_batch method is meant to accept batch sizes >= 1

x_index, y_index = sampler.get_slice(begin=0, size=1)
model.reset_index()
model.index(x_index, y_index, data=x_index)

Regarding initializing the index within compile, I think we could add a create_index method and then call that from within compile that way you could directly set the index if needed. We'll take a look at that and see if we can add a solution to the 0.14 update.

@phillips96
Copy link
Contributor Author

phillips96 commented Sep 23, 2021

So if I have an example model built, and the indexer setup I can run this and get the anticipated output;
In:

model = SimilarityModelTest()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
CLASSES = [2, 3, 1, 7, 9, 6, 8, 5, 0, 4]
x_index, y_index = select_examples(x_train, y_train, CLASSES, 20)
x_index = tf.cast(tf.expand_dims(x_index, axis=-1), dtype=tf.float32)
print(x_index.shape)
print(y_index.shape)
model.index(x_index, y_index, data=x_index)

Out:

(200, 28, 28, 1)
(200,)
[Indexing 200 points]
|-Computing embeddings
|-Storing data points in key value store
|-Adding embeddings to index.

All good. But if I select the first entry, and expand dims to give it the batch at the front to make it compat with .predict():
In:

...
x_single = tf.expand_dims(x_index[0], axis=0)
y_single = tf.expand_dims(y_index[0], axis=0)
print(x_single.shape)
print(y_single.shape)

model.index(x_single, y_single, data=x_single)

Out:

(200, 28, 28, 1)
(200,)
(1, 28, 28, 1)
(1,)
[Indexing 1 points]
|-Computing embeddings
Traceback (most recent call last):
  File ".\tmp.py", line 56, in <module>
    model.index(x_single, y_single, data=x_single)
  File "D:\Python\lib\site-packages\tensorflow_similarity\models\similarity_model.py", line 248, in index
    verbose=verbose)
  File "D:\Python\lib\site-packages\tensorflow_similarity\indexer.py", line 276, in batch_add
    idxs = self.kv_store.batch_add(embeddings, labels, data)
  File "D:\Python\lib\site-packages\tensorflow_similarity\stores\memory_store.py", line 81, in batch_add
    label = None if labels is None else labels[idx]
  File "D:\Python\lib\site-packages\tensorflow\python\util\dispatch.py", line 206, in wrapper
    return target(*args, **kwargs)
  File "D:\Python\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1052, in _slice_helper
    name=name)
  File "D:\Python\lib\site-packages\tensorflow\python\util\dispatch.py", line 206, in wrapper
    return target(*args, **kwargs)
  File "D:\Python\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1224, in strided_slice
    shrink_axis_mask=shrink_axis_mask)
  File "D:\Python\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 10510, in strided_slice
    _ops.raise_from_not_ok_status(e, name)
  File "D:\Python\lib\site-packages\tensorflow\python\framework\ops.py", line 6941, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: slice index 1 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/  

In my head, they're identical but just batch size = 1, so I'm not sure why the labels[idx] is looking for the 1st element and not the 0th element.

Much appreciated,
Alex

@owenvallis
Copy link
Collaborator

Thanks for details Alex. I think the issue might be with your custom SimilarityModelTest(). Would you be able to share more details on the SimilarityModelTest() setup?

I'm able to index single examples using the following:

Model

from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler

# Data sampler that generates balanced batches from MNIST dataset
sampler = TFDatasetMultiShotMemorySampler(dataset_name='mnist', classes_per_batch=10)

from tensorflow.keras import layers
from tensorflow_similarity.layers import MetricEmbedding
from tensorflow_similarity.models import SimilarityModel

# Build a Similarity model using standard Keras layers
inputs = layers.Input(shape=(28, 28, 1))
x = layers.experimental.preprocessing.Rescaling(1/255)(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
outputs = MetricEmbedding(64)(x)

# Build a specialized Similarity model
model = SimilarityModel(inputs, outputs)

from tensorflow_similarity.losses import MultiSimilarityLoss

# Train Similarity model using contrastive loss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)

Indexing

import tensorflow as tf
from tensorflow_similarity.samplers import select_examples

# Index the raw examples from tf keras datasets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
CLASSES = [2, 3, 1, 7, 9, 6, 8, 5, 0, 4]
x_index, y_index = select_examples(x_train, y_train, CLASSES, 20)
x_index = tf.cast(tf.expand_dims(x_index, axis=-1), dtype=tf.float32)
print(x_index.shape)
print(y_index.shape)
model.index(x_index, y_index, data=x_index)

# Index single example from tf keras datasets
x_single = tf.expand_dims(x_index[0], axis=0)
y_single = tf.expand_dims(y_index[0], axis=0)
print(x_single.shape)
print(y_single.shape)
model.index(x_single, y_single, data=x_single)

# Index single example using sampler.get_slice()
x_single_v2, y_single_v2 = sampler.get_slice(0,1)
print(x_single_v2.shape)
print(y_single_v2.shape)
model.index(x_single_v2, y_single_v2, data=x_single_v2)

I'm on the tf.sim 0.14 branch and using:

  • Tensorflow 2.6.0
  • nmslib 2.1.1
  • python 3.7.10

@phillips96
Copy link
Contributor Author

phillips96 commented Sep 23, 2021

As an example

class SimilarityModelTest(SimilarityModel):
    def __init__(self,):
        super().__init__(self,) 

        self.features = tf.keras.Sequential(layers=[
            tf.keras.layers.Conv2D(4, 2),
            tf.keras.layers.Conv2D(8, 2),
            tf.keras.layers.Conv2D(16, 2),
            MetricEmbedding(16)
        ])

        self.embedding_size = 16
        self._index = Indexer(embedding_size=self.embedding_size,
                                    distance='cosine',
                                    search='nmslib',
                                    kv_store='memory',
                                    evaluator='memory',
                                    embedding_output=0,
                                    stat_buffer_size=1000)


    def __call__(self, inputs, training=False, mask=None):
        return self.features(inputs, training=training)

    def call(self, inputs, training=False, mask=None):
        return self.__call__(inputs, training, mask)

However, by just swapping embedding_output to None instead of 0 it works!

Happy to remove the PR I submitted for the single_index method if you're happy that this is resolved?

Note that the add instead of batch_add works totally fine when embedding_output is 0, hence why the code in the PR seemed to make sense!

Thanks!

@owenvallis
Copy link
Collaborator

The embedding_output is meant to tell the indexer which of the output heads we should take as the embedding. If the embedding_output is None, then we take the prediction[0] see indexer.py +186, else if the embedding_output is >= 0 then we assume a multi headed output and take the prediction[embedding_output][0].

@owenvallis
Copy link
Collaborator

Hi Alex,

I refactored the model class and added a create_index method as part of pull request #164

I think that addresses the issues. Let me know if we can close out your single_index PR.

@owenvallis
Copy link
Collaborator

owenvallis commented Sep 24, 2021

Actually, had a chat and support for adding a single embedding would be very helpful. I'll add a few notes on the PR and then we can merge it in. Thanks for taking the time to submit the PR.

@phillips96
Copy link
Contributor Author

Thanks for explaining that!

I've made the requested updates for review! Feel free to close this if you're happy with the PR 👍

@owenvallis
Copy link
Collaborator

Thanks! PR looks good. I'll merge once the checks are finished.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants