<a href="https://colab.research.google.com/github/s34836/WUM/blob/main/Lab_11_Transfer_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transfer Learning
## Example

In [None]:
!tar -xzf imagenette2.tgz

In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf

image_size = (224, 224)
batch_size = 32

(train_ds, val_ds), ds_info = tfds.load(
    "cats_vs_dogs",
    split=["train[:80%]", "train[80%:]"],
    as_supervised=True,
    with_info=True
)

In [2]:
preprocess = tf.keras.applications.mobilenet_v2.preprocess_input

def prepare(image, label):
    image = tf.image.resize(image, image_size)
    image = preprocess(image)
    return image, label

train = (
    train_ds
    .map(prepare, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(1000)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

valid = (
    val_ds
    .map(prepare, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
import tensorflow as tf

image_size = (224, 224)
batch_size = 32

data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input, validation_split=0.2)

train = data_generator.flow_from_directory(
    "kagglecatsanddogs_5340/PetImages",
    target_size=image_size,
    batch_size=batch_size,
    class_mode='binary',
    subset='training')

valid = data_generator.flow_from_directory(
    "kagglecatsanddogs_5340/PetImages",
    target_size=image_size,
    batch_size=batch_size,
    class_mode='binary',
    subset='validation')


In [None]:
image_shape = (224, 224, 3)
base_model = tf.keras.applications.MobileNetV2(input_shape=image_shape,
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = False

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(shape=image_shape),
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(1, activation="sigmoid")
])

model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss="binary_crossentropy", metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
model.fit(train, validation_data=valid, batch_size=32, epochs=3)

## Tasks
1. Use one of the pretrained models available in Tensorflow to classiy images in the `imagenette2` dataset. See the list of available models [here](https://keras.io/api/applications/).
2. (optional) Fine-tune the pretrained model. Unfreeze the last few convolutional layers of the model trained in Task 1 by setting `trainable=True`. Then recompile the model and train it for a few more epochs with a low learning rate.

In [None]:
import tensorflow as tf

# --- ustawienia
image_size = (224, 224)
batch_size = 32
train_dir = "imagenette2/train"
val_dir   = "imagenette2/val"

# --- generatory (preprocess POD MobileNetV2)
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)
val_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

train_ds = train_gen.flow_from_directory(
    train_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode="categorical",  # Imagenette ma 10 klas
    shuffle=True
)

val_ds = val_gen.flow_from_directory(
    val_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=False
)

# --- pretrained backbone
base = tf.keras.applications.MobileNetV2(
    input_shape=(image_size[0], image_size[1], 3),
    include_top=False,
    weights="imagenet"
)
base.trainable = False  # najpierw trenujemy tylko head

# --- model
inputs = tf.keras.Input(shape=(image_size[0], image_size[1], 3))
x = base(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(train_ds.num_classes, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(train_ds, validation_data=val_ds, epochs=5)

# --- opcjonalny fine-tuning: odmroź kawałek backbone
base.trainable = True
for layer in base.layers[:-30]:
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history_ft = model.fit(train_ds, validation_data=val_ds, epochs=5)

In [None]:
import tensorflow as tf

# --- ustawienia
image_size = (224, 224)
batch_size = 32
train_dir = "imagenette2/train"
val_dir   = "imagenette2/val"

# --- generatory (preprocess POD EfficientNetV2)
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.efficientnet_v2.preprocess_input
)
val_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.efficientnet_v2.preprocess_input
)

train_ds = train_gen.flow_from_directory(
    train_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode="categorical",  # Imagenette ma 10 klas
    shuffle=True
)

val_ds = val_gen.flow_from_directory(
    val_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=False
)

# --- pretrained backbone
base = tf.keras.applications.EfficientNetV2B0(
    input_shape=(image_size[0], image_size[1], 3),
    include_top=False,
    weights="imagenet"
)
base.trainable = False  # najpierw trenujemy tylko head

# --- model
inputs = tf.keras.Input(shape=(image_size[0], image_size[1], 3))
x = base(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(train_ds.num_classes, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(train_ds, validation_data=val_ds, epochs=5)

# --- opcjonalny fine-tuning: odmroź kawałek backbone
base.trainable = True
for layer in base.layers[:-30]:
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history_ft = model.fit(train_ds, validation_data=val_ds, epochs=5)

In [None]:
import tensorflow as tf

# 1) Odmroź ostatnie kilka warstw KONWOLUCYJNYCH (Conv2D / DepthwiseConv2D)
#    Resztę zostaw zamrożoną.
def unfreeze_last_conv_layers(base_model, n_conv_layers_to_unfreeze=20):
    conv_like = (tf.keras.layers.Conv2D, tf.keras.layers.DepthwiseConv2D)
    conv_layers = [l for l in base_model.layers if isinstance(l, conv_like)]

    # nic nie odmrażamy jeśli jest za mało warstw
    n = min(n_conv_layers_to_unfreeze, len(conv_layers))
    to_unfreeze = set(conv_layers[-n:])

    base_model.trainable = True
    for layer in base_model.layers:
        # odmrażamy tylko wybrane convy
        layer.trainable = (layer in to_unfreeze)

    # opcjonalnie: wypisz ile warstw trainable
    trainable_cnt = sum(l.trainable for l in base_model.layers)
    print(f"Base model layers trainable: {trainable_cnt} / {len(base_model.layers)}")

# base to Twój EfficientNetV2B0 z wcześniejszego kodu
unfreeze_last_conv_layers(base, n_conv_layers_to_unfreeze=20)

# 2) Recompile z niskim learning rate (kluczowy krok)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

# 3) Trenuj kilka epok + sensowne callbacki
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy", patience=2, restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.2, patience=1, min_lr=1e-7
    )
]

history_ft = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,
    callbacks=callbacks
)