In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [2]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50

In [3]:
config = {}
config["num_layers"] = 12
config["hidden_dim"] = 768
config["mlp_dim"] = 3072
config["num_heads"] = 12
config["dropout_rate"] = 0.1

config["image_size"] = 512
config["patch_size"] = 32
config["num_patches"] = int(config["image_size"]**2 / config["patch_size"]**2)
config["num_channels"] = 3
config["num_classes"] = 10

In [4]:
class ClassToken(Layer):
    def __init__(self,):
        super().__init__()
        
    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value = w_init(shape=(1,1,input_shape[-1]), dtype="float32"),
            trainable = True
        )
        
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]
        
        cls = tf.broadcast_to(self.w, [batch_size, 1,hidden_dim])
        cls = tf.cast(cls, dtype=inputs.dtype)
        
        return cls
    

In [5]:
def mlp(x, cf):
    x = Dense(cf["mlp_dim"], activation="gelu")(x)
    x = Dropout(cf["dropout_rate"])(x)
    x = Dense(cf["hidden_dim"])(x)
    x = Dropout(cf["dropout_rate"])(x)
    return x

In [6]:
def transformer_encoder(x, cf):
    skip_1 = x
    x = LayerNormalization()(x)
    x = MultiHeadAttention(
        num_heads=cf["num_heads"],
        key_dim=cf["hidden_dim"]
    )(x,x)
    x = Add()([x, skip_1])
    
    skip_2 = x
    x=LayerNormalization()(x)
    x = mlp(x,cf)
    x = Add()([x,skip_2])
    
    return x

In [7]:
def ResNet50ViT(cf):
    inputs = Input((cf["image_size"],cf["image_size"],cf["num_channels"]))
    
    #Pretrained ResNet50
    resnet50 = ResNet50(include_top=False, weights="imagenet", input_tensor=inputs)
    output = resnet50.output
    
    #Patch Embeddings
    patch_embed = Conv2D(
        cf["hidden_dim"],
        kernel_size=cf["patch_size"],
        padding="same"
    )(output) #(None, 16, 16 ,768)
    
    _, h, w, f = patch_embed.shape
    patch_embed = Reshape((h*w, f))(patch_embed) #(None, 256, 768)
    
    #Positional Embedding
    positions = tf.range(start=0, limit=cf["num_patches"], delta=1)
    pos_embed = Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions)
#     print(pos_embed.shape) (256, 768)
    
    #Patch + Position Embeddings
    embed = patch_embed + pos_embed
    #print(embed.shape) #(None, 256, 768)

    
    #Adding Class Token
    token = ClassToken()(embed)
    x = Concatenate(axis=1)([token,embed])
#     print(x.shape) (None, 257, 768)


    #Transformer Encoder
    for _ in range(cf["num_layers"]):
        x = transformer_encoder(x,cf)
        
    # print(x.shape)(None, 257, 768)
    x = x[:, 0 , :]
    # print(x.shape) (None, 768)
    
    x = Dense(cf["num_classes"], activation="softmax")(x)
    
    model = Model(inputs, x)
    return model

In [8]:
model = ResNet50ViT(config)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 512, 512, 3)]        0         []                            
                                                                                                  
 conv1_pad (ZeroPadding2D)   (None, 518, 518, 3)          0         ['input_1[0][0]']             
                                                                                                  
 conv1_conv (Conv2D)         (None, 256, 256, 64)         9472      ['conv1_pad[0][0]']           
                                                                                                  
 conv1_bn (BatchNormalizati  (None, 256, 256, 64)         256       ['conv1_conv[0][0]']          
 on)                                                                                          

In [9]:
from keras.preprocessing.image import ImageDataGenerator


In [10]:
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)


In [11]:
train_dir = 'tomato/train'
val_dir = 'tomato/val'

In [12]:
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(config["image_size"], config["image_size"]),
    batch_size=32,
    class_mode='categorical')

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(config["image_size"], config["image_size"]),
    batch_size=32,
    class_mode='categorical')


Found 10000 images belonging to 10 classes.
Found 1000 images belonging to 10 classes.


In [13]:
epochs = 10

In [14]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


In [16]:
train_generator

<keras.src.preprocessing.image.DirectoryIterator at 0x7fe8ccf46320>

In [None]:
history = model.fit(train_generator, epochs=10, batch_size=32, validation_data=val_generator)

Epoch 1/10
