# analyzing sequence data with transformers

xx

In [None]:
import numpy as np
import sklearn.datasets
import sklearn.model_selection

# generate sequence data
x, y = sklearn.datasets.make_classification(n_samples=38262,
                                            n_features=128,
                                            n_informative=32,
                                            random_state=8792439)
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(x,y, test_size=0.2, random_state=849691)
print(train_x.shape, valid_x.shape)
print(train_y.shape, valid_y.shape)

# add 'location' to sequence data
loc = np.linspace(start=-2.0, stop=+2.0, num=128)
train_x = np.stack([ train_x, np.array([loc]*train_x.shape[0]) ], axis=-1)
valid_x = np.stack([ valid_x, np.array([loc]*valid_x.shape[0]) ], axis=-1)

print(train_x.shape, valid_x.shape)
print(valid_x)

In [None]:
import tensorflow as tf

train_data = tf.data.Dataset.from_tensor_slices((train_x, train_y)).batch(32)
valid_data = tf.data.Dataset.from_tensor_slices((valid_x, valid_y)).batch(32)
print(train_data, valid_data)

In [None]:
# build model using functional api
inlayer = tf.keras.Input(shape=(128,2))
mha1 = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2)(inlayer,inlayer,inlayer)
res1 = tf.keras.layers.Add()([inlayer, mha1])
nrm1 = tf.keras.layers.LayerNormalization()(res1)
ffa1 = tf.keras.layers.Dense(units=2, activation=tf.keras.activations.relu)(nrm1)
ffb1 = tf.keras.layers.Dense(units=2)(ffa1)
res2 = tf.keras.layers.Add()([nrm1, ffb1])
nrm2 = tf.keras.layers.LayerNormalization()(res2)
flt = tf.keras.layers.Flatten()(nrm2)
outlayer = tf.keras.layers.Dense(units=1, activation=tf.keras.activations.sigmoid)(flt)

model = tf.keras.Model(inputs=inlayer, outputs=outlayer)
model.summary()

In [None]:
# compile model
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=['accuracy'])

# fit model
model.fit(train_data, epochs=20, validation_data=valid_data)