In [4]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import os

# ---- 1. TPU Configuration ----
try:
    # Detect and initialize TPU
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU:', tpu.cluster_spec().as_dict()['worker'])
    
    # Connect to TPU cluster
    tf.config.experimental_connect_to_cluster(tpu)
    
    # Initialize TPU system
    tf.tpu.experimental.initialize_tpu_system(tpu)
    
    # Create distribution strategy for TPU
    strategy = tf.distribute.TPUStrategy(tpu)
    
    print("TPU detected and configured successfully!")
    print(f"Number of accelerators: {strategy.num_replicas_in_sync}")
    
    # Set mixed precision policy for TPU
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
    print("Using mixed precision bfloat16 policy for TPU")
    
except ValueError:
    print("No TPU detected, falling back to GPU/CPU.")
    # Fallback to GPU configuration
    physical_devices = tf.config.list_physical_devices('GPU')
    if physical_devices:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
        strategy = tf.distribute.MirroredStrategy()
        print(f"Using GPU with {strategy.num_replicas_in_sync} device(s)")
    else:
        strategy = tf.distribute.get_strategy()
        print("Using CPU")

# ---- 2. Kaggle Dataset & Model Paths ----
# Kaggle input directory typically contains the dataset
DATASET_PATH = "../input/fruits"  # Adjust if needed
MODEL_PATH = "./mobilenet_fruits360_optimized.h5"
CHECKPOINT_PATH = "./checkpoints/model_checkpoint.h5"

# Create checkpoint directory if it doesn't exist
os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)

# ---- 3. Check if Dataset exists ----
if not os.path.exists(DATASET_PATH):
    print(f"Dataset not found at {DATASET_PATH}")
    print("Please make sure to add the 'fruits-360-dataset' to your Kaggle notebook.")
    # Check common alternate locations in Kaggle
    alt_paths = [
        "../input/fruits-360",
        "../input/fruit-images-for-object-detection",
        "../input/fruits-360_dataset"
    ]
    for path in alt_paths:
        if os.path.exists(path):
            print(f"Found dataset at alternate location: {path}")
            DATASET_PATH = path
            break

# ---- 4. Optimized Data Loading & Augmentation ----
# Increase batch size for TPU performance
BATCH_SIZE = 512  # TPUs perform better with larger batch sizes
IMG_SIZE = 96  

# Validation split reduced to 0.1
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1.0/255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    zoom_range=0.2,
    validation_split=0.1,
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

# Use TF data API for more efficient data loading
def create_dataset(generator, subset):
    dataset = tf.data.Dataset.from_generator(
        lambda: generator.flow_from_directory(
            DATASET_PATH,
            target_size=(IMG_SIZE, IMG_SIZE),
            batch_size=BATCH_SIZE,
            class_mode="categorical",
            subset=subset,
            shuffle=True if subset == "training" else False
        ),
        output_signature=(
            tf.TensorSpec(shape=(None, IMG_SIZE, IMG_SIZE, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, None), dtype=tf.float32)
        )
    )
    
    # TPU optimization: cache, optimize for TPU processing, prefetch
    dataset = dataset.cache()
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

# Try to find the correct training data directory
# Kaggle's Fruits-360 dataset might have different directory structures
possible_data_dirs = [
    DATASET_PATH,
    os.path.join(DATASET_PATH, "fruits-360"),
    os.path.join(DATASET_PATH, "fruits-360_dataset", "fruits-360"),
    os.path.join(DATASET_PATH, "Training")
]

data_dir_found = False
for dir_path in possible_data_dirs:
    if os.path.exists(dir_path):
        # Check if this directory contains subdirectories (classes)
        subdirs = [f for f in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, f))]
        if subdirs:
            print(f"Found valid dataset directory: {dir_path}")
            DATASET_PATH = dir_path
            data_dir_found = True
            break

if not data_dir_found:
    print("Warning: Could not automatically find the correct dataset directory structure.")
    print("Please verify the dataset path and structure manually.")

# Now create datasets
try:
    train_ds = create_dataset(train_datagen, "training")
    val_ds = create_dataset(train_datagen, "validation")

    # Get number of classes
    temp_generator = train_datagen.flow_from_directory(
        DATASET_PATH, target_size=(IMG_SIZE, IMG_SIZE), batch_size=1, class_mode="categorical", subset="training"
    )
    num_classes = len(temp_generator.class_indices)
    steps_per_epoch = temp_generator.samples // BATCH_SIZE
    
    # Get validation samples count
    val_generator = train_datagen.flow_from_directory(
        DATASET_PATH, target_size=(IMG_SIZE, IMG_SIZE), batch_size=1, class_mode="categorical", subset="validation"
    )
    validation_steps = val_generator.samples // BATCH_SIZE
    validation_steps = max(1, validation_steps)  # Ensure at least 1 step

    # Print dataset statistics
    print(f"Number of classes: {num_classes}")
    print(f"Training samples: {temp_generator.samples}")
    print(f"Validation samples: {val_generator.samples}")
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Validation steps: {validation_steps}")
    
