<a href="https://colab.research.google.com/github/wint3rx3/flowers_classification/blob/main/SwinTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install tensorflow tensorflow-hub

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import matplotlib.pyplot as plt

NUM_CLASSES = 5  # 실제 클래스 수에 맞게 수정
EPOCHS = 5      # 예시로 10으로 설정 (적절히 조정)
LEARNING_RATE = 1e-3

SWIN_MODEL_URL = "https://tfhub.dev/sayakpaul/swin_tiny_patch4_window7_224/1"

inputs = Input(shape=(224, 224, 3))

swin_layer = hub.KerasLayer(
    SWIN_MODEL_URL,
    trainable=False,
    output_shape=[1000]
)

features = Lambda(lambda x: swin_layer(x),
                  output_shape=lambda input_shape: (input_shape[0], 1000))(inputs)
x = Dense(256, activation='relu')(features)
outputs = Dense(NUM_CLASSES, activation='softmax')(x)

model_swin = Model(inputs=inputs, outputs=outputs)

model_swin.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model_swin.summary()

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=3,              # val_loss가 3 epoch 연속 개선되지 않으면 중단
    restore_best_weights=True
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,              # 학습률을 1/5로 감소
    patience=2,              # 2 epoch 동안 개선 없으면 감소
    min_lr=1e-6
)

history_swin = model_swin.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator,
    callbacks=[early_stopping, reduce_lr]
)

loss_swin, accuracy_swin = model_swin.evaluate(validation_generator)
print(f"Swin Validation Loss: {loss_swin:.4f}")
print(f"Swin Validation Accuracy: {accuracy_swin:.4f}")

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history_swin.history['loss'], label='Train Loss')
plt.plot(history_swin.history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Swin Transformer Loss over Epochs')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_swin.history['accuracy'], label='Train Accuracy')
plt.plot(history_swin.history['val_accuracy'], label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Swin Transformer Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()
