## Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tensorflow import keras

tfds.disable_progress_bar()
tf.random.set_seed(42)

## Constants

In [None]:
MODULE_URL = "https://tfhub.dev/google/bit/m-r50x3/1"

BATCH_SIZE = 128
SZ = 224
NUM_EPOCHS = 10

AUTO = tf.data.AUTOTUNE
NB_CLASSES = 23

In [None]:
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1

# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001

# Data
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 23

## Data preprocessing and loading

In [None]:


train_data = '/content/drive/MyDrive/data60/trainData/'
val_data = '/content/drive/MyDrive/data60/valData/'

def create_image_dataset(data_dir, is_training=True):
    dataset = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        label_mode='categorical',  # Assuming your labels are in one-hot encoded format
        batch_size=BATCH_SIZE,
        image_size=(RESOLUTION, RESOLUTION),
        shuffle=is_training,
        seed=123  # Set a seed for reproducibility
    )

    if is_training:
        num_samples = len(dataset)
        # Calculate the number of samples for validation
        num_val_samples = int(0.2 * num_samples)  # You can adjust the validation split ratio

        # Split the dataset into training and validation sets
        dataset = dataset.skip(num_val_samples) if num_val_samples > 0 else dataset

    return dataset

train_dataset = create_image_dataset(train_data, is_training=True)
val_dataset = create_image_dataset(val_data, is_training=False)

def preprocess_dataset(image, label, is_training=True):
    # Standardize the image data (mean=0, std=1)
    image = (image - 127.5) / 127.5

    if is_training:
        # Data augmentation for the training dataset
        image = tf.image.random_flip_left_right(image)
        # Add more data augmentation techniques as needed

    return image, label

# Apply preprocessing function to both datasets
train_dataset = train_dataset.map(lambda x, y: preprocess_dataset(x, y, is_training=True), num_parallel_calls=AUTO)
val_dataset = val_dataset.map(lambda x, y: preprocess_dataset(x, y, is_training=False), num_parallel_calls=AUTO)

# Use AUTOTUNE for better performance
train_dataset = train_dataset.prefetch(AUTO)
val_dataset = val_dataset.prefetch(AUTO)

Found 6399 files belonging to 23 classes.
Found 2129 files belonging to 23 classes.


## Model initialization

In [None]:
hub_module = hub.KerasLayer(MODULE_URL)

model = keras.Sequential(
    [
        keras.Input((SZ, SZ, 3)),
        keras.layers.Rescaling(scale=1.0 / 255),
        hub_module,
        keras.layers.Dense(NB_CLASSES, kernel_initializer="zeros"),
    ],
    name="bit_teacher_flowers",
)
print(f"Number of parameters (millions): {model.count_params() / 1e6}.")

Number of parameters (millions): 211.315415.


## Optimizer and loss function

In [None]:
SCHEDULE_LENGTH = 500
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE

SCHEDULE_BOUNDARIES = [200, 300, 400]
lr = 0.003 * BATCH_SIZE / 512

# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=SCHEDULE_BOUNDARIES, values=[lr, lr * 0.1, lr * 0.001, lr * 0.0001]
)
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

## Train the model and save it

In [None]:
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

history = model.fit(
    train_dataset.repeat(),
    batch_size=BATCH_SIZE,
    steps_per_epoch=10,
    epochs=NUM_EPOCHS,
    validation_data=val_dataset,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [None]:
model.save("/content/drive/MyDrive/hyperKvasir")

## References

* [Official Colab Notebook from BiT authors](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb)