In [None]:

import pickle
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.sequence import pad_sequences
from keras.preprocessing.text import Tokenizer
from keras.models import Sequential, Model
from keras.layers import Embedding, Input, Dropout, add, dot, Activation, Permute, concatenate, LSTM, Dense


## Load Data

In [None]:

with open("train_qa.txt", "rb") as fp:
    train_data = pickle.load(fp)

with open("test_qa.txt", "rb") as fp:
    test_data = pickle.load(fp)

print("Training samples:", len(train_data))
print("Testing samples:", len(test_data))


## Build Vocabulary

In [None]:

all_data = train_data + test_data
vocab = set()
for story, question, answer in all_data:
    vocab |= set(story)
    vocab |= set(question)
vocab.add("yes")
vocab.add("no")

vocab_size = len(vocab) + 1
max_story_len = max(len(story) for story, _, _ in all_data)
max_question_len = max(len(q) for _, q, _ in all_data)

print("Vocabulary size:", vocab_size)
print("Max story length:", max_story_len)
print("Max question length:", max_question_len)


## Tokenization

In [None]:

tokenizer = Tokenizer(filters=[])
tokenizer.fit_on_texts(vocab)
word_index = tokenizer.word_index


## Vectorization

In [None]:

def vectorize_stories(data, word_index, max_story_len, max_question_len):
    X, Xq, Y = [], [], []
    for story, query, answer in data:
        x = [word_index[w.lower()] for w in story]
        xq = [word_index[w.lower()] for w in query]
        y = np.zeros(len(word_index) + 1)
        y[word_index[answer]] = 1
        X.append(x)
        Xq.append(xq)
        Y.append(y)
    return (
        pad_sequences(X, maxlen=max_story_len),
        pad_sequences(Xq, maxlen=max_question_len),
        np.array(Y)
    )

inputs_train, queries_train, answers_train = vectorize_stories(train_data, word_index, max_story_len, max_question_len)
inputs_test, queries_test, answers_test = vectorize_stories(test_data, word_index, max_story_len, max_question_len)

print("Training input shape:", inputs_train.shape)
print("Training query shape:", queries_train.shape)
print("Training answer shape:", answers_train.shape)


## Build the Model

In [None]:

input_sequence = Input((max_story_len,))
question = Input((max_question_len,))

# Encoders
input_encoder_m = Sequential([
    Embedding(vocab_size, 64),
    Dropout(0.3)
])
input_encoder_c = Sequential([
    Embedding(vocab_size, max_question_len),
    Dropout(0.3)
])
question_encoder = Sequential([
    Embedding(vocab_size, 64, input_length=max_question_len),
    Dropout(0.3)
])

# Encoded Inputs
input_encoded_m = input_encoder_m(input_sequence)
input_encoded_c = input_encoder_c(input_sequence)
question_encoded = question_encoder(question)

# Match (Attention)
match = dot([input_encoded_m, question_encoded], axes=(2, 2))
match = Activation("softmax")(match)

# Response
response = add([match, input_encoded_c])
response = Permute((2, 1))(response)

# Combine response + question
answer = concatenate([response, question_encoded])
answer = LSTM(32)(answer)
answer = Dropout(0.5)(answer)
answer = Dense(vocab_size)(answer)
answer = Activation("softmax")(answer)

# Final model
model = Model([input_sequence, question], answer)
model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"])
model.summary()


## Train the Model

In [None]:

history = model.fit(
    [inputs_train, queries_train], answers_train,
    batch_size=32,
    epochs=120,
    validation_data=([inputs_test, queries_test], answers_test)
)


## Training Results

In [None]:

plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title("Model Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()


## Sample Predictions

In [None]:

story, query, ans = test_data[0]
print("Story:", story)
print("Query:", query)
print("Correct Answer:", ans)

pred = model.predict([
    pad_sequences([[word_index[w.lower()] for w in story]], maxlen=max_story_len),
    pad_sequences([[word_index[w.lower()] for w in query]], maxlen=max_question_len)
])

val_max = np.argmax(pred[0])
for word, idx in word_index.items():
    if idx == val_max:
        print("Predicted Answer:", word)
        break
