<a href="https://colab.research.google.com/github/yuta-kubo/devenvs/blob/keybindings/vision_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# pip upgrade
!pip install -U pip
# インストール済みのバージョンではエラーが起こる。バージョンの問題らしいので最新版に上げる
!pip install -U tensorflow-addons

In [None]:
# GPUが認識されていることを確認
import tensorflow as tf
tf.test.gpu_device_name()

In [2]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import (
    Dense,
    Dropout,
    LayerNormalization,
    Add,
    Activation,
    Input,
)
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical


def MultiHead_SelfAttention(inputs, embed_dim, num_heads):
    projection_dim = embed_dim // num_heads
    batch_size = K.int_shape(inputs)[0]

    query = Dense(embed_dim)(inputs)
    key = Dense(embed_dim)(inputs)
    value = Dense(embed_dim)(inputs)

    query = K.reshape(query, (batch_size, -1, num_heads, projection_dim))
    key = K.reshape(key, (batch_size, -1, num_heads, projection_dim))
    value = K.reshape(value, (batch_size, -1, num_heads, projection_dim))

    query = K.permute_dimensions(query, (0, 2, 1, 3))
    key = K.permute_dimensions(key, (0, 2, 1, 3))
    value = K.permute_dimensions(value, (0, 2, 1, 3))

    score = tf.matmul(query, key, transpose_b=True)
    score = score / K.sqrt(K.cast(projection_dim, "float32"))
    weights = Activation("softmax")(score)

    attention = tf.matmul(weights, value)
    attention = K.permute_dimensions(attention, (0, 2, 1, 3))
    attention = K.reshape(attention, (batch_size, -1, embed_dim))
    output = Dense(embed_dim)(attention)
    return output


def TransformerBlock(inputs, embed_dim, num_heads, ff_dim):
    attn_output = MultiHead_SelfAttention(inputs, embed_dim, num_heads)
    attn_output = Dropout(0.1)(attn_output)
    out1 = LayerNormalization(epsilon=1e-6)(Add()([inputs, attn_output]))
    ffn_output = Dense(ff_dim, activation="relu")(out1)
    ffn_output = Dense(embed_dim)(ffn_output)
    ffn_output = Dropout(0.1)(ffn_output)
    return LayerNormalization(epsilon=1e-6)(Add()([out1, ffn_output]))


class Add_Embedding_Layer(tf.keras.layers.Layer):
    def __init__(self, num_patches=64, d_model=64, batch_size=16):
        super(Add_Embedding_Layer, self).__init__()
        self.batch_size = batch_size
        self.patch_emb = self.add_weight(shape=[1, 1, d_model], dtype=tf.float32)
        self.pos_emb = self.add_weight(shape=[1, num_patches + 1, d_model], dtype=tf.float32)

    def call(self, input):
        patch_emb = K.repeat_elements(self.patch_emb, self.batch_size, axis=0)
        pos_emb = K.repeat_elements(self.pos_emb, self.batch_size, axis=0)
        return K.concatenate([input, patch_emb], axis=1) + pos_emb


epochs = 30
batch_size = 400


def make_ViT(img_size=32, ch_size=3, patch_size=4, batch_size=400, num_layers=4, d_model=64, num_heads=4, mlp_dim=128, num_classes=10):

    num_patches = (img_size // patch_size) ** 2
    patch_dim = ch_size * patch_size ** 2

    inputs = Input(shape=(32, 32, 3))

    x = Rescaling(1.0 / 255)(inputs)
    x = tf.nn.space_to_depth(x, patch_size)
    x = K.reshape(x, (-1, num_patches, patch_dim))
    x = Dense(d_model)(x)

    x = Add_Embedding_Layer(num_patches, d_model, batch_size)(x)
    for _ in range(num_layers):
        x = TransformerBlock(x, d_model, num_heads, mlp_dim)

    x = Dense(mlp_dim, activation=tfa.activations.gelu)(x[:, 0])
    x = Dropout(0.1)(x)
    y = Dense(num_classes, activation="softmax")(x)
    return Model(inputs=inputs, outputs=y)


model = make_ViT()
model.compile(optimizer="Adam", loss="categorical_crossentropy", metrics=["accuracy"])
model.summary()

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
history = model.fit(
    x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), verbose=1,
)


TensorFlow Addons has compiled its custom ops against TensorFlow 2.2.0, and there are no compatibility guarantees between the two versions. 
This means that you might get segfaults when loading the custom op, or other kind of low-level errors.
 If you do, do not file an issue on Github. This is a known limitation.

It might help you to fallback to pure Python ops with TF_ADDONS_PY_OPS . To do that, see https://github.com/tensorflow/addons#gpucpu-custom-ops 

You can also change the TensorFlow version installed on your system. You would need a TensorFlow version equal to or above 2.2.0 and strictly below 2.3.0.
 Note that nightly versions of TensorFlow, as well as non-pip TensorFlow like `conda install tensorflow` or compiled from source are not supported.

The last solution is to find the TensorFlow Addons version that has custom ops compatible with the TensorFlow installed on your system. To do that, refer to the readme: https://github.com/tensorflow/addons


Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 32, 32, 3)    0           input_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_SpaceToDepth (Tenso [(None, 8, 8, 48)]   0           rescaling[0][0]                  
__________________________________________________________________________________________________
tf_op_layer_Reshape (TensorFlow [(None, 64, 48)]     0           tf_op_layer_SpaceToDepth[0][0]   
_______________________________________________________________________________________