In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# Enable mixed precision for faster GPU training and lower memory usage
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

# Configure GPU memory growth
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# Distributed strategy (multi-GPU) if available
tf_strategy = tf.distribute.MirroredStrategy()
print(f"GPUs detected: {tf_strategy.num_replicas_in_sync}")

# Configuration
DATA_DIR = 'data/'
IMAGE_SIZE = 224
GLOBAL_BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
EPOCHS_HEAD = 20
EPOCHS_FINE = 10
SEED = 42

# 1. Load file paths and labels
class_names = sorted(os.listdir(DATA_DIR))
paths, labels = [], []
for idx, cls in enumerate(class_names):
    cls_dir = os.path.join(DATA_DIR, cls)
    if os.path.isdir(cls_dir):
        cls_paths = glob(os.path.join(cls_dir, '*.*'))
        paths += cls_paths
        labels += [idx] * len(cls_paths)
paths = np.array(paths)
labels = np.array(labels)

# Shuffle and split
i = np.random.RandomState(SEED).permutation(len(paths))
paths, labels = paths[i], labels[i]
train_paths, test_paths, train_labels, test_labels = train_test_split(
    paths, labels, test_size=0.2, stratify=labels, random_state=SEED
)

# 2. Preprocessing function
def preprocess_image(path, label):
    img = cv2.imread(path.numpy().decode('utf-8'))
    img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
    rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    eq = cv2.equalizeHist(gray)
    blur = cv2.GaussianBlur(eq, (3,3), 0)
    _, binar = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
    closed = cv2.morphologyEx(binar, cv2.MORPH_CLOSE, kernel)
    mask = closed.astype(np.float32) / 255.0

    mask3 = np.stack([mask]*3, axis=-1)
    masked = (rgb * mask3).astype(np.uint8)

    # Normalize for custom CNN
    masked = masked.astype(np.float32) / 255.0
    return masked, label

@tf.function
def tf_preprocess(path, label):
    img, lbl = tf.py_function(preprocess_image, [path, label], [tf.float32, tf.int64])
    img.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])
    lbl.set_shape([])
    return img, lbl

# 3. Dataset builder
with tf_strategy.scope():
    def make_dataset(paths, labels, training=True):
        ds = tf.data.Dataset.from_tensor_slices((paths, labels))
        if training:
            ds = ds.shuffle(1000, seed=SEED)
        ds = ds.map(tf_preprocess, num_parallel_calls=AUTO)
        if training:
            ds = ds.repeat()
        ds = ds.batch(GLOBAL_BATCH_SIZE)
        ds = ds.prefetch(AUTO)
        return ds

    train_ds = make_dataset(train_paths, train_labels, training=True)
    val_ds = make_dataset(test_paths, test_labels, training=False)

    # 4. Model definition with augmentation and custom CNN
    aug = tf.keras.Sequential([
        tf.keras.layers.RandomFlip('horizontal'),
        tf.keras.layers.RandomRotation(0.1),
        tf.keras.layers.RandomZoom(0.1),
        tf.keras.layers.RandomContrast(0.1)
    ])

    def build_custom_cnn(input_shape, num_classes):
        inputs = tf.keras.Input(shape=input_shape)
        x = aug(inputs)

        x = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(x)
        x = tf.keras.layers.MaxPooling2D((2,2))(x)

        x = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
        x = tf.keras.layers.MaxPooling2D((2,2))(x)

        x = tf.keras.layers.Conv2D(128, (3,3), activation='relu', padding='same')(x)
        x = tf.keras.layers.MaxPooling2D((2,2))(x)

        x = tf.keras.layers.Conv2D(256, (3,3), activation='relu', padding='same')(x)
        x = tf.keras.layers.GlobalAveragePooling2D()(x)

        x = tf.keras.layers.Dropout(0.3)(x)
        outputs = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')(x)

        return tf.keras.Model(inputs, outputs)

    model = build_custom_cnn((IMAGE_SIZE, IMAGE_SIZE, 3), len(class_names))
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

# 5. Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3)
]

# 6. Train head
steps_per_epoch = len(train_paths) // GLOBAL_BATCH_SIZE
history_head = model.fit(
    train_ds,
    epochs=EPOCHS_HEAD,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=len(test_paths) // GLOBAL_BATCH_SIZE,
    callbacks=callbacks
)


# 7. Final evaluation on CPU
y_preds = []
y_true = []
for imgs, lbls in val_ds.take(len(test_paths) // GLOBAL_BATCH_SIZE + 1):
    preds = np.argmax(model.predict(imgs), axis=1)
    y_preds.extend(preds)
    y_true.extend(lbls.numpy())

print(classification_report(y_true, y_preds, target_names=class_names))