except Exception as e:
    print(f"Error setting up dataset: {e}")
    print("\nTrying alternate dataset structure (Training/Test directories)...")
    
    # Try alternate directory structure common in Kaggle datasets
    TRAIN_DIR = os.path.join(DATASET_PATH, "Training")
    TEST_DIR = os.path.join(DATASET_PATH, "Test")
    
    if os.path.exists(TRAIN_DIR) and os.path.exists(TEST_DIR):
        # For the alternative structure, we'll manually split the training data 90/10
        from sklearn.model_selection import train_test_split
        import numpy as np
        
        # Set up with manual validation split
        train_datagen_alt = tf.keras.preprocessing.image.ImageDataGenerator(
            rescale=1.0/255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
            zoom_range=0.2,
            preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
        )
        
        # Separate validation generator without augmentation
        val_datagen_alt = tf.keras.preprocessing.image.ImageDataGenerator(
            rescale=1.0/255,
            preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
        )
        
        # Get all image files and classes
        import glob
        
        all_images = []
        all_labels = []
        class_dirs = [d for d in os.listdir(TRAIN_DIR) if os.path.isdir(os.path.join(TRAIN_DIR, d))]
        class_to_idx = {cls_name: i for i, cls_name in enumerate(class_dirs)}
        
        for cls_name in class_dirs:
            cls_path = os.path.join(TRAIN_DIR, cls_name)
            for img_path in glob.glob(os.path.join(cls_path, "*.jpg")):
                all_images.append(img_path)
                all_labels.append(class_to_idx[cls_name])
        
        # Split with 90% training, 10% validation
        train_imgs, val_imgs, train_labels, val_labels = train_test_split(
            all_images, all_labels, test_size=0.1, stratify=all_labels, random_state=42
        )
        
        print(f"Manual split created: {len(train_imgs)} training images, {len(val_imgs)} validation images")
        
        # Create custom generators from file lists - optimized for TPU
        from tensorflow.keras.utils import to_categorical
        
        def create_tpu_dataset_from_files(image_paths, labels, num_classes, is_training=True):
            def _parse_function(filename, label):
                image_string = tf.io.read_file(filename)
                image = tf.image.decode_jpeg(image_string, channels=3)
                image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
                image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
                
                # Apply random augmentations if training
                if is_training:
                    image = tf.image.random_flip_left_right(image)
                    image = tf.image.random_brightness(image, 0.2)
                    image = tf.image.random_contrast(image, 0.8, 1.2)
                
                label = tf.one_hot(label, depth=num_classes)
                return image, label
            
            # Create a dataset from the file paths and labels
            dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
            
            # Shuffle if training
            if is_training:
                dataset = dataset.shuffle(buffer_size=len(image_paths))
            
            # Parse files and augment
            dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
            
            # Batch and prefetch
            dataset = dataset.batch(BATCH_SIZE)
            dataset = dataset.cache()
            dataset = dataset.prefetch(tf.data.AUTOTUNE)
            
            return dataset
            
        num_classes = len(class_dirs)
        
        # Create TPU-optimized datasets
        train_ds = create_tpu_dataset_from_files(train_imgs, train_labels, num_classes, is_training=True)
        val_ds = create_tpu_dataset_from_files(val_imgs, val_labels, num_classes, is_training=False)
        
        steps_per_epoch = (len(train_imgs) + BATCH_SIZE - 1) // BATCH_SIZE  # Ceiling division
        validation_steps = (len(val_imgs) + BATCH_SIZE - 1) // BATCH_SIZE
        
        print(f"Alternative dataset structure processed for TPU:")
        print(f"Number of classes: {num_classes}")
        print(f"Training samples: {len(train_imgs)}")
        print(f"Validation samples: {len(val_imgs)}")
        print(f"Steps per epoch: {steps_per_epoch}")
        print(f"Validation steps: {validation_steps}")
    else:
        raise Exception("Could not find valid dataset structure. Please check your dataset path.")

# ---- 5. Define Model Creation Function within Strategy Scope ----
# TPU Strategy Scope for model creation
def create_model():
    # Use smaller input size and alpha parameter for faster inference
    base_model = MobileNetV2(
        weights="imagenet", 
        include_top=False, 
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        alpha=0.75  # Smaller network (75% of filters)
    )

    # Freeze base model for initial training
    base_model.trainable = False

    # Efficient Model Head
    x = base_model.output
    x = GlobalAveragePooling2D(name="gap")(x)
    x = Dense(128, activation="relu", name="dense_1")(x)
    x = Dropout(0.4, name="dropout_1")(x)
    # Force float32 output for TPU compatibility
    output_layer = Dense(num_classes, activation="softmax", dtype='float32', name="output")(x)

    model = Model(inputs=base_model.input, outputs=output_layer)
    
    # Learning rate schedule for better convergence
    initial_learning_rate = 0.001
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=steps_per_epoch*2,
        decay_rate=0.9,
        staircase=True
    )

    # TPU-optimized compilation
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss="categorical_crossentropy",
        metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3_acc")]
    )
    
    return model, base_model

