In [19]:
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
import tensorflow as tf
import os
import gc

# Force TensorFlow to use CPU to avoid Metal plugin issues on M1
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
tf.config.set_visible_devices([], 'GPU')  # Disable GPU
gc.collect()  # Clear memory

# Load the model
try:
    model_path = '/Users/tkarim45/Documents/Personal Github Repositories/CureWise-AI-Medical-Healthcare/backend/data/kidney_disease/kidney_ct_model_1.h5'
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")
    model = load_model(model_path, compile=False)
    print("Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

def predict(image_path):
    """
    Predict the class of the given image using the loaded model.
    
    :param image_path: Path to the image file to be classified.
    :return: Dictionary with predicted class and confidence.
    """
    try:
        # Check if image file exists
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found at {image_path}")

        class_labels = ['Cyst', 'Normal', 'Stone']
        
        # Load and preprocess the image
        image = Image.open(image_path).convert('RGB')
        image = image.resize((28, 28))
        image = np.array(image, dtype=np.float32) / 255.0
        image = np.expand_dims(image, axis=0)  # Add batch dimension

        # Make prediction
        predictions = model.predict(image, verbose=0)
        predicted_class = np.argmax(predictions[0])
        predicted_label = class_labels[predicted_class]
        confidence = float(np.max(predictions[0]))

        return {
            "predicted_class": predicted_label,
            "confidence": confidence
        }
    except Exception as e:
        print(f"Error in prediction: {e}")
        return {"error": str(e)}

# Test prediction
try:
    image_path = '/Users/tkarim45/Downloads/Normal/Normal- (1009).jpg'
    result = predict(image_path)
    print("Prediction result:", result)
except Exception as e:
    print(f"Error during prediction: {e}")


Model loaded successfully
Prediction result: {'predicted_class': 'Normal', 'confidence': 7.963675022125244}
