In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import os
import numpy as np
import glob
from sklearn.model_selection import train_test_split

# ---- 1. TPU Configuration ----
try:
    # Detect and initialize TPU
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU:', tpu.cluster_spec().as_dict()['worker'])
    
    # Connect to TPU cluster
    tf.config.experimental_connect_to_cluster(tpu)
    
    # Initialize TPU system
    tf.tpu.experimental.initialize_tpu_system(tpu)
    
    # Create distribution strategy for TPU
    strategy = tf.distribute.TPUStrategy(tpu)
    
    print("TPU detected and configured successfully!")
    print(f"Number of accelerators: {strategy.num_replicas_in_sync}")
    
    # Set mixed precision policy for TPU - Changed from bfloat16 to float16
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    print("Using mixed precision float16 policy for TPU")
    
    # Print TPU device information
    print("TPU device information:")
    for device in tf.config.list_logical_devices('TPU'):
        print(f" - {device}")
    
except ValueError:
    print("No TPU detected, falling back to GPU/CPU.")
    # Fallback to GPU configuration
    physical_devices = tf.config.list_physical_devices('GPU')
    if physical_devices:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
        strategy = tf.distribute.MirroredStrategy()
        print(f"Using GPU with {strategy.num_replicas_in_sync} device(s)")
    else:
        strategy = tf.distribute.get_strategy()
        print("Using CPU")

# ---- 2. Dataset Path and Model Settings ----
DATASET_PATH = "/kaggle/input/fruits/fruits-360_100x100/fruits-360"
MODEL_PATH = "./mobilenet_fruits360_optimized.keras"  # Changed to .keras extension
CHECKPOINT_PATH = "./checkpoints/model_checkpoint.keras"  # Changed to .keras extension
IMG_SIZE = 96
BATCH_SIZE = 128  # Will be adjusted based on TPU cores

# Create checkpoint directory if it doesn't exist
os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)

# ---- 3. Dataset Path Verification ----
print(f"\nVerifying dataset path: {DATASET_PATH}")
if not os.path.exists(DATASET_PATH):
    print(f"ERROR: Dataset path {DATASET_PATH} does not exist!")
    # Try to find alternative paths
    base_dirs = ["/kaggle/input", "/kaggle/input/fruits"]
    found_paths = []
    for base in base_dirs:
        if os.path.exists(base):
            print(f"Searching in {base} for fruit datasets...")
            for item in os.listdir(base):
                full_path = os.path.join(base, item)
                if os.path.isdir(full_path) and ("fruit" in item.lower() or "360" in item):
                    found_paths.append(full_path)
    
    if found_paths:
        print(f"Found potential dataset paths: {found_paths}")
        # Use the first found path as alternative
        DATASET_PATH = found_paths[0]
        print(f"Using alternative path: {DATASET_PATH}")
    else:
        raise Exception("No fruit dataset found! Please verify the dataset is available.")
else:
    print(f"Dataset path exists: {DATASET_PATH}")

# ---- 4. Adjust batch size to be divisible by TPU cores ----
# Make batch size divisible by replica count - important for TPU
if 'strategy' in locals() and hasattr(strategy, 'num_replicas_in_sync'):
    BATCH_SIZE = 128 * strategy.num_replicas_in_sync  # Base batch size per replica = 128
    print(f"Using TPU-optimized batch size: {BATCH_SIZE}")

# ---- 5. Dataset Directory Structure Analysis ----
print("\nAnalyzing dataset directory structure...")
try:
    contents = os.listdir(DATASET_PATH)
    for item in contents[:10]:  # Show first 10 items
        item_path = os.path.join(DATASET_PATH, item)
        if os.path.isdir(item_path):
            subcontents = os.listdir(item_path)
            subdir_count = len([i for i in subcontents if os.path.isdir(os.path.join(item_path, i))])
            file_count = len([i for i in subcontents if os.path.isfile(os.path.join(item_path, i))])
            print(f"  - {item}/ (contains {subdir_count} subdirs, {file_count} files)")
        else:
            print(f"  - {item}")
    if len(contents) > 10:
        print(f"  ... and {len(contents) - 10} more items")
except Exception as e:
    print(f"Error listing directory: {e}")

# ---- 6. Check for Training/Test directories ----
TRAIN_DIR = os.path.join(DATASET_PATH, "Training")
TEST_DIR = os.path.join(DATASET_PATH, "Test")

training_dir_exists = os.path.exists(TRAIN_DIR) and os.path.isdir(TRAIN_DIR)
test_dir_exists = os.path.exists(TEST_DIR) and os.path.isdir(TEST_DIR)

