In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np

# --- 1. Define input parameters ---
BATCH_SIZE = 2
SEQUENCE_LENGTH = 5
EMBEDDING_DIM = 8

# Create a sequence of input data (e.g., word embeddings)
# Shape: (batch_size, sequence_length, embedding_dim)
input_sequence = tf.constant(
    np.random.rand(BATCH_SIZE, SEQUENCE_LENGTH, EMBEDDING_DIM),
    dtype=tf.float32
)

# --- 2. Create the Attention Layer ---
attention_layer = tf.keras.layers.Attention(use_scale=True)
# use_scale=True scales the dot product by 1/sqrt(dim) for stability (standard in Transformers)

# --- 3. Apply Self-Attention ---
# For Self-Attention, the same tensor is used for Query, Value, and Key.
# The layer expects inputs as a list: [query, value, key]
output_sequence = attention_layer([input_sequence, input_sequence])

# --- 4. Inspect Results ---
print(f"Input Sequence Shape: {input_sequence.shape}")
print(f"Output Sequence Shape: {output_sequence.shape}")

# Output:
# Input Sequence Shape: (2, 5, 8)
# Output Sequence Shape: (2, 5, 8)

Input Sequence Shape: (2, 5, 8)
Output Sequence Shape: (2, 5, 8)


In [2]:
input_sequence

<tf.Tensor: shape=(2, 5, 8), dtype=float32, numpy=
array([[[0.6163462 , 0.4502409 , 0.54020715, 0.4567726 , 0.1706332 ,
         0.512923  , 0.43535164, 0.44838443],
        [0.0976066 , 0.16478907, 0.79595566, 0.60687643, 0.6379799 ,
         0.37943003, 0.22030625, 0.914196  ],
        [0.89062434, 0.7686348 , 0.01389515, 0.5637549 , 0.0274174 ,
         0.2825459 , 0.77198696, 0.14052314],
        [0.06659249, 0.6762018 , 0.712621  , 0.14408113, 0.9379386 ,
         0.87725616, 0.8915948 , 0.24940827],
        [0.8653926 , 0.8190601 , 0.81058645, 0.8047972 , 0.22301292,
         0.09609768, 0.26960856, 0.52175957]],

       [[0.6946656 , 0.19220184, 0.38341936, 0.37072253, 0.9913491 ,
         0.7016437 , 0.59946257, 0.749612  ],
        [0.637732  , 0.5509997 , 0.79309505, 0.9927132 , 0.9506386 ,
         0.11061171, 0.4487013 , 0.32262543],
        [0.06646844, 0.3703626 , 0.7601154 , 0.28812945, 0.3191472 ,
         0.44479486, 0.94008446, 0.22960438],
        [0.54868305, 0.2786