In [None]:
import os

# Prevent TensorFlow GPU memory grabbing
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

# Fix TensorFlow layout optimizer bugs
os.environ["TF_DISABLE_OPTIMIZER_IN_LAYOUT"] = "1"
os.environ["TF_DISABLE_LAYOUT_OPTIMIZER"] = "1"

In [None]:
# ============================================================
# FAST + MEMORY EFFICIENT RETINA vs NON-RETINA CLASSIFIER
# EfficientNetB3 + Mixed Precision + tf.data (no CLAHE)
# Output: retina_model_best.h5 (for your app.py)
# ============================================================

import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import mixed_precision
import matplotlib.pyplot as plt
from google.colab import drive

# ============================================================
# MIXED PRECISION (Fast + Accurate)
# ============================================================
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
print("Mixed Precision:", mixed_precision.global_policy())

# ============================================================
# MOUNT DRIVE (Your dataset is in Drive)
# ============================================================
drive.mount('/content/drive')

ZIP_PATH = "/content/drive/MyDrive/retina_nonretina_dataset_balanced2.zip"
EXTRACT_PATH = "/content/retina_nonretina_dataset_balanced2"

import shutil
from zipfile import ZipFile

print("Extracting dataset...")
with ZipFile(ZIP_PATH, 'r') as z:
    z.extractall(EXTRACT_PATH)

train_dir = f"{EXTRACT_PATH}/train"
val_dir   = f"{EXTRACT_PATH}/val"

print("Train:", train_dir)
print("Val:", val_dir)

# ============================================================
# LOAD DATASET (FAST tf.data pipeline)
# ============================================================
IMG_SIZE = (300, 300)
BATCH_SIZE = 16      # safe + fast
SEED = 42

train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    label_mode="categorical",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=SEED
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    val_dir,
    label_mode="categorical",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)

print("Class indices:", train_ds.class_names)

# ============================================================
# GPU AUGMENTATION + EFFICIENTNET PREPROCESS
# ============================================================
from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess

def augment(images, labels):
    images = tf.image.random_flip_left_right(images)
    images = tf.image.random_brightness(images, max_delta=0.12)
    images = tf.image.random_contrast(images, 0.85, 1.15)
    images = tf.image.random_saturation(images, 0.9, 1.1)
    return images, labels

def preprocess(images, labels):
    images = tf.cast(images, tf.float32)
    images = eff_preprocess(images)
    return images, labels

AUTOTUNE = tf.data.AUTOTUNE

train_ds = (
    train_ds
    .shuffle(512)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .map(preprocess, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

val_ds = (
    val_ds
    .map(preprocess, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

# ============================================================
# BUILD EfficientNetB3 MODEL
# ============================================================
base = tf.keras.applications.EfficientNetB3(
    include_top=False,
    weights="imagenet",
    input_shape=(300,300,3)
)

# Create classifier head
inputs = layers.Input(shape=(300,300,3), dtype=tf.float32)
x = base(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation="relu", dtype="float32")(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(2, activation="softmax", dtype="float32")(x)

model = Model(inputs, outputs)
model.summary()

# ============================================================
# STAGE 1 TRAINING â€” Freeze Backbone
# ============================================================
for layer in base.layers:
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

checkpoint = ModelCheckpoint(
    "retina_model_best.h5",
    save_best_only=True,
    monitor="val_accuracy",
    mode="max"
)

earlystop = EarlyStopping(
    monitor="val_loss",
    patience=6,
    restore_best_weights=True
)

reduce_lr = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.3,
    patience=3
)

print("\nðŸ”µ Stage 1 training...")
history1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=2,
    callbacks=[checkpoint, earlystop, reduce_lr]
)

# ============================================================
# STAGE 2 â€” UNFREEZE ALL + FINE-TUNE
# ============================================================
for layer in base.layers:
    layer.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

print("\nðŸŸ¢ Stage 2 fine-tuning...")
history2 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,
    callbacks=[checkpoint, earlystop, reduce_lr]
)

# Final save
model.save("retina_model_final.h5")
print("\nâœ” Saved final: retina_model_final.h5")
print("âœ” Saved best:  retina_model_best.h5")

# ============================================================
# UPLOAD TO DRIVE
# ============================================================
import shutil
shutil.copy("retina_model_best.h5", "/content/drive/MyDrive/retina_model_best.h5")
shutil.copy("retina_model_final.h5", "/content/drive/MyDrive/retina_model_final.h5")

print("\nðŸ“¤ Uploaded to Drive")

# ============================================================
# PLOT TRAINING CURVES
# ============================================================
def plot_history(h1, h2):
    acc = h1.history["accuracy"] + h2.history["accuracy"]
    val_acc = h1.history["val_accuracy"] + h2.history["val_accuracy"]

    plt.figure(figsize=(10,5))
    plt.plot(acc, label="train")
    plt.plot(val_acc, label="val")
    plt.legend()
    plt.title("Accuracy")
    plt.show()

plot_history(history1, history2)