if training_dir_exists:
    print(f"\nFound Training directory: {TRAIN_DIR}")
    train_classes = os.listdir(TRAIN_DIR)
    print(f"  Contains {len(train_classes)} classes")
    # Sample a few classes
    for cls in train_classes[:3]:
        cls_path = os.path.join(TRAIN_DIR, cls)
        if os.path.isdir(cls_path):
            # Look for images with different extensions
            image_count = 0
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                image_count += len(glob.glob(os.path.join(cls_path, ext)))
            print(f"    - {cls}: {image_count} images")
else:
    print(f"Training directory not found at {TRAIN_DIR}")

if test_dir_exists:
    print(f"\nFound Test directory: {TEST_DIR}")
    test_classes = os.listdir(TEST_DIR)
    print(f"  Contains {len(test_classes)} classes")
else:
    print(f"Test directory not found at {TEST_DIR}")

# ---- 7. Dataset Creation Function ----
def prepare_datasets():
    """Prepare training and validation datasets based on directory structure"""
    # Determine which approach to use based on directory structure
    if training_dir_exists and test_dir_exists:
        print("\nUsing Training/Test directory structure")
        return prepare_training_test_datasets()
    else:
        print("\nUsing alternative dataset structure detection")
        return prepare_alternative_datasets()

def prepare_training_test_datasets():
    """Prepare datasets using the Training/Test directory structure"""
    # Process training directory
    all_train_images = []
    all_train_labels = []
    class_dirs = [d for d in os.listdir(TRAIN_DIR) if os.path.isdir(os.path.join(TRAIN_DIR, d))]
    class_to_idx = {cls_name: i for i, cls_name in enumerate(class_dirs)}
    
    print(f"Found {len(class_dirs)} classes in Training directory")
    
    # Get training images
    for cls_name in class_dirs:
        cls_path = os.path.join(TRAIN_DIR, cls_name)
        # Look for different image extensions
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
            img_list = glob.glob(os.path.join(cls_path, ext))
            for img_path in img_list:
                all_train_images.append(img_path)
                all_train_labels.append(class_to_idx[cls_name])
    
    print(f"Found {len(all_train_images)} training images")
    
    # Process test directory
    all_test_images = []
    all_test_labels = []
    
    # Check that test directory has same classes
    for cls_name in class_dirs:
        if cls_name not in class_to_idx:
            print(f"Warning: Class {cls_name} in test set not found in training set")
            continue
            
        cls_path = os.path.join(TEST_DIR, cls_name)
        if not os.path.exists(cls_path):
            print(f"Warning: Test directory for class {cls_name} not found")
            continue
            
        # Look for different image extensions
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
            img_list = glob.glob(os.path.join(cls_path, ext))
            for img_path in img_list:
                all_test_images.append(img_path)
                all_test_labels.append(class_to_idx[cls_name])
    
    print(f"Found {len(all_test_images)} test images")
    
    if not all_train_images or not all_test_images:
        raise Exception("No images found in Training/Test directories")
        
    # Split training data into train and validation
    # Use 10% of training data for validation
    train_imgs, val_imgs, train_labels, val_labels = train_test_split(
        all_train_images, all_train_labels, test_size=0.1, stratify=all_train_labels, random_state=42
    )
    
    print(f"Split: {len(train_imgs)} training, {len(val_imgs)} validation, {len(all_test_images)} test images")
    
    # Create TF datasets
    train_ds = create_tpu_dataset(train_imgs, train_labels, len(class_dirs), is_training=True)
    val_ds = create_tpu_dataset(val_imgs, val_labels, len(class_dirs), is_training=False)
    test_ds = create_tpu_dataset(all_test_images, all_test_labels, len(class_dirs), is_training=False)
    
    return train_ds, val_ds, test_ds, len(class_dirs), len(train_imgs), len(val_imgs)

def prepare_alternative_datasets():
    """Prepare datasets using an alternative approach when standard structure not found"""
    # Try to find any classes in the main directory
    all_images = []
    all_labels = []
    
    # First check if classes are directly in main directory
    potential_class_dirs = [d for d in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH, d))]
    class_dirs = []
    
    # Verify which directories contain images (actual classes)
    for d in potential_class_dirs:
        dir_path = os.path.join(DATASET_PATH, d)
        has_images = False
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
            if glob.glob(os.path.join(dir_path, ext)):
                has_images = True
                break
        if has_images:
            class_dirs.append(d)
    
    if class_dirs:
        print(f"Found {len(class_dirs)} classes in main directory")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(class_dirs)}
        
        # Collect images from each class
        for cls_name in class_dirs:
            cls_path = os.path.join(DATASET_PATH, cls_name)
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                img_list = glob.glob(os.path.join(cls_path, ext))
                for img_path in img_list:
                    all_images.append(img_path)
                    all_labels.append(class_to_idx[cls_name])
    
    # If no classes found, try recursive search
    if not all_images:
        print("No class directories found in main directory, trying recursive search...")
        
        # Map of parent directory to count of image files - to identify likely class dirs
        dir_to_img_count = {}
        
        # Search for image files recursively
        for root, _, files in os.walk(DATASET_PATH):
            img_count = 0
            for f in files:
                if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    img_count += 1
            
            if img_count > 0:
                dir_to_img_count[root] = img_count
        
        # Sort directories by image count (descending)
        sorted_dirs = sorted(dir_to_img_count.items(), key=lambda x: x[1], reverse=True)
        
        # Print directories with most images
        print("Directories with most images:")
        for dir_path, count in sorted_dirs[:10]:
            print(f"  {dir_path}: {count} images")
        
        # Try to infer classes from directories with images
        # Strategy: directories at same level with similar image counts are likely classes
        potential_class_dirs = []
        
        # Get parent of first directory with images
        if sorted_dirs:
            first_dir = sorted_dirs[0][0]
            parent_dir = os.path.dirname(first_dir)
            
            # Check if siblings have images too
            sibling_dirs = [d for d, _ in sorted_dirs if os.path.dirname(d) == parent_dir]
            
            if len(sibling_dirs) > 1:
                print(f"Found {len(sibling_dirs)} potential class directories under {parent_dir}")
                potential_class_dirs = sibling_dirs
            else:
                # Just use all directories with images as classes
                potential_class_dirs = [d for d, c in sorted_dirs if c >= 5]  # At least 5 images
        
        if potential_class_dirs:
            # Use directory names as class names
            class_to_idx = {os.path.basename(d): i for i, d in enumerate(potential_class_dirs)}
            
            # Collect images from each potential class directory
            for cls_path in potential_class_dirs:
                cls_name = os.path.basename(cls_path)
                for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                    img_list = glob.glob(os.path.join(cls_path, ext))
                    for img_path in img_list:
                        all_images.append(img_path)
                        all_labels.append(class_to_idx[cls_name])
    
    # Final check - did we find any images?
    if not all_images:
        raise Exception("No images found in the dataset with any common structure!")
        
    print(f"Total images found: {len(all_images)}")
    print(f"Total classes found: {len(set(all_labels))}")
    
    # Split into train/val/test (80/10/10)
    train_imgs, temp_imgs, train_labels, temp_labels = train_test_split(
        all_images, all_labels, test_size=0.2, stratify=all_labels, random_state=42
    )
    
    val_imgs, test_imgs, val_labels, test_labels = train_test_split(
        temp_imgs, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42
    )
    
    print(f"Split: {len(train_imgs)} training, {len(val_imgs)} validation, {len(test_imgs)} test images")
    
    # Create TF datasets
    num_classes = len(set(all_labels))
    train_ds = create_tpu_dataset(train_imgs, train_labels, num_classes, is_training=True)
    val_ds = create_tpu_dataset(val_imgs, val_labels, num_classes, is_training=False)
    test_ds = create_tpu_dataset(test_imgs, test_labels, num_classes, is_training=False)
    
    return train_ds, val_ds, test_ds, num_classes, len(train_imgs), len(val_imgs)

def decode_img(file_path):
    """Decode an image file to a tensor"""
    img = tf.io.read_file(file_path)
    # Detect the image format
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    img = tf.cast(img, tf.float32) / 255.0  # Normalize to [0,1]
    return img

def create_tpu_dataset(image_paths, labels, num_classes, is_training=True):
    """Create a TPU-optimized dataset from file paths and labels"""
    # Convert Python lists to TensorFlow tensors
    paths_ds = tf.data.Dataset.from_tensor_slices(image_paths)
    labels_ds = tf.data.Dataset.from_tensor_slices(labels)
    
    # Create a dataset of (path, label) pairs
    dataset = tf.data.Dataset.zip((paths_ds, labels_ds))
    
    # Shuffle if training
    if is_training:
        dataset = dataset.shuffle(buffer_size=min(10000, len(image_paths)))
        
    # Map function to process each item
    def process_path(file_path, label):
        img = decode_img(file_path)
        
        # Data augmentation for training
        if is_training:
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_brightness(img, 0.2)
            img = tf.image.random_contrast(img, 0.8, 1.2)
        
        # Apply MobileNetV2 preprocessing
        img = tf.keras.applications.mobilenet_v2.preprocess_input(img * 255.0)
        
        # One-hot encode the label
        label = tf.one_hot(label, depth=num_classes)
        return img, label
        
    # Apply processing function to each item
    dataset = dataset.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
    
    # Batch the data
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)  # Important for TPU: drop_remainder=True
    
    # Use caching for better performance
    dataset = dataset.cache()
    
    # Prefetch for better performance
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