# Create model inside TPU strategy scope
with strategy.scope():
    model, base_model = create_model()

# Summary of model architecture
print("Model Architecture Summary:")
model.summary()

# ---- 8. Callbacks for Better Training ----
# Ensure TPU compatibility for callbacks
callbacks = [
    # Reduced patience for early stopping
    EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True),
    
    # Save model checkpoints
    ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        save_best_only=True,
        monitor='val_accuracy',
        mode='max'
    ),
    
    # Reduced patience for learning rate reduction
    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=1, min_lr=1e-6),
    
    # TensorBoard logging
    tf.keras.callbacks.TensorBoard(log_dir="./logs", histogram_freq=1)
]

# ---- 9. Initial Training Phase ----
print("\nStarting initial training phase on TPU...")
history = model.fit(
    train_ds,
    epochs=5,
    validation_data=val_ds,
    callbacks=callbacks,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps
)

# ---- 10. Selective Fine-Tuning ----
print("\nStarting fine-tuning phase...")
# Need to update the model inside TPU strategy scope
with strategy.scope():
    # Unfreeze the last block of the MobileNetV2 model
    for layer in base_model.layers[-12:]:
        layer.trainable = True

    # Count trainable parameters
    trainable_count = sum(tf.keras.backend.count_params(w) for w in model.trainable_weights)
    non_trainable_count = sum(tf.keras.backend.count_params(w) for w in model.non_trainable_weights)
    print(f"Trainable parameters: {trainable_count:,}")
    print(f"Non-trainable parameters: {non_trainable_count:,}")

    # Use a much smaller learning rate for fine-tuning
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-5),
        loss="categorical_crossentropy",
        metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3_acc")]
    )

# Fine-tune with early stopping
history_finetune = model.fit(
    train_ds,
    epochs=5,
    validation_data=val_ds,
    callbacks=callbacks,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps
)

# ---- 11. Evaluation ----
print("\nEvaluating model on validation set...")
evaluation = model.evaluate(val_ds, steps=validation_steps)
print(f"Final validation loss: {evaluation[0]:.4f}")
print(f"Final validation accuracy: {evaluation[1]:.4f}")
print(f"Final validation top-3 accuracy: {evaluation[2]:.4f}")

# ---- 12. Save Models ----
# Save the Keras model to Kaggle's output directory
model.save(MODEL_PATH)
print(f"Saved Keras model to {MODEL_PATH}")

# Convert to TensorFlow Lite for deployment
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model
tflite_path = os.path.join(os.path.dirname(MODEL_PATH), 'model.tflite')
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
print(f"Saved TFLite model to {tflite_path}")

# ---- 13. Output class indices for later use ----
# Save the class indices for inference
import json

# Handle different dataset structures
if 'temp_generator' in locals():
    class_indices = temp_generator.class_indices
elif 'class_to_idx' in locals():
    class_indices = {cls: idx for cls, idx in class_to_idx.items()}
else:
    class_indices = {}

with open('class_indices.json', 'w') as f:
    json.dump(class_indices, f)
print("Saved class indices to class_indices.json")

# ---- 14. Sample prediction code ----
print("\nSample code for making predictions:")
print("""
# Code to load and use the model for prediction
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing import image
import json

# Load the model
model = tf.keras.models.load_model('mobilenet_fruits360_optimized.h5')

# Load class indices
with open('class_indices.json', 'r') as f:
    class_indices = json.load(f)
    
# Invert the dictionary to map indices to class names
idx_to_class = {v: k for k, v in class_indices.items()}

# Function to preprocess and predict
def predict_fruit(img_path):
    img = image.load_img(img_path, target_size=(96, 96))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array)
    
    predictions = model.predict(img_array)
    predicted_class_idx = np.argmax(predictions[0])
    confidence = predictions[0][predicted_class_idx] * 100
    
    return idx_to_class[predicted_class_idx], confidence

# Example usage
# fruit_name, confidence = predict_fruit('path/to/your/fruit/image.jpg')
# print(f'Predicted fruit: {fruit_name} with {confidence:.2f}% confidence')
""")

print("\nTraining and optimization with TPU complete!")

# ---- 15. Create a simple visualization of training history ----
try:
    import matplotlib.pyplot as plt
    
    # Plot training & validation accuracy
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model Accuracy (Initial Training)')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.subplot(1, 2, 2)
    plt.plot(history_finetune.history['accuracy'])
    plt.plot(history_finetune.history['val_accuracy'])
    plt.title('Model Accuracy (Fine-tuning)')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    print("Saved training history visualization to 'training_history.png'")
except Exception as e:
    print(f"Could not create visualization: {e}")

No TPU detected, falling back to GPU/CPU.
Using CPU
Dataset not found at ../input/fruits-360-dataset
Please make sure to add the 'fruits-360-dataset' to your Kaggle notebook.
Please verify the dataset path and structure manually.
Error setting up dataset: [Errno 2] No such file or directory: '../input/fruits-360-dataset'

Trying alternate dataset structure (Training/Test directories)...


Exception: Could not find valid dataset structure. Please check your dataset path.