# 1. Import Modules
## (Korean) 기본 라이브러리를 불러옵니다.
## (English) Import essential libraries.

In [2]:
! pip install -q tensorflow tensorflow_hub tensorflow_datasets

: 

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
print(tf.test.is_gpu_available())

# 2. Load and Split TF Flowers Dataset
## (Korean) TF Flowers 데이터를 train, val, test로 분할합니다.
## (English) Split the TF Flowers dataset into train, val, and test.

In [None]:
dataset_name = "tf_flowers"
# 예: train[:70%], train[70%:85%], train[85%:]
train_split = "train[:70%]"
val_split = "train[70%:85%]"
test_split = "train[85%:]"

(raw_train, raw_val, raw_test), info = tfds.load(
    dataset_name,
    split=[train_split, val_split, test_split],
    as_supervised=True,
    with_info=True
)

num_classes = info.features['label'].num_classes
print("Number of classes:", num_classes)
print("Training samples:", len(raw_train))
print("Validation samples:", len(raw_val))
print("Test samples:", len(raw_test))

# 3. Data Preprocessing
## (Korean) 이미지 크기를 (224, 224)로 리사이즈하고, 0~1 범위로 스케일링합니다.
## (English) Resize images to (224, 224) and scale pixel values to [0, 1].

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32

def preprocess_image(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = image / 255.0
    return image, label

train_ds = (raw_train
    .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(1000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = (raw_val
    .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

test_ds = (raw_test
    .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

# 4. Load a Pretrained ViT Model via TensorFlow Hub
## (Korean) TensorFlow Hub에서 사전학습된 ViT 모델을 가져옵니다.
## (English) Fetch a pretrained ViT model from TensorFlow Hub.

In [None]:
vit_url = "https://tfhub.dev/google/vit/base_patch16_224/1"
vit_layer = hub.KerasLayer(vit_url, trainable=True, name="vit_layer")
print("ViT layer loaded.")

# 5. Build the Keras Model
## (Korean) ViT 임베딩 뒤에 Dense 레이어를 연결해 분류용 Keras 모델을 구성합니다.
## (English) Construct a Keras model by adding a Dense layer for classification after ViT embeddings.

In [None]:
inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name="input_images")
x = vit_layer(inputs)

# Classification head
outputs = tf.keras.layers.Dense(num_classes, activation='softmax', name="output_dense")(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name="ViT_Flowers")
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)
model.summary()

# 6. Train the Model
## (Korean) 모델을 일정 에폭 동안 학습합니다.
## (English) Train the model for a certain number of epochs.

In [None]:
EPOCHS = 3  # 데모 목적으로 짧게 설정

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS
)

# 7. Evaluate on Test Set
## (Korean) 테스트 세트를 통해 모델 성능을 확인합니다.
## (English) Check model performance on the test set.

In [None]:
test_loss, test_acc = model.evaluate(test_ds)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

# 8. Sample Prediction
## (Korean) 임의의 배치에 대해 예측을 수행해봅니다.
## (English) Perform a prediction on a random batch.

In [None]:
import itertools

# take(1)을 사용하여 첫 번째 배치를 가져옴
for images, labels in test_ds.take(1):
    predictions = model.predict(images)
    pred_labels = tf.argmax(predictions, axis=1)
    print("Predicted:", pred_labels.numpy())
    print("Actual:   ", labels.numpy())
    break

# 9. Visualization
## (Korean) 예측 결과를 시각화하여 확인합니다.
## (English) Visualize the prediction results.

In [None]:
class_names = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]
# TF Flowers가 5종류 꽃으로 구성되어 있으므로 이렇게 지정

# 위에서 뽑아둔 images, labels를 활용
plt.figure(figsize=(12, 8))
for i in range(6):
    ax = plt.subplot(2, 3, i+1)
    plt.imshow(images[i])
    title = f"Pred: {class_names[pred_labels[i]]}\nActual: {class_names[labels[i]]}"
    plt.title(title)
    plt.axis("off")
plt.tight_layout()
plt.show()

# (Korean) 이렇게 하면 ViT로 TF Flowers를 학습하고 평가해볼 수 있습니다.
# (English) This completes training and evaluating the TF Flowers dataset using ViT.