# CIFAR-10 Transfer Learning 팀 과제 (실습 기반 버전)

In [None]:
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
input_shape = x_train.shape[1:]


In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(horizontal_flip=True,
                             vertical_flip=True,
                             rotation_range=40,
                             zoom_range=0.2,
                             brightness_range=[0.5, 1.5],
                             fill_mode='nearest',
                             channel_shift_range=50.0)
datagen.fit(x_train)


In [None]:
from tensorflow.keras.applications import VGG16, MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, BatchNormalization, Activation
from tensorflow.keras.initializers import HeNormal
from tensorflow.keras.optimizers import Adam, SGD

def build_model(base_model_type='vgg16', optimizer='adam', learning_rate=0.001, activation='relu'):
    input_tensor = Input(shape=input_shape)
    if base_model_type == 'vgg16':
        base_model = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
    elif base_model_type == 'mobilenet':
        base_model = MobileNetV2(include_top=False, weights='imagenet', input_tensor=input_tensor)
    else:
        raise ValueError("지원하지 않는 모델임.")

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, kernel_initializer=HeNormal())(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    output = Dense(10, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=output)

    if optimizer == 'adam':
        opt = Adam(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = SGD(learning_rate=learning_rate)
    else:
        raise ValueError("지원하지 않는 optimizer임.")

    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
    return model


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler

def lr_schedule(epoch, lr):
    if epoch > 10:
        return lr * 0.5
    return lr

callbacks = [
    ModelCheckpoint("best_model.h5", save_best_only=True, monitor='val_accuracy', verbose=1),
    LearningRateScheduler(lr_schedule)
]


In [None]:
!pip install scikeras

from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import RandomizedSearchCV

wrapped_model = KerasClassifier(model=build_model, epochs=5, batch_size=64, verbose=1)

param_dist = {
    'model__optimizer': ['adam', 'sgd'],
    'model__learning_rate': [0.001, 0.0001],
    'model__activation': ['relu', 'tanh'],
    'model__base_model_type': ['vgg16', 'mobilenet']
}

search = RandomizedSearchCV(estimator=wrapped_model, param_distributions=param_dist, n_iter=4, cv=3, verbose=2)
search.fit(x_train, y_train)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import numpy as np

# Cross-validation score plot
results = pd.DataFrame(search.cv_results_)
plt.plot(results['mean_test_score'], marker='o')
plt.title('Cross Validation Accuracy')
plt.xlabel('Trial')
plt.ylabel('Mean Accuracy')
plt.grid(True)
plt.show()

# Evaluate on test data
best_model = search.best_estimator_.model_
y_pred = best_model.predict(x_test)
y_pred_labels = np.argmax(y_pred, axis=1)
y_true_labels = np.argmax(y_test, axis=1)

test_acc = accuracy_score(y_true_labels, y_pred_labels)
print(f"✅ Test Accuracy: {test_acc:.4f}")

# Confusion Matrix
cm = confusion_matrix(y_true_labels, y_pred_labels)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'])
disp.plot(cmap='Blues', xticks_rotation='vertical')
plt.title("CIFAR-10 Confusion Matrix")
plt.show()
