In [1]:
import tensorflow as tf

class CBOWModel(tf.keras.Model):
    def __init__(self, vocab_sz, emb_sz, window_sz, **kwargs):
        super(CBOWModel, self).__init__(**kwargs)
        self.embedding = tf.keras.layers.Embedding(
            input_dim=vocab_sz,
            output_dim=emb_sz,
            embeddings_initializer="glorot_uniform",
            input_length=window_sz*2
        )
        self.dense = tf.keras.layers.Dense(
            vocab_sz,
            kernel_initializer="glorot_uniform",
            activation="softmax"
        )

    def call(self, x):
        x = self.embedding(x)
        x = tf.reduce_mean(x, axis=1)
        x = self.dense(x)
        return x


VOCAB_SIZE = 5000
EMBED_SIZE = 300
WINDOW_SIZE = 1  # 3 word window, 1 on left, 1 on right

model = CBOWModel(VOCAB_SIZE, EMBED_SIZE, WINDOW_SIZE)
model.build(input_shape=(None, VOCAB_SIZE))
model.compile(optimizer=tf.optimizers.Adam(),
    loss="categorical_crossentropy",
    metrics=["accuracy"])

model.summary()

# train the model here

# retrieve embeddings from trained model
emb_layer = [layer for layer in model.layers 
    if layer.name.startswith("embedding")][0]
emb_weight = [weight.numpy() for weight in emb_layer.weights
    if weight.name.endswith("/embeddings:0")][0]
print(emb_weight, emb_weight.shape)


Model: "cbow_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        multiple                  1500000   
_________________________________________________________________
dense (Dense)                multiple                  1505000   
Total params: 3,005,000
Trainable params: 3,005,000
Non-trainable params: 0
_________________________________________________________________
[[-0.00747491 -0.00965876  0.00058828 ... -0.03070909 -0.02074809
  -0.01729073]
 [-0.02535501  0.00546921  0.03070942 ...  0.03015628 -0.00556514
   0.01839813]
 [-0.00685583  0.01317265 -0.01770947 ...  0.02328325  0.00451733
   0.00912225]
 ...
 [-0.01492023  0.02622795  0.02864654 ... -0.03148879 -0.02111962
   0.02042102]
 [-0.00662356  0.01656102 -0.01947211 ... -0.02060558 -0.00068993
  -0.01413563]
 [ 0.01263543  0.0252607   0.02902652 ...  0.00723125  0.03273669
  -0.00805153]] (5000, 300)
