In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


In [None]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras import Sequential
import tensorflow.keras.layers as nn

from tensorflow import einsum
from einops import rearrange, repeat
from einops.layers.tensorflow import Rearrange
import numpy as np

In [None]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(Layer):
    def __init__(self, fn):
        super(PreNorm, self).__init__()

        self.norm = nn.LayerNormalization()
        self.fn = fn

    def call(self, x, training=True):
        return self.fn(self.norm(x), training=training)

class Attentionlepe(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dropout_rate):
        super(Attentionlepe, self).__init__()
        
        self.num_heads = num_heads
        self.d_model = d_model
        self.dropout_rate = dropout_rate
        
        self.local_att_layer = tf.keras.layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.d_model, 
                                                                  dropout=self.dropout_rate)
        self.att_drop = tf.keras.layers.Dropout(self.dropout_rate)
        self.local_att_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, mask=None):
        bsz, qlen = tf.shape(inputs)[0], tf.shape(inputs)[1]
        
        pos_emb = self.add_positional_embeddings(inputs)
        
        pos_emb_reshaped = tf.reshape(pos_emb, [bsz, qlen, self.num_heads, self.d_model // self.num_heads])
        pos_emb_reshaped = tf.transpose(pos_emb_reshaped, [0, 2, 1, 3])
        pos_emb_reshaped = tf.reshape(pos_emb_reshaped, [bsz * self.num_heads, qlen, self.d_model // self.num_heads])
        
        out = self.local_att_layer(pos_emb_reshaped, pos_emb_reshaped, attention_mask=None, return_attention_scores=False)
        
        out = tf.reshape(out, [bsz, self.num_heads, qlen, self.d_model // self.num_heads])
        out = tf.transpose(out, [0, 2, 1, 3])
        out = tf.reshape(out, [bsz, qlen, self.d_model])
        
        out = self.att_drop(out)
        out = self.local_att_norm(inputs + out)
        
        return out
    
    

    def add_positional_embeddings(self, inputs):
        bsz, seq_len = tf.shape(inputs)[0], tf.shape(inputs)[1]
        pos_emb = self.get_absolute_position_embeddings(seq_len, self.d_model)
        
        pos_emb_reshaped = tf.reshape(pos_emb, [1, seq_len, self.d_model])
        pos_emb_reshaped = tf.broadcast_to(pos_emb_reshaped, [bsz, seq_len, self.d_model])
        
        return pos_emb_reshaped

    def get_absolute_position_embeddings(self, seq_len, hidden_size):
        freqs = tf.range(0, hidden_size, 2, dtype=tf.float32)
        freqs = 1 / (10000 ** (freqs / hidden_size))
        pos = tf.range(0, seq_len, dtype=tf.float32)
        pos = tf.expand_dims(pos, axis=1)
        pos = tf.broadcast_to(pos, [seq_len, hidden_size // 2])
        pos = tf.cast(pos, tf.float32)
        sin_emb = tf.sin(pos * freqs)
        cos_emb = tf.cos(pos * freqs)
        pos_emb = tf.stack([sin_emb, cos_emb], axis=-1)
        pos_emb = tf.reshape(pos_emb, [seq_len, hidden_size])
        
        return pos_emb


class BilevelRoutingAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_iterations):
        super(BilevelRoutingAttention, self).__init__()

        self.d_model = d_model
        self.num_iterations = num_iterations

        self.w1 = self.add_weight(name='w1', shape=(d_model, d_model), initializer='glorot_uniform', trainable=True)
        self.w2 = self.add_weight(name='w2', shape=(d_model, d_model), initializer='glorot_uniform', trainable=True)
        self.v = self.add_weight(name='v', shape=(d_model, 1), initializer='glorot_uniform', trainable=True)

    def call(self, inputs, mask=None):
       
        u = tf.matmul(inputs, self.w1)  # (batch_size, seq_len, d_model)
        sequence_weights = tf.nn.softmax(tf.squeeze(tf.matmul(tf.nn.tanh(u), self.v), axis=-1))  # (batch_size, seq_len)

        sequence_output = tf.reduce_sum(tf.expand_dims(sequence_weights, axis=-1) * inputs, axis=1)  # (batch_size, d_model)

        if mask is not None:
            mask = tf.expand_dims(mask, axis=-1)
            inputs *= mask
            u = tf.matmul(inputs, self.w2)  # (batch_size, seq_len, d_model)
            token_weights = tf.nn.softmax(tf.squeeze(tf.matmul(tf.nn.tanh(u), self.v), axis=-1) + (1 - mask) * -1e9)
        else:
            u = tf.matmul(inputs, self.w2)  # (batch_size, seq_len, d_model)
            token_weights = tf.nn.softmax(tf.squeeze(tf.matmul(tf.nn.tanh(u), self.v), axis=-1))

        token_output = tf.reduce_sum(tf.expand_dims(token_weights, axis=-1) * inputs, axis=1)  # (batch_size, d_model)

        for i in range(self.num_iterations):
            token_output = token_output + sequence_output
            u = tf.matmul(tf.expand_dims(token_output, axis=1), self.w2)  # (batch_size, 1, d_model)
            token_weights = tf.nn.softmax(tf.squeeze(tf.matmul(tf.nn.tanh(u), self.v), axis=-1) + (1 - mask) * -1e9)
            token_output = tf.reduce_sum(tf.expand_dims(token_weights, axis=-1) * inputs, axis=1)  # (batch_size, d_model)

        return token_output

class MLP(Layer):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super(MLP, self).__init__()

        def GELU():
            def gelu(x, approximate=False):
                if approximate:
                    coeff = tf.cast(0.044715, x.dtype)
                    return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3))))
                else:
                    return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype)))

            return nn.Activation(gelu)

        self.net = Sequential([
            nn.Dense(units=hidden_dim),
            GELU(),
            nn.Dropout(rate=dropout),
            nn.Dense(units=dim),
            nn.Dropout(rate=dropout)
        ])

    def call(self, x, training=True):
        return self.net(x, training=training)

