In [47]:
import tensorflow as tf

class OneHot(tf.keras.layers.Layer):
    def __init__(self, depth, **kwargs):
        super(OneHot, self).__init__(**kwargs)
        self.depth = depth

    def call(self, x, mask=None):
        return tf.one_hot(tf.cast(x, tf.int32), self.depth)

In [30]:
class OneHotModel(tf.keras.Model):

    def __init__(self, vocab_size):
        super(OneHotModel, self).__init__()
        self.one_hot = OneHot(vocab_size)

    def call(self, inputs):
        output = self.one_hot(inputs)
        return output

In [43]:
import numpy as np

batch_inputs = np.array([0, 1, 0, 2, 1, 3])
model = OneHotModel(len(set(batch_inputs)))
output = model.predict(batch_inputs)

In [44]:
output

array([[[1., 0., 0., 0.]],

       [[0., 1., 0., 0.]],

       [[1., 0., 0., 0.]],

       [[0., 0., 1., 0.]],

       [[0., 1., 0., 0.]],

       [[0., 0., 0., 1.]]], dtype=float32)

## Another way to test that OneHot Layer is working properly:

In [48]:
tf_inputs = tf.keras.Input(shape=(1,))
one_hot = OneHot(len(set(batch_inputs)))(tf_inputs)
model = tf.keras.Model(inputs=tf_inputs, outputs=one_hot)
output_array = model.predict(batch_inputs)
print(output_array)

[[[1. 0. 0. 0.]]

 [[0. 1. 0. 0.]]

 [[1. 0. 0. 0.]]

 [[0. 0. 1. 0.]]

 [[0. 1. 0. 0.]]

 [[0. 0. 0. 1.]]]
