In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image
import ipywidgets as widgets
from IPython.display import display
from PIL import Image
import io

# Load the trained model
model = tf.keras.models.load_model("brain_tumor_detection_model.h5")

# Categories
categories = ["glioma", "meningioma", "notumor", "pituitary"]

# Create an upload widget
upload_widget = widgets.FileUpload(accept="image/*", multiple=False)

# Function to process the uploaded image and make a prediction
def on_upload_change(change):
    if upload_widget.value:
        # Extract the uploaded file
        uploaded_file = list(upload_widget.value.values())[0]
        
        # Get the content (raw byte data)
        image_data = uploaded_file['content']
        
        # Use io.BytesIO to read the image content into memory
        img = Image.open(io.BytesIO(image_data))
        
        # Resize the image to the required input size (150x150)
        img = img.resize((150, 150))
        
        # Convert image to numpy array and normalize it
        img_array = np.array(img) / 255.0
        
        # Ensure the image has 3 channels (RGB) if it's grayscale (2D array)
        if img_array.ndim == 2:
            img_array = np.stack([img_array] * 3, axis=-1)
        
        # Add batch dimension to the image
        img_array = np.expand_dims(img_array, axis=0)
        
        # Make a prediction using the trained model
        prediction = model.predict(img_array)
        predicted_class = categories[np.argmax(prediction)]
        
        # Display the image and the prediction result
        plt.imshow(img)
        plt.title(f"Predicted: {predicted_class}")
        plt.axis('off')
        plt.show()
        
        # Print the predicted category in the console
        print(f"Predicted Category: {predicted_class}")

# Attach the function to the upload widget
upload_widget.observe(on_upload_change, names='value')

# Display the upload widget
display(upload_widget)




FileUpload(value=(), accept='image/*', description='Upload')