In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, LSTM, Bidirectional, Dense, MultiHeadAttention

# Generate random data
input_shape = np.random.randint(0, 20, size=(1000, 64))
output_data = np.random.randint(0, 3, size=(1000,))

# Define the model
input_layer = Input(shape=(64,))

# Embedding layer
embedding_layer = Embedding(input_dim=20, output_dim=50)(input_layer)

# Bidirectional LSTM layer
bidirectional_lstm = Bidirectional(LSTM(64, return_sequences=True))(embedding_layer)

# Multi-Head Attention layer
num_heads = 8  # You can adjust the number of heads
multihead_attention = MultiHeadAttention(num_heads=num_heads, key_dim=64 // num_heads)(bidirectional_lstm, bidirectional_lstm)

# Concatenate the multi-head attention output with the bidirectional LSTM output
concatenated_output = tf.keras.layers.Concatenate(axis=-1)([bidirectional_lstm, multihead_attention])

# LSTM layer
lstm_layer = LSTM(32)(concatenated_output)

# Dense layer
output_layer = Dense(3, activation='softmax')(lstm_layer)

# Create the model
model = tf.keras.Model(inputs=input_layer, outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Display the model summary
model.summary()

# Train the model (replace 'output_data' with your actual target data)
model.fit(input_shape, output_data, epochs=10, batch_size=32, validation_split=0.2)


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 64)]         0           []                               
                                                                                                  
 embedding (Embedding)          (None, 64, 50)       1000        ['input_1[0][0]']                
                                                                                                  
 bidirectional (Bidirectional)  (None, 64, 128)      58880       ['embedding[0][0]']              
                                                                                                  
 multi_head_attention (MultiHea  (None, 64, 128)     33088       ['bidirectional[0][0]',          
 dAttention)                                                      'bidirectional[0][0]']      

<keras.callbacks.History at 0x1cff352aa90>