# ---- 8. Prepare Datasets ----
try:
    print("\nPreparing datasets...")
    train_ds, val_ds, test_ds, num_classes, train_size, val_size = prepare_datasets()
    
    # Calculate steps
    steps_per_epoch = train_size // BATCH_SIZE
    validation_steps = val_size // BATCH_SIZE
    
    # Ensure at least one step
    steps_per_epoch = max(1, steps_per_epoch)
    validation_steps = max(1, validation_steps)
    
    print(f"Dataset prepared successfully:")
    print(f"Number of classes: {num_classes}")
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Validation steps: {validation_steps}")
    
except Exception as e:
    print(f"Error preparing datasets: {e}")
    # Try again with a more aggressive search
    try:
        print("\nAttempting to find any images in the dataset...")
        all_image_paths = []
        for ext in ['jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG']:
            found = glob.glob(os.path.join(DATASET_PATH, "**", f"*.{ext}"), recursive=True)
            all_image_paths.extend(found)
            print(f"Found {len(found)} .{ext} files")
        
        if not all_image_paths:
            raise Exception("No image files found in the dataset")
            
        print(f"Total images found: {len(all_image_paths)}")
        print("Sample paths:")
        for path in all_image_paths[:5]:
            print(f"  {path}")
            
        raise Exception("Dataset structure not compatible with automatic detection. Please check paths.")
    except Exception as e2:
        print(f"Final error: {e2}")
        raise

# ---- 9. Model Creation ----
def create_model():
    """Create the MobileNetV2 model for fruit classification"""
    # Use smaller input size and alpha parameter for faster inference
    base_model = MobileNetV2(
        weights="imagenet", 
        include_top=False, 
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        alpha=0.75  # Smaller network (75% of filters)
    )

    # Freeze base model for initial training
    base_model.trainable = False

    # Efficient Model Head
    x = base_model.output
    x = GlobalAveragePooling2D(name="gap")(x)
    x = Dense(128, activation="relu", name="dense_1")(x)
    x = Dropout(0.4, name="dropout_1")(x)
    # Force float32 output for TPU compatibility
    output_layer = Dense(num_classes, activation="softmax", dtype='float32', name="output")(x)

    model = Model(inputs=base_model.input, outputs=output_layer)
    
    # Learning rate schedule for better convergence
    initial_learning_rate = 0.001
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=steps_per_epoch*2,
        decay_rate=0.9,
        staircase=True
    )

    # Compilation
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss="categorical_crossentropy",
        metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3_acc")]
    )
    
    return model, base_model

# Create model inside TPU/GPU strategy scope
with strategy.scope():
    model, base_model = create_model()

# Model summary
print("\nModel Architecture Summary:")
model.summary()

# ---- 10. Training Callbacks ----
callbacks = [
    # Save model checkpoints
    ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        mode='max'
    ),
    # Early stopping to prevent overfitting
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    ),
    # Reduce learning rate when training plateaus
    ReduceLROnPlateau(
        monitor='val_loss', 
        factor=0.2, 
        patience=2, 
        min_lr=1e-6
    )
]

# ---- 11. Initial TPU Compatibility Test ----
print("\nRunning a minimal test to check hardware compatibility...")
try:
    # Take just one batch and run for one epoch as a test
    test_train_ds = train_ds.take(1).repeat(1)
    test_val_ds = val_ds.take(1).repeat(1)
    
    test_history = model.fit(
        test_train_ds,
        epochs=1,
        steps_per_epoch=1,
        validation_data=test_val_ds,
        validation_steps=1
    )
    
    print("Hardware compatibility test successful!")
