### Imports

In [None]:
from datasets import load_dataset
from transformers import ViTImageProcessor
from transformers import ViTFeatureExtractor
from transformers import DefaultDataCollator
from transformers import TFViTForImageClassification, create_optimizer
import tensorflow as tf
from torchvision.transforms import (CenterCrop,
                                    Compose,
                                    Normalize,
                                    RandomHorizontalFlip,
                                    RandomResizedCrop,
                                    Resize,
                                    ToTensor)

### Check if utilizing GPU

In [None]:
print("TensorFlow version: ", tf.__version__)

device_name = tf.test.gpu_device_name()
if not device_name:
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))


### Prepare data

In [None]:
train = load_dataset("nelorth/oxford-flowers", split='train')
test = load_dataset("nelorth/oxford-flowers", split='test')

class_labels = train.features["label"].names

### Preprocess dataset

In [None]:
id2label = {id:label for id, label in enumerate(train.features["label"].names)}
label2id = {label:id for id,label in id2label.items()}
id2label


processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")


image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
    return examples


processed_train = train.map(train_transforms, batched=True)
processed_train
processed_test = test.map(val_transforms, batched=True)
processed_test

### Hyperparameters

In [None]:
num_train_epochs = 1
train_batch_size = 8
eval_batch_size = 8
learning_rate = 3e-5
weight_decay_rate=0.01
num_warmup_steps=0
output_dir="google/vit-base-patch16-224-in21k".split("/")[1]
fp16=True # set to True if you have a GPU

# Train in mixed-precision float16
# Comment this line out if you're using a GPU that will not benefit from this
if fp16:
    tf.keras.mixed_precision.set_global_policy("mixed_float16")


### Convert to TenforFlow dataset

In [None]:
# Data collator that will dynamically pad the inputs received, as well as the labels.
data_collator = DefaultDataCollator(return_tensors="tf")

# converting our train dataset to tf.data.Dataset
tf_train_dataset = processed_train.to_tf_dataset(
   columns=["pixel_values"],
   label_cols=["label"],
   shuffle=True,
   batch_size=train_batch_size,
   collate_fn=data_collator)

# converting our test dataset to tf.data.Dataset
tf_eval_dataset = processed_test.to_tf_dataset(
   columns=["pixel_values"],
   label_cols=["label"],
   shuffle=True,
   batch_size=eval_batch_size,
   collate_fn=data_collator)

### Compile model

In [None]:
# create optimizer wight weigh decay
num_train_steps = len(tf_train_dataset) * num_train_epochs
optimizer, lr_schedule = create_optimizer(
    init_lr=learning_rate,
    num_train_steps=num_train_steps,
    weight_decay_rate=weight_decay_rate,
    num_warmup_steps=num_warmup_steps,
)

# load pre-trained ViT model
model = TFViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(class_labels),
    id2label=id2label,
    label2id=label2id,
)

# define loss
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# define metrics
metrics=[
    tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
    tf.keras.metrics.SparseTopKCategoricalAccuracy(3, name="top-3-accuracy"),
]

# compile model
model.compile(optimizer=optimizer,
              loss=loss,
              metrics=metrics
              )

### Define callback

In [None]:
tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = 'logs',
                                                 histogram_freq = 1,
                                                 profile_batch = '100,200')

### Train model

In [None]:
train_results = model.fit(
    tf_train_dataset,
    validation_data=tf_eval_dataset,
    callbacks=[tboard_callback],
    epochs=num_train_epochs,
)

### Launch TensorBoard

In [None]:
%load_ext tensorboard
%tensorboard --logdir=logs