# üß™ Test TFLite Model with Your Own Images

This notebook lets you test the converted TFLite model with your own nail images.

## Prerequisites:
1. Run `convert_to_tflite.ipynb` first to generate the TFLite model
2. Have the `.tflite` model file ready
3. Prepare your test images

## Step 1: Install & Import Required Packages

In [None]:
!pip install tensorflow pillow numpy opencv-python matplotlib supervision

In [None]:
import tensorflow as tf
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import cv2
import time
import os

print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")

## Step 2: Load TFLite Model

In [None]:
# Path to your TFLite model
TFLITE_MODEL_PATH = "./tflite_models/rfdetr_nails_float16.tflite"

# Check if model exists
if not os.path.exists(TFLITE_MODEL_PATH):
    print(f"‚ùå Model not found at: {TFLITE_MODEL_PATH}")
    print("\nPlease run convert_to_tflite.ipynb first to generate the model.")
else:
    print(f"‚úÖ Model found: {TFLITE_MODEL_PATH}")
    model_size = os.path.getsize(TFLITE_MODEL_PATH) / (1024 * 1024)
    print(f"üìä Model size: {model_size:.2f} MB")

# Load TFLite model
print("\nüì¶ Loading TFLite model...")
interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL_PATH)
interpreter.allocate_tensors()

# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("‚úÖ Model loaded successfully!")
print(f"\nInput shape: {input_details[0]['shape']}")
print(f"Output shape: {output_details[0]['shape']}")

## Step 3: Define Preprocessing Function

In [None]:
def preprocess_image(image_path, target_size=(640, 640)):
    """
    Preprocess image for TFLite model
    
    Args:
        image_path: Path to input image
        target_size: Target size (width, height)
    
    Returns:
        preprocessed_image: Numpy array ready for inference
        original_image: Original PIL Image
        original_size: Original image size (width, height)
    """
    # Load image
    original_image = Image.open(image_path).convert('RGB')
    original_size = original_image.size
    
    # Resize
    resized_image = original_image.resize(target_size, Image.BILINEAR)
    
    # Convert to numpy array
    img_array = np.array(resized_image).astype(np.float32)
    
    # Normalize to [0, 1]
    img_array = img_array / 255.0
    
    # Transpose to CHW format (channels, height, width)
    img_array = np.transpose(img_array, (2, 0, 1))
    
    # Add batch dimension
    img_array = np.expand_dims(img_array, axis=0)
    
    return img_array, original_image, original_size

print("‚úÖ Preprocessing function defined")

## Step 4: Define Inference Function

In [None]:
def run_inference(image_path, threshold=0.5):
    """
    Run TFLite inference on image
    
    Args:
        image_path: Path to input image
        threshold: Confidence threshold
    
    Returns:
        detections: Dictionary with detection results
        inference_time: Inference time in ms
    """
    # Preprocess
    print(f"\nüì∏ Processing: {image_path}")
    input_data, original_image, original_size = preprocess_image(image_path)
    
    # Run inference
    print("‚è±Ô∏è Running inference...")
    start_time = time.time()
    
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    
    end_time = time.time()
    inference_time = (end_time - start_time) * 1000  # Convert to ms
    
    # Get output
    output = interpreter.get_tensor(output_details[0]['index'])
    
    print(f"‚úÖ Inference completed in {inference_time:.2f} ms")
    print(f"   Output shape: {output.shape}")
    
    # TODO: Post-process output to extract detections
    # This depends on your RF-DETR model's output format
    # For now, return raw output
    
    return {
        'output': output,
        'original_image': original_image,
        'original_size': original_size,
        'inference_time': inference_time
    }

print("‚úÖ Inference function defined")

## Step 5: Test with Your Image

### üìù Instructions:
1. Put your test image in a known location
2. Update the `IMAGE_PATH` variable below
3. Run the cell to see results

In [None]:
# ‚ö†Ô∏è UPDATE THIS PATH TO YOUR IMAGE
IMAGE_PATH = "/home/usama-naveed/nail_AR-rfdeter/usama_nails1.jpeg"

# Check if image exists
if not os.path.exists(IMAGE_PATH):
    print(f"‚ùå Image not found: {IMAGE_PATH}")
    print("\nPlease update IMAGE_PATH to point to your test image.")
else:
    print(f"‚úÖ Image found: {IMAGE_PATH}")
    
    # Run inference
    result = run_inference(IMAGE_PATH, threshold=0.5)
    
    # Display original image
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(result['original_image'])
    plt.title(f"Original Image\n{result['original_size'][0]}√ó{result['original_size'][1]}")
    plt.axis('off')
    
    # Display output visualization (if applicable)
    plt.subplot(1, 2, 2)
    plt.imshow(result['original_image'])
    plt.title(f"Inference Result\nTime: {result['inference_time']:.2f} ms")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìä Results:")
    print(f"  Image size: {result['original_size']}")
    print(f"  Inference time: {result['inference_time']:.2f} ms")
    print(f"  FPS: {1000/result['inference_time']:.1f}")
    print(f"  Output shape: {result['output'].shape}")
    print(f"  Output dtype: {result['output'].dtype}")

## Step 6: Compare with PyTorch Model (Optional)

Run inference with both models and compare results

