In [4]:
from keras import layers
from keras.layers import Input
from keras.models import Model
import numpy as np

# 设置随机种子
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
from keras.utils import to_categorical

text_vocabulary_size = 10000
question_vocabulary_size = 10000
answer_vocabulary_size = 500

text_input = Input(shape=(None,), dtype="int32", name="text")
# 将输入嵌入到长度为64的向量中
embedded_text = layers.Embedding(text_vocabulary_size, 64)(text_input)
# 利用LSTM将序列转化为单个向量
encoded_text = layers.LSTM(32)(embedded_text)

question_input = Input(shape=(None,), dtype="int32", name="question")
embedded_question = layers.Embedding(question_vocabulary_size, 32)(question_input)
encoded_question = layers.LSTM(16)(embedded_question)
# 将两个向量连接起来
concatenated = layers.concatenate([encoded_question,encoded_text], axis= -1)
# 将连接后的向量传入一个全连接层
answer = layers.Dense(answer_vocabulary_size, activation="softmax")(concatenated)
# 构建模型
model = Model([text_input, question_input], answer)
model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["acc"])
model.summary()

num_sampels = 1000
max_length = 100
text = np.random.randint(1, text_vocabulary_size, size=(num_sampels, max_length))
question = np.random.randint(1, question_vocabulary_size, size=(num_sampels, max_length))
answers = np.random.randint(answer_vocabulary_size, size=(num_sampels,))
answer = to_categorical(answers, answer_vocabulary_size)
model.fit([text, question], answer, epochs=10, batch_size=128)

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 question (InputLayer)          [(None, None)]       0           []                               
                                                                                                  
 text (InputLayer)              [(None, None)]       0           []                               
                                                                                                  
 embedding_7 (Embedding)        (None, None, 32)     320000      ['question[0][0]']               
                                                                                                  
 embedding_6 (Embedding)        (None, None, 64)     640000      ['text[0][0]']                   
                                                                                            

<keras.callbacks.History at 0x23ef31cebf0>