class Attention(Layer):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super(Attention, self).__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax()
        self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False)

        if project_out:
            self.to_out = [
                nn.Dense(units=dim),
                nn.Dropout(rate=dropout)
            ]
        else:
            self.to_out = []

        self.to_out = Sequential(self.to_out)

    def call(self, x, training=True):
        qkv = self.to_qkv(x)
        qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        # dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = self.attend(dots)

        # x = tf.matmul(attn, v)
        x = einsum('b h i j, b h j d -> b h i d', attn, v)
        x = rearrange(x, 'b h n d -> b n (h d)')
        x = self.to_out(x, training=training)

        return x

class block(Layer):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim,topk=-1,layer_scale_val=None, dropout=0.0):
        super(block, self).__init__()
        self.layer_scale_val = layer_scale_val
        self.layers = []
        if topk>0:
          att = BilevelRoutingAttention(dim,dim_head)
        if topk== -1:
          att = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
        if topk==-2:
          att = Attentionlepe(dim,dim_head,dropout)

        for _ in range(depth):
            self.layers.append([
                PreNorm(att),
                PreNorm(MLP(dim, mlp_dim, dropout=dropout))
            ])

    def call(self, x, training=True):
        for attn, mlp in self.layers:
          if self.layer_scale_val:
            x = attn(x, training=training)*self.layer_scale_val + x
            x = mlp(x, training=training)*self.layer_scale_val + x
          else:
            x = attn(x, training=training) + x
            x = mlp(x, training=training) + x

        return x
        