In [None]:
# Load PyTorch model for comparison
try:
    from rfdetr import RFDETRSegPreview
    import torch
    
    print("üì¶ Loading PyTorch model...")
    pytorch_model = RFDETRSegPreview(
        pretrain_weights="/home/usama-naveed/nail_AR-rfdeter/output/checkpoint_best_total.pth"
    )
    pytorch_model.optimize_for_inference()
    pytorch_model.eval()
    
    if torch.cuda.is_available():
        pytorch_model = pytorch_model.cuda()
        device = "CUDA"
    else:
        device = "CPU"
    
    print(f"‚úÖ PyTorch model loaded (device: {device})")
    
    # Run PyTorch inference
    print("\n‚è±Ô∏è Running PyTorch inference...")
    image = Image.open(IMAGE_PATH).convert('RGB')
    
    start = time.time()
    pytorch_result = pytorch_model.predict(image, threshold=0.5)
    pytorch_time = (time.time() - start) * 1000
    
    print(f"‚úÖ PyTorch inference: {pytorch_time:.2f} ms")
    
    # Compare
    print(f"\nüìä Performance Comparison:")
    print(f"  TFLite (Float16): {result['inference_time']:.2f} ms")
    print(f"  PyTorch (Float32): {pytorch_time:.2f} ms")
    print(f"  Speedup: {pytorch_time/result['inference_time']:.2f}x faster with TFLite")
    
except Exception as e:
    print(f"‚ö†Ô∏è Could not load PyTorch model: {e}")
    print("Skipping comparison...")

## Step 7: Batch Testing (Multiple Images)

In [None]:
# Test with multiple images
import glob

# ‚ö†Ô∏è UPDATE THIS PATH TO YOUR IMAGE FOLDER
IMAGE_FOLDER = "/home/usama-naveed/nail_AR-rfdeter/"
IMAGE_PATTERN = "*.jpeg"  # Change to *.jpg, *.png, etc.

# Find all images
image_files = glob.glob(os.path.join(IMAGE_FOLDER, IMAGE_PATTERN))

if len(image_files) == 0:
    print(f"‚ùå No images found matching: {os.path.join(IMAGE_FOLDER, IMAGE_PATTERN)}")
else:
    print(f"‚úÖ Found {len(image_files)} images\n")
    
    results = []
    
    for img_path in image_files[:5]:  # Test first 5 images
        print(f"Processing: {os.path.basename(img_path)}")
        try:
            result = run_inference(img_path, threshold=0.5)
            results.append({
                'path': img_path,
                'time': result['inference_time'],
                'size': result['original_size']
            })
        except Exception as e:
            print(f"  ‚ùå Error: {e}")
    
    # Show summary
    if results:
        times = [r['time'] for r in results]
        print(f"\nüìä Batch Testing Summary:")
        print(f"  Images processed: {len(results)}")
        print(f"  Average time: {np.mean(times):.2f} ms")
        print(f"  Min time: {np.min(times):.2f} ms")
        print(f"  Max time: {np.max(times):.2f} ms")
        print(f"  Average FPS: {1000/np.mean(times):.1f}")

## Step 8: Visualize Output (If Applicable)

This section will depend on your model's output format

In [None]:
# Examine raw output structure
print("üîç Examining model output structure:\n")
print(f"Output shape: {result['output'].shape}")
print(f"Output dtype: {result['output'].dtype}")
print(f"Output range: [{result['output'].min():.4f}, {result['output'].max():.4f}]")

# If output is a single tensor, show its structure
if len(result['output'].shape) == 4:
    batch, channels, height, width = result['output'].shape
    print(f"\nInterpretation:")
    print(f"  Batch size: {batch}")
    print(f"  Channels: {channels}")
    print(f"  Height: {height}")
    print(f"  Width: {width}")

## Step 9: Save Test Results

In [None]:
# Save test results to file
import json
from datetime import datetime

test_results = {
    'timestamp': datetime.now().isoformat(),
    'model_path': TFLITE_MODEL_PATH,
    'model_size_mb': os.path.getsize(TFLITE_MODEL_PATH) / (1024 * 1024),
    'test_image': IMAGE_PATH,
    'inference_time_ms': result['inference_time'],
    'fps': 1000 / result['inference_time'],
    'input_shape': input_details[0]['shape'].tolist(),
    'output_shape': list(result['output'].shape)
}

results_file = "tflite_test_results.json"
with open(results_file, 'w') as f:
    json.dump(test_results, f, indent=2)

print(f"‚úÖ Test results saved to: {results_file}")
print("\nüìÑ Results:")
print(json.dumps(test_results, indent=2))

## ‚úÖ Summary

You've successfully tested the TFLite model!

### Next Steps:
1. ‚úÖ Model is working and producing outputs
2. üìä Note the inference time and FPS
3. üîß Integrate into your backend (`model_rf_deter_tflite.py`)
4. üöÄ Deploy to production

### Key Metrics to Monitor:
- **Inference Time:** Should be 2-3x faster than PyTorch
- **Model Size:** Should be ~50% smaller
- **Accuracy:** Should be similar (< 1% degradation)

### Need Help?
- Check `tflite_test_results.json` for detailed metrics
- Compare with PyTorch model using Step 6
- Adjust threshold in Step 5 if needed