except Exception as e:
    print(f"Hardware test failed: {e}")
    print("Trying alternate configuration...")
    
    # Try re-initializing with different settings
    try:
        if 'tpu' in locals():
            tf.tpu.experimental.initialize_tpu_system(tpu)
            
        # Recreate model with simpler configuration
        with strategy.scope():
            model = tf.keras.Sequential([
                tf.keras.applications.MobileNetV2(
                    input_shape=(IMG_SIZE, IMG_SIZE, 3),
                    include_top=False,
                    weights='imagenet',
                    pooling='avg'
                ),
                tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')
            ])
            
            model.compile(
                optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy']
            )
            
            # Try test again
            test_train_ds = train_ds.take(1).repeat(1)
            test_history = model.fit(
                test_train_ds,
                epochs=1,
                steps_per_epoch=1
            )
            
            print("Alternate model configuration successful!")
    except Exception as e2:
        print(f"Alternate configuration also failed: {e2}")
        print("Falling back to CPU training with smaller batches...")
        
        # Reduce batch size for CPU training
        global BATCH_SIZE
        original_batch_size = BATCH_SIZE
        BATCH_SIZE = 32
        print(f"Reduced batch size from {original_batch_size} to {BATCH_SIZE}")
        
        # Recreate datasets with smaller batch size
        train_ds, val_ds, test_ds, _, _, _ = prepare_datasets()
        
        # Recreate model
        strategy = tf.distribute.get_strategy()
        with strategy.scope():
            model, base_model = create_model()

# ---- 12. Training Phase 1 ----
print("\nStarting initial training phase (base model frozen)...")
try:
    history = model.fit(
        train_ds,
        epochs=10,  # Start with 10 epochs
        validation_data=val_ds,
        callbacks=callbacks,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps
    )
    
    print("Initial training phase completed successfully!")
except Exception as e:
    print(f"Error during initial training: {e}")
    # Try again with simpler approach
    try:
        print("Attempting simplified training...")
        history = model.fit(
            train_ds,
            epochs=5,
            validation_data=val_ds,
            callbacks=[],  # No callbacks to simplify
            steps_per_epoch=min(steps_per_epoch, 10),  # Limit steps
            validation_steps=min(validation_steps, 5)   # Limit validation steps
        )
    except Exception as e2:
        print(f"Simplified training also failed: {e2}")
        raise Exception("Training failed. Please check hardware and dataset.")

# ---- 13. Fine-tuning Phase ----
print("\nStarting fine-tuning phase (unfreeze top layers)...")
try:
    # Unfreeze the base model (partially)
    with strategy.scope():
        # Unfreeze the last block of the MobileNetV2 model
        for layer in base_model.layers[-12:]:
            layer.trainable = True

        # Count trainable parameters
        trainable_count = sum(tf.keras.backend.count_params(w) for w in model.trainable_weights)
        non_trainable_count = sum(tf.keras.backend.count_params(w) for w in model.non_trainable_weights)
        print(f"Trainable parameters: {trainable_count:,}")
        print(f"Non-trainable parameters: {non_trainable_count:,}")

        # Use a much smaller learning rate for fine-tuning
        model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=1e-5),
            loss="categorical_crossentropy",
            metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3_acc")]
        )

    # Fine-tune
    history_finetune = model.fit(
        train_ds,
        epochs=5,
        validation_data=val_ds,
        callbacks=callbacks,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps
    )
    
    print("Fine-tuning phase completed successfully!")
except Exception as e:
    print(f"Error during fine-tuning: {e}")
    print("Skipping fine-tuning phase.")
    history_finetune = None

# ---- 14. Evaluation ----
print("\nEvaluating model on test dataset...")
try:
    test_results = model.evaluate(test_ds)
    print(f"Test loss: {test_results[0]:.4f}")
    print(f"Test accuracy: {test_results[1]:.4f}")
    if len(test_results) > 2:
        print(f"Test top-3 accuracy: {test_results[2]:.4f}")
except Exception as e:
    print(f"Error during evaluation: {e}")

# ---- 15. Save Model ----
print("\nSaving model...")
try:
    model.save(MODEL_PATH)
    print(f"Saved Keras model to {MODEL_PATH}")

    # Convert to TensorFlow Lite for deployment
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()

    # Save the TF Lite model
    tflite_path = os.path.join(os.path.dirname(MODEL_PATH), 'model.tflite')
    with open(tflite_path, 'wb') as f:
        f.write(tflite_model)
    print(f"Saved TFLite model to {tflite_path}")
    
    # Save class indices for inference
    class_indices = {}
    if 'class_to_idx' in locals():
        class_indices = {cls: idx for cls, idx in class_to_idx.items()}
    
    # Save to file
    import json
    with open('class_indices.json', 'w') as f:
        json.dump(class_indices, f)
    print("Saved class indices to class_indices.json")
    
except Exception as e:
    print(f"Error saving model: {e}")