In [17]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import Dense, Dropout, LayerNormalization
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
import os
from argparse import ArgumentParser
import tensorflow_datasets as tfds
from tensorflow.keras.callbacks import TensorBoard

In [18]:
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
            )
        self.projection_dim = embed_dim // num_heads
        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)
        self.combine_heads = Dense(embed_dim)

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))
        output = self.combine_heads(concat_attention)
        return output


class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadSelfAttention(embed_dim, num_heads)
        self.mlp = tf.keras.Sequential(
            [
                Dense(mlp_dim, activation=tf.nn.relu),
                Dense(embed_dim),
            ]
        )

    def call(self, inputs, training):
        attn_output = self.att(inputs)
        out1 = attn_output + inputs

        mlp_output = self.mlp(out1)
        return mlp_output + out1


class VisionTransformer(tf.keras.Model):
    def __init__(self, image_size, patch_size, num_layers, num_classes, d_model, num_heads, mlp_dim, channels=1):
        super(VisionTransformer, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size
        self.d_model = d_model
        self.num_layers = num_layers

        self.rescale = Rescaling(1.0 / 255)
        self.pos_emb = self.add_weight(
            "pos_emb", shape=(1, num_patches + 1, d_model)
        )
        self.class_emb = self.add_weight("class_emb", shape=(1, 1, d_model))
        self.patch_proj = Dense(d_model)
        self.enc_layers = [
            TransformerBlock(d_model, num_heads, mlp_dim)
            for _ in range(num_layers)
        ]
        self.mlp_head = tf.keras.Sequential(
            [
                Dense(mlp_dim, activation=tfa.activations.gelu),
                Dense(num_classes),
            ]
        )

    def extract_patches(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patches = tf.reshape(patches, [batch_size, -1, self.patch_dim])
        return patches

    def call(self, x, training):
        batch_size = tf.shape(x)[0]
        x = self.rescale(x)
        patches = self.extract_patches(x)
        x = self.patch_proj(patches)

        class_emb = tf.broadcast_to(self.class_emb, [batch_size, 1, self.d_model])
        x = tf.concat([class_emb, x], axis=1)
        x = x + self.pos_emb

        for layer in self.enc_layers:
            x = layer(x, training)

        # First (class token) is used for classification
        x = self.mlp_head(x[:, 0])
        return x
    
    def load_weight(self, filepath):
        """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
        Arguments:
            filepath: String, path to the weights file to load. For weight files in
                TensorFlow format, this is the file prefix (the same as was passed
                to `save_weights`).
        Returns:
            When loading a weight file in TensorFlow format, returns the same status
            object as `tf.train.Checkpoint.restore`. When graph building, restore
            ops are run automatically as soon as the network is built (on first call
            for user-defined classes inheriting from `Model`, immediately if it is
            already built).
            When loading weights in HDF5 format, returns `None`.
        Raises:
            ImportError: If h5py is not available and the weight file is in HDF5
                format.
            ValueError: If `skip_mismatch` is set to `True` when `by_name` is
              `False`.
        """

        filepath = path_to_string(filepath)
        if h5py is None:
            raise ImportError(
                '`load_weights` requires h5py when loading weights from HDF5.')
        if not self._is_graph_network and not self.built:
            raise ValueError(
                'Unable to load weights saved in HDF5 format into a subclassed '
                'Model which has not created its variables yet. Call the Model '
                'first, then load the weights.')
        with h5py.File(filepath, 'r') as f:
            if 'layer_names' not in f.attrs and 'model_weights' in f:
                f = f['model_weights']
            else:
                hdf5_format.load_weights_from_hdf5_group(f, self.layers)

In [19]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

logdir = "logs"
image_size = 28
patch_size = 4
num_layers = 4
d_model = 64
num_heads = 4
mlp_dim = 128
lr = 3e-4
weight_decay = 1e-4
batch_size = 4096
epochs = 1

ds = tfds.load("mnist", as_supervised=True)

ds_train = (ds["train"].shuffle(5 * batch_size).batch(batch_size))

ds_test = (
    ds["test"]
    .cache()
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)


model = VisionTransformer(
    image_size=image_size,
    patch_size=patch_size,
    num_layers=num_layers,
    num_classes=10,
    d_model=d_model,
    num_heads=num_heads,
    mlp_dim=mlp_dim,
    channels=1,
)
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tfa.optimizers.AdamW(learning_rate=lr, weight_decay=weight_decay),
    metrics=["accuracy"],
)
model.fit(
    ds_train,
    validation_data=ds_test,
    epochs=epochs,
    callbacks=[TensorBoard(log_dir=logdir, profile_batch=0),],
)
model.load_weights("../saved_weights/vit.hdf5")
model.evaluate(ds_test)



[3.9850449562072754, 0.16609999537467957]

In [20]:
from tensorflow.python.client import device_lib

def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']

In [23]:
model.summary()

Model: "vision_transformer_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
rescaling_11 (Rescaling)     multiple                  0         
_________________________________________________________________
dense_297 (Dense)            multiple                  1088      
_________________________________________________________________
transformer_block_44 (Transf multiple                  33216     
_________________________________________________________________
transformer_block_45 (Transf multiple                  33216     
_________________________________________________________________
transformer_block_46 (Transf multiple                  33216     
_________________________________________________________________
transformer_block_47 (Transf multiple                  33216     
_________________________________________________________________
sequential_59 (Sequential)   (None, 10)      

In [41]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

logdir = "logs"
image_size = 28
patch_size = 4
num_layers = 4
d_model = 64
num_heads = 4
mlp_dim = 128
lr = 3e-4
weight_decay = 1e-4
batch_size = 4096
epochs = 300

ds = tfds.load("mnist", as_supervised=True)

ds_train = (
    ds["train"]
    .cache()
    .shuffle(5 * batch_size)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
ds_test = (
    ds["test"]
    .cache()
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = VisionTransformer(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        num_classes=10,
        d_model=d_model,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        channels=1,
    )
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tfa.optimizers.AdamW(learning_rate=lr, weight_decay=weight_decay),
        metrics=["accuracy"],
    )

model.fit(
    ds_train,
    validation_data=ds_test,
    epochs=epochs,
    callbacks=[TensorBoard(log_dir=logdir, profile_batch=0),],
)
model.save_weights(os.path.join(logdir, "weights", "vit"))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


Epoch 1/300








Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300


Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 79/300
Epoch 80/300
Epoch 81/300
Epoch 82/300
Epoch 83/300
Epoch 84/300
Epoch 85/300
Epoch 86/300
Epoch 87/300
Epoch 88/300
Epoch 89/300
Epoch 90/300
Epoch 91/300
Epoch 92/300
Epoch 93/300
Epoch 94/300
Epoch 95/300
Epoch 96/300
Epoch 97/300
Epoch 98/300
Epoch 99/300
Epoch 100/300
Epoch 101/300
Epoch 102/300
Epoch 103/300
Epoch 104/300
Epoch 105/300
Epoch 106/300
Epoch 107/300
Epoch 108/300
Epoch 109/300
Epoch 110/300
Epoch 111/300
Epoch 112/300
Epoch 113/300
Epoch 114/300


Epoch 115/300
Epoch 116/300
Epoch 117/300
Epoch 118/300
Epoch 119/300
Epoch 120/300
Epoch 121/300
Epoch 122/300
Epoch 123/300
Epoch 124/300
Epoch 125/300
Epoch 126/300
Epoch 127/300
Epoch 128/300
Epoch 129/300
Epoch 130/300
Epoch 131/300
Epoch 132/300
Epoch 133/300
Epoch 134/300
Epoch 135/300
Epoch 136/300
Epoch 137/300
Epoch 138/300
Epoch 139/300
Epoch 140/300
Epoch 141/300
Epoch 142/300
Epoch 143/300
Epoch 144/300
Epoch 145/300
Epoch 146/300
Epoch 147/300
Epoch 148/300
Epoch 149/300
Epoch 150/300
Epoch 151/300
Epoch 152/300
Epoch 153/300
Epoch 154/300
Epoch 155/300
Epoch 156/300
Epoch 157/300
Epoch 158/300
Epoch 159/300
Epoch 160/300
Epoch 161/300
Epoch 162/300
Epoch 163/300
Epoch 164/300
Epoch 165/300
Epoch 166/300
Epoch 167/300
Epoch 168/300
Epoch 169/300
Epoch 170/300
Epoch 171/300
Epoch 172/300
Epoch 173/300
Epoch 174/300
Epoch 175/300
Epoch 176/300
Epoch 177/300
Epoch 178/300
Epoch 179/300
Epoch 180/300
Epoch 181/300
Epoch 182/300
Epoch 183/300
Epoch 184/300
Epoch 185/300
Epoch 

Epoch 228/300
Epoch 229/300
Epoch 230/300
Epoch 231/300
Epoch 232/300
Epoch 233/300
Epoch 234/300
Epoch 235/300
Epoch 236/300
Epoch 237/300
Epoch 238/300
Epoch 239/300
Epoch 240/300
Epoch 241/300
Epoch 242/300
Epoch 243/300
Epoch 244/300
Epoch 245/300
Epoch 246/300
Epoch 247/300
Epoch 248/300
Epoch 249/300
Epoch 250/300
Epoch 251/300
Epoch 252/300
Epoch 253/300
Epoch 254/300
Epoch 255/300
Epoch 256/300
Epoch 257/300
Epoch 258/300
Epoch 259/300
Epoch 260/300
Epoch 261/300
Epoch 262/300
Epoch 263/300
Epoch 264/300
Epoch 265/300
Epoch 266/300
Epoch 267/300
Epoch 268/300
Epoch 269/300
Epoch 270/300
Epoch 271/300
Epoch 272/300
Epoch 273/300
Epoch 274/300
Epoch 275/300
Epoch 276/300
Epoch 277/300
Epoch 278/300
Epoch 279/300
Epoch 280/300
Epoch 281/300
Epoch 282/300
Epoch 283/300
Epoch 284/300
Epoch 285/300
Epoch 286/300
Epoch 287/300
Epoch 288/300
Epoch 289/300
Epoch 290/300
Epoch 291/300
Epoch 292/300
Epoch 293/300
Epoch 294/300
Epoch 295/300
Epoch 296/300
Epoch 297/300
Epoch 298/300
Epoch 

In [50]:
with strategy.scope():
    model = VisionTransformer(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        num_classes=10,
        d_model=d_model,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        channels=1,
    )
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tfa.optimizers.AdamW(learning_rate=lr, weight_decay=weight_decay),
        metrics=["accuracy"],
    )
    model.load_weights(os.path.join(logdir, "weights", "vit"))
model.evaluate(ds_test)



[0.0930701419711113, 0.9764000177383423]

In [51]:
model.evaluate(ds_test)



[0.0930701419711113, 0.9764000177383423]