In [None]:
# train.py
import os
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB0, ResNet50
from tensorflow.keras.preprocessing import image_dataset_from_directory
import numpy as np
import pathlib
import configs
from utils import plot_history

tf.random.set_seed(configs.SEED)

def get_base_model(name="efficientnet", input_shape=(224,224,3)):
    if name == "efficientnet":
        base = tf.keras.applications.EfficientNetB0(weights="imagenet", include_top=False, input_shape=input_shape)
        preprocess_input = tf.keras.applications.efficientnet.preprocess_input
    elif name == "resnet50":
        base = tf.keras.applications.ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)
        preprocess_input = tf.keras.applications.resnet.preprocess_input
    else:
        raise ValueError("Unknown base model")
    return base, preprocess_input

def make_datasets(data_dir, img_size=(224,224), batch_size=32, seed=42):
    train_ds = image_dataset_from_directory(
        os.path.join(data_dir, "train"),
        shuffle=True,
        batch_size=batch_size,
        image_size=img_size,
        seed=seed
    )
    val_ds = image_dataset_from_directory(
        os.path.join(data_dir, "val"),
        shuffle=False,
        batch_size=batch_size,
        image_size=img_size,
        seed=seed
    )
    class_names = train_ds.class_names
    AUTOTUNE = tf.data.AUTOTUNE
    # data augmentation layer
    data_augmentation = tf.keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ])
    def prepare(ds, training=False):
        ds = ds.map(lambda x, y: (tf.cast(x, tf.float32), y), num_parallel_calls=AUTOTUNE)
        if training:
            ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
        return ds.prefetch(AUTOTUNE)
    return prepare(train_ds, training=True), prepare(val_ds, training=False), class_names

def build_model(num_classes, img_size=(224,224), base_name="efficientnet", dropout_rate=0.3):
    input_shape = img_size + (3,)
    base_model, preprocess_input = get_base_model(base_name, input_shape=input_shape)
    base_model.trainable = False
    inputs = layers.Input(shape=input_shape)
    x = preprocess_input(inputs)
    x = base_model(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model, base_model

def compile_model(model, lr):
    model.compile(
        optimizer=optimizers.Adam(learning_rate=lr),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

def main():
    data_dir = configs.DATA_DIR
    batch_size = configs.BATCH_SIZE
    img_size = configs.IMG_SIZE
    model_dir = configs.MODEL_DIR
    os.makedirs(model_dir, exist_ok=True)

    train_ds, val_ds, class_names = make_datasets(data_dir, img_size=img_size, batch_size=batch_size, seed=configs.SEED)
    num_classes = len(class_names)
    model, base_model = build_model(num_classes, img_size=img_size, base_name=configs.BASE_MODEL)
    compile_model(model, configs.LEARNING_RATE_HEAD)

    # Callbacks
    ckpt_path = os.path.join(model_dir, "best_model.h5")
    cb = [
        callbacks.ModelCheckpoint(ckpt_path, monitor="val_accuracy", save_best_only=True, mode="max"),
        callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, verbose=1),
        callbacks.EarlyStopping(monitor="val_loss", patience=6, restore_best_weights=True)
    ]

    # Train head
    history1 = model.fit(train_ds, validation_data=val_ds, epochs=configs.INITIAL_EPOCHS, callbacks=cb)

    # Fine-tuning
    base_model.trainable = True
    # Freeze all layers until the specified layer
    fine_tune_at = configs.FINE_TUNE_AT
    for i, layer in enumerate(base_model.layers):
        if i < fine_tune_at:
            layer.trainable = False
        else:
            layer.trainable = True

    compile_model(model, configs.LEARNING_RATE_FINE)
    history2 = model.fit(train_ds, validation_data=val_ds, epochs=configs.INITIAL_EPOCHS + configs.FINE_TUNE_EPOCHS,
                         initial_epoch=history1.epoch[-1]+1, callbacks=cb)
    # Save final model (SavedModel format)
    saved = os.path.join(model_dir, "saved_model")
    model.save(saved)
    # plot history
    try:
        plot_history(history2, save_path=os.path.join(model_dir, "training_history.png"))
    except Exception as e:
        print("Could not save plot:", e)
    print("Class names:", class_names)
    print("Saved model to:", saved)

if _name_ == "_main_":
    main()