class BiFormer(Model):
    def __init__(self, image_size, embed_size,patch_size, num_classes, dim, depth, heads, mlp_dim,
                 pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0):
       
        super(BiFormer, self).__init__()

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_embedding = Sequential([
            Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Dense(units=dim)
        ], name='patch_embedding')
        self.downsample = Sequential()
        for i in embed_size:
          self.downsample.add(nn.Conv2D(i, 3, activation='relu', padding="same"))
          self.downsample.add(nn.BatchNormalization())
        self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]))
        self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]))
        self.dropout = nn.Dropout(rate=emb_dropout)

        self.block = block(dim, depth, heads, dim_head, mlp_dim,layer_scale_val=0.4,dropout= dropout)

        self.pool = pool

        self.mlp_head = Sequential([
            nn.LayerNormalization(),
            nn.Dense(units=num_classes)
        ], name='mlp_head')

    def call(self, img, training=True, **kwargs):
        x = self.downsample(img)
        x = self.patch_embedding(x)
        b, n, d = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = tf.concat([cls_tokens, x], axis=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x, training=training)

        x = self.block(x, training=training)

        if self.pool == 'mean':
            y = tf.reduce_mean(x, axis=1)
        else:
            y = x[:, 0]

        y = self.mlp_head(y)
        return y

In [None]:
v = BiFormer(
    image_size = 32,
    embed_size=[2,4,6,8],
    patch_size = 4,
    num_classes = 10,
    dim = 128,
    depth = 6,
    heads = 16,
    mlp_dim = 256,
    dropout = 0.1,
    emb_dropout = 0.1
)

In [None]:
img = tf.random.normal([2, 32, 32, 3])
outa = v(img)
print(outa)

tf.Tensor(
[[ 1.8189895  -2.9456773   3.1443267  -1.1026647   0.23636061  0.88873965
  -2.1820972   0.377267   -1.5959678  -1.7842641 ]
 [ 0.8837092  -2.3297973   1.4872115  -1.6793201   1.268717    0.75414175
  -0.925108   -0.1175254  -1.3871202  -0.61367285]], shape=(2, 10), dtype=float32)


In [None]:
import keras
num_classes = 10
input_shape = (32, 32, 3)
batch_size = 32

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 10)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 10)


In [None]:
# norr = tf.keras.layers.Resizing(
#     224, 224, interpolation="bicubic", crop_to_aspect_ratio=False)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)

In [None]:
optimizer = keras.optimizers.Adam()
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()

In [None]:
import numpy as np 

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = v(x, training=True)
        loss_value = loss_fn(y, logits)
        # print(loss_value)
    grads = tape.gradient(loss_value, v.trainable_weights)
    # print(2)
    optimizer.apply_gradients(zip(grads, v.trainable_weights))
    # print(3)
    logits = v(x,training=False)
    train_acc_metric.update_state(y, logits)
    return loss_value
@tf.function
def test_step(x, y):
    val_logits = v(x, training=False)
    val_acc_metric.update_state(y, val_logits)
import time
train_acc_list=[]
train_loss_list=[]
epochs = 10
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()
    train_loss = []
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)
        train_loss.append(float(loss_value))
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))
    train_loss_list.append(np.mean(train_loss))
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))
    train_acc_list.append(float(train_acc))
    train_acc_metric.reset_states()
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 3.1575
Seen so far: 32 samples
Training loss (for one batch) at step 200: 1.8259
Seen so far: 6432 samples
Training loss (for one batch) at step 400: 1.8718
Seen so far: 12832 samples
Training loss (for one batch) at step 600: 1.9602
Seen so far: 19232 samples
Training loss (for one batch) at step 800: 1.7096
Seen so far: 25632 samples
Training loss (for one batch) at step 1000: 1.5373
Seen so far: 32032 samples
Training loss (for one batch) at step 1200: 1.2513
Seen so far: 38432 samples
Training loss (for one batch) at step 1400: 1.2664
Seen so far: 44832 samples
Training acc over epoch: 0.4150
Validation acc: 0.4709
Time taken: 110.04s

Start of epoch 1
Training loss (for one batch) at step 0: 1.5677
Seen so far: 32 samples
Training loss (for one batch) at step 200: 1.7078
Seen so far: 6432 samples
Training loss (for one batch) at step 400: 1.8163
Seen so far: 12832 samples
Training loss (for one batch) at step 600: 1.4782
S

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_loss_list)
plt.show()

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_acc_list)
plt.show()