# tf.flowers Dataset with VIT Pretrained Model

In [2]:

import tensorflow_datasets as tfds
import tensorflow as tf

# 데이터셋 로드 및 준비
dataset_name = "tf_flowers"
(train_ds, val_ds), ds_info = tfds.load(
    dataset_name,
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,
    with_info=True
)
num_classes = ds_info.features['label'].num_classes

# 데이터 전처리 함수 정의
def preprocess_image(image, label, img_size=(224, 224)):
    image = tf.image.resize(image, img_size)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train_ds = train_ds.map(lambda x, y: preprocess_image(x, y)).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(lambda x, y: preprocess_image(x, y)).batch(32).prefetch(tf.data.AUTOTUNE)


## VIT Pretrained Model

In [4]:

from tensorflow.keras import layers, models
import tensorflow_hub as hub

# VIT 모델 로드
vit_url = "https://tfhub.dev/sayakpaul/vit_b16_fe/1"  # 올바른 URL로 수정
vit_layer = hub.KerasLayer(vit_url, trainable=False, name="vit_base")

# 모델 정의
def create_model():
    inputs = layers.Input(shape=(224, 224, 3))
    # Convert KerasTensor to a regular Tensor
    vit_features = vit_layer(inputs)
    x = layers.Dense(128, activation="relu")(vit_features)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    return models.Model(inputs, outputs)

model = create_model()
model.summary()


ValueError: Exception encountered when calling layer 'vit_base' (type KerasLayer).

A KerasTensor is symbolic: it's a placeholder for a shape an a dtype. It doesn't have any actual numerical value. You cannot convert it to a NumPy array.

Call arguments received by layer 'vit_base' (type KerasLayer):
  • inputs=<KerasTensor shape=(None, 224, 224, 3), dtype=float32, sparse=False, name=keras_tensor_1>
  • training=None

## WandB Setup and Training

In [None]:

import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint

wandb.init(
    project="vit-tf-flowers",
    config={
        "epochs": 10,
        "batch_size": 32,
        "learning_rate": 0.001,
    }
)
config = wandb.config

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=config.epochs,
    callbacks=[
        WandbMetricsLogger(log_freq="epoch"),
        WandbModelCheckpoint("model_checkpoint.keras")
    ]
)


## Sweep Configuration and Execution

In [None]:

# Sweep 설정
sweep_config = {
    "method": "bayes",
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "learning_rate": {"values": [0.001, 0.0001, 0.00001]},
        "dropout": {"values": [0.3, 0.5]},
        "batch_size": {"values": [16, 32]},
    },
}

sweep_id = wandb.sweep(sweep_config, project="vit-tf-flowers")

# Sweep 실행 함수
def sweep_train(config_defaults=None):
    with wandb.init(config=config_defaults):
        config = wandb.config
        model = create_model()
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"]
        )
        model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=10,
            batch_size=config.batch_size,
            callbacks=[WandbMetricsLogger(log_freq="epoch")]
        )

wandb.agent(sweep_id, sweep_train, count=10)


## Artifacts for Model Management

In [None]:

# 모델 저장 및 Artifacts 등록
model.save("vit_tuned_model.h5")
artifact = wandb.Artifact("vit_tuned_model", type="model")
artifact.add_file("vit_tuned_model.h5")
wandb.log_artifact(artifact)
