In [1]:
!pip install --quiet vit-keras

[0m

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from vit_keras import vit

In [3]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

In [4]:
train_generator = train_datagen.flow_from_directory(
        '/kaggle/input/chest-xray-pneumonia/chest_xray/train/',
        target_size=(128, 128),
        batch_size=8,
        class_mode='binary')

validation_generator = train_datagen.flow_from_directory(
        '/kaggle/input/chest-xray-pneumonia/chest_xray/test/',
        target_size=(128, 128),
        batch_size=8,
        class_mode='binary')

Found 5216 images belonging to 2 classes.
Found 624 images belonging to 2 classes.


In [5]:
vit_base_16 = vit.vit_b16(
        image_size = (128,128),
        activation = 'softmax',
        pretrained = None,
        include_top = False,
        pretrained_top = False,
        classes = 1)

vit_base_16_model = tf.keras.Sequential([
        vit_base_16,
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1, 'sigmoid')
    ])

In [6]:
vit_base_16_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vit-b16 (Functional)        (None, 768)               85697280  
                                                                 
 flatten (Flatten)           (None, 768)               0         
                                                                 
 dense (Dense)               (None, 1)                 769       
                                                                 
Total params: 85,698,049
Trainable params: 85,698,049
Non-trainable params: 0
_________________________________________________________________


In [7]:
METRICS = [
        'accuracy',
    ]
    
vit_base_16_model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=METRICS
    )

In [8]:
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.01, 20)

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)

In [9]:
history = vit_base_16_model.fit(
    train_generator,
    epochs=10,
    validation_data=validation_generator,
    callbacks=[lr_scheduler]
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [10]:
vit_base_16_model.save("ViT_xray_Pneumonia.h5")