<a href="https://colab.research.google.com/github/sotetsuk/LectureColab/blob/main/seq2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/odashi/small_parallel_enja.git

Cloning into 'small_parallel_enja'...
remote: Enumerating objects: 35, done.[K
remote: Total 35 (delta 0), reused 0 (delta 0), pack-reused 35 (from 1)[K
Receiving objects: 100% (35/35), 1.37 MiB | 9.15 MiB/s, done.
Resolving deltas: 100% (18/18), done.


In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
if "KERAS_BACKEND" not in os.environ:
    os.environ["KERAS_BACKEND"] = "torch"

import keras
from keras import layers, ops


class LuongAttention(layers.Layer):
    def call(self, inputs, mask=None):
        query, value = inputs
        scores = ops.matmul(query, ops.transpose(value, [0, 2, 1]))

        if mask is not None:
            enc_mask = mask[1] if isinstance(mask, (list, tuple)) else mask
            if enc_mask is not None:
                enc_mask = ops.expand_dims(ops.cast(enc_mask, scores.dtype), axis=1)
                scores = scores - 1e9 * (1.0 - enc_mask)

        weights = ops.softmax(scores, axis=-1)
        return ops.matmul(weights, value)

    def compute_output_shape(self, input_shape):
        return input_shape[0]

    def compute_mask(self, inputs, mask=None):
        return mask[0] if isinstance(mask, (list, tuple)) and mask else mask


class TranslationDataset(keras.utils.PyDataset):
    def __init__(self, en_sentences, ja_sentences, en_vectorizer, ja_vectorizer,
                 batch_size=32, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.batch_size = batch_size
        self.shuffle = shuffle

        self.en_sequences = ops.convert_to_tensor(en_vectorizer(en_sentences), dtype="int32")
        self.ja_sequences_input = ops.convert_to_tensor(
            ja_vectorizer(["<s> " + s for s in ja_sentences]), dtype="int32")
        self.ja_sequences_target = ops.convert_to_tensor(
            ja_vectorizer([s + " </s>" for s in ja_sentences]), dtype="int32")

        self.indices = ops.arange(len(en_sentences))
        self.on_epoch_end()

    def __len__(self):
        return len(self.indices) // self.batch_size

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        return {"encoder_input": self.en_sequences[batch_indices],
                "decoder_input": self.ja_sequences_input[batch_indices]}, \
               self.ja_sequences_target[batch_indices]

    def on_epoch_end(self):
        if self.shuffle:
            self.indices = ops.convert_to_tensor(keras.random.shuffle(self.indices))


def translate_sentences(model, en_sentences, en_vectorizer, ja_vectorizer, max_length=20):
    en_sequences = ops.convert_to_tensor(en_vectorizer(en_sentences), dtype="int32")
    start_token_index = int(ja_vectorizer(["<s>"])[0][0])
    end_token_index = int(ja_vectorizer(["</s>"])[0][0])
    ja_vocab = ja_vectorizer.get_vocabulary()

    translations = []
    for i in range(len(en_sentences)):
        current_sequence = [start_token_index]

        for _ in range(max_length - 1):
            enc_input = ops.expand_dims(en_sequences[i], axis=0)
            dec_input = ops.convert_to_tensor([current_sequence], dtype="int32")
            predictions = model.predict([enc_input, dec_input], verbose=0)
            next_token = int(ops.argmax(predictions[0, len(current_sequence)-1, :]))

            if next_token == end_token_index:
                break
            current_sequence.append(next_token)

        translated_tokens = [ja_vocab[idx] for idx in current_sequence[1:]
                           if idx < len(ja_vocab) and ja_vocab[idx] not in ["", "</s>"]]
        translations.append(" ".join(translated_tokens))

    return translations


def masked_sparse_crossentropy(y_true, y_pred):
    loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
    mask = ops.cast(ops.not_equal(y_true, 0), dtype=loss.dtype)
    loss = loss * mask
    return ops.sum(loss) / ops.sum(mask + keras.backend.epsilon())


def build_seq2seq_model(enc_vocab_size, dec_vocab_size, embedding_dim=256, latent_dim=512, num_layers=2):
    enc_in = keras.Input(shape=(None,), name="encoder_input")
    x = layers.Embedding(enc_vocab_size, embedding_dim, mask_zero=True)(enc_in)
    x = layers.Dropout(0.3)(x)

    enc_states = []
    for i in range(num_layers):
        x, h, c = layers.LSTM(latent_dim, return_sequences=True, return_state=True,
                              dropout=0.3, name=f"enc_lstm_{i+1}")(x)
        enc_states.append([h, c])
    enc_out = x

    dec_in = keras.Input(shape=(None,), name="decoder_input")
    y = layers.Embedding(dec_vocab_size, embedding_dim, mask_zero=True)(dec_in)
    y = layers.Dropout(0.3)(y)

    for i in range(num_layers):
        y = layers.LSTM(latent_dim, return_sequences=True, dropout=0.3,
                        name=f"dec_lstm_{i+1}")(y, initial_state=enc_states[i])

    context = LuongAttention(name="luong_attention")([y, enc_out])
    y = layers.Concatenate(axis=-1)([y, context])
    out = layers.Dense(dec_vocab_size, activation=None, name="output")(y)

    return keras.Model([enc_in, dec_in], out)


def main():
    print(f"Keras {keras.__version__}, Backend: {keras.config.backend()}")

    # Load data
    def load_data(split):
        with open(f"small_parallel_enja/{split}.en", "r") as f:
            en = [line.strip() for line in f]
        with open(f"small_parallel_enja/{split}.ja", "r") as f:
            ja = [line.strip() for line in f]
        return en, ja

    train_en, train_ja = load_data("train")
    val_en, val_ja = load_data("dev")

    # Create vectorizers
    en_vectorizer = layers.TextVectorization(
        max_tokens=5000, output_mode="int", output_sequence_length=20,
        standardize="lower_and_strip_punctuation")
    ja_vectorizer = layers.TextVectorization(
        max_tokens=5000, output_mode="int", output_sequence_length=20,
        standardize=None, split="whitespace")

    en_vectorizer.adapt(train_en)
    ja_vectorizer.adapt(["<s>", "</s>"] + train_ja +
                        ["<s> " + s + " </s>" for s in train_ja[:1000]])

    # Create datasets
    train_dataset = TranslationDataset(train_en, train_ja, en_vectorizer, ja_vectorizer, 64, True)
    val_dataset = TranslationDataset(val_en, val_ja, en_vectorizer, ja_vectorizer, 64, False)

    # Build and compile model
    model = build_seq2seq_model(len(en_vectorizer.get_vocabulary()),
                                len(ja_vectorizer.get_vocabulary()))
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
                  loss=masked_sparse_crossentropy, metrics=["accuracy"])

    # Sample sentences for evaluation
    sample_sentences = val_en[:5]

    # Translation callback
    class TranslationCallback(keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            if epoch % 5 == 0:  # Print every 5 epochs
                print(f"\nEpoch {epoch + 1} translations:")
                translations = translate_sentences(self.model, sample_sentences,
                                                 en_vectorizer, ja_vectorizer)
                for en, ja in zip(sample_sentences[:5], translations[:5]):  # Show only 2
                    print(f"EN: {en}\nJA: {ja}\n")

    # Train model
    history = model.fit(
        train_dataset,
        epochs=50,
        validation_data=val_dataset,
        callbacks=[
            keras.callbacks.ModelCheckpoint("seq2seq_model.keras", save_best_only=True,
                                          monitor="val_loss"),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5,
                                            patience=2, min_lr=1e-6),
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=10,
                                        restore_best_weights=True),
            TranslationCallback()
        ]
    )

    print(f"\nTraining complete! Best val_loss: {min(history.history['val_loss']):.4f}")

In [3]:
main()

Keras 3.8.0, Backend: torch
Epoch 1/50
[1m781/781[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 160ms/step - accuracy: 0.1555 - loss: 4.3049
Epoch 1 translations:
EN: show your own business .
JA: 私 は 何 を し て くださ い 。

EN: he lived a hard life .
JA: 彼 は その 仕事 を し た 。

EN: no . i 'm sorry , i 've got to go back early .
JA: 君 は その 仕事 を 見 て い る の は な い 。

EN: she wrote to me to come at once .
JA: 彼女 は 私 の ため に [UNK] し た 。

EN: i can 't swim at all .
JA: 私 は その 仕事 を し て い る 。

[1m781/781[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 166ms/step - accuracy: 0.1555 - loss: 4.3040 - val_accuracy: 0.2830 - val_loss: 2.8162 - learning_rate: 0.0010
Epoch 2/50
[1m781/781[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m117s[0m 150ms/step - accuracy: 0.2884 - loss: 2.7232 - val_accuracy: 0.3224 - val_loss: 2.3553 - learning_rate: 0.0010
Epoch 3/50
[1m781/781[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m116s[0m 148ms/step - accuracy: 0.3290 - loss: 2.2358 - val_accuracy: 0.3596