In [7]:
import tensorflow as tf
import numpy as np

In [8]:
# Vocabulary size and embedding dimension
vocab_size = 20  # Small vocabulary
embedding_dim = 5  # Embedding size

In [9]:
# Create an embedding layer
embedding_layer = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)

In [10]:
# Define a simple loss function and optimizer
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()

In [11]:
# Simulated input batch of word indices (sparse updates)
word_indices = tf.constant([[1, 3, 7], [2, 5, 8]], dtype=tf.int32)  # Batch of word indices
targets = tf.random.normal((2, 3, embedding_dim))  # Random target embeddings

In [12]:
# Training loop
for epoch in range(1):
    with tf.GradientTape() as tape:
        # Forward pass
        embeddings = embedding_layer(word_indices)  # Get embeddings for the words in the batch
        loss = loss_fn(targets, embeddings)  # Compute MSE loss with the target embeddings
    
    # Compute gradients
    gradients = tape.gradient(loss, embedding_layer.trainable_variables)
    
    # Apply gradients
    optimizer.apply_gradients(zip(gradients, embedding_layer.trainable_variables))
    
    print(f"Epoch {epoch + 1}: Loss = {loss.numpy()}")
    
    # Display gradient sparsity
    print("Sparse Gradient Example (only non-zero updates):")
    non_zero_updates = tf.reduce_sum(tf.cast(tf.not_equal(gradients[0], 0), tf.int32)).numpy()
    print(f"Non-zero gradients: {non_zero_updates}/{vocab_size * embedding_dim} parameters")
    print("Gradients (truncated):")
    print(gradients[:10])  # Print gradients for the first 10 words

Epoch 1: Loss = 1.0070326328277588
Sparse Gradient Example (only non-zero updates):
Non-zero gradients: 30/100 parameters
Gradients (truncated):
[<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001961BD0A210>]


In [13]:
# Verify embeddings for untouched indices remain unchanged
print("\nFinal embeddings (first 5 words):")
print(embedding_layer.embeddings[:5])


Final embeddings (first 5 words):
tf.Tensor(
[[-0.04614676  0.00548912  0.0450556  -0.04679706  0.00980228]
 [ 0.04905393 -0.01446868  0.0105222   0.00238351  0.04675721]
 [-0.02414486 -0.03999238 -0.05066098  0.0208729   0.02861297]
 [ 0.03482476  0.02229689  0.04842939 -0.01246209 -0.02342523]
 [ 0.03767601 -0.0380299  -0.04815232 -0.02079639 -0.04310198]], shape=(5, 5), dtype=float32)
