In [1]:
import pandas as pd
import tensorflow as tf

from modules.transformer_layers import Transformer, CustomSchedule, masked_loss, masked_accuracy
from sklearn.model_selection import train_test_split

In [2]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1
max_features = 5000

In [3]:
df = pd.read_csv('./dataset/rsa_key_dataset.csv')
x_train, x_test, y_train, y_test = train_test_split(df['public'].to_numpy(), df['private'].to_numpy())
train_examples = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val_examples = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [4]:
tokenizer_input = tf.keras.layers.TextVectorization(split='character', max_tokens=max_features, output_mode='int')
tokenizer_output = tf.keras.layers.TextVectorization(split='character', max_tokens=max_features, output_mode='int')

In [5]:
tokenizer_input.adapt(df['public'].to_numpy())
tokenizer_output.adapt(df['private'].to_numpy())

2023-03-26 16:40:36.103435: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
def prepare_batch(input, output):
    output = tokenizer_output(output)
    output = output[:, :d_model]

    input = tokenizer_input(input)
    input = input[:, :(d_model+1)]
    input_inputs = input[:, :-1]
    input_labels = input[:, 1:]

    return (output, input_inputs), input_labels

In [7]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64

def make_batches(ds):
  return (
      ds
      .shuffle(BUFFER_SIZE)
      .batch(BATCH_SIZE)
      .map(prepare_batch, tf.data.AUTOTUNE)
      .prefetch(buffer_size=tf.data.AUTOTUNE))

In [8]:
train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

In [9]:
transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=max_features,
    target_vocab_size=max_features,
    dropout_rate=dropout_rate)

In [10]:
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)

In [11]:
transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[masked_accuracy])

In [12]:
transformer.fit(train_batches,
                epochs=20,
                validation_data=val_batches)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x28128b010>