In [2]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.9.1-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.6-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.5.2 (from gradio)
  Downloading gradio_client-1.5.2-py3-none-any.whl.metadata (7.1 kB)
Collecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting orjson~=3.0 (from gradio)
  Downloading orjson-3.10.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.8/41.8 kB[0m [31m2.2 MB/s[0m e

In [3]:
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import cv2

In [4]:
# Define the custom ClassToken layer for Vision Transformer (ViT)
class ClassToken(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
            trainable=True
        )

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]
        cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
        cls = tf.cast(cls, dtype=inputs.dtype)
        return cls

In [5]:
# Load models
vit_model = load_model(
    r'/content/drive/MyDrive/galaxy_type_classification_with_ViT_and_CNN/Vit_Model.h5',
    custom_objects={'ClassToken': ClassToken}
)
cnn_model = load_model(r'/content/drive/MyDrive/galaxy_type_classification_with_ViT_and_CNN/CNN_Model.h5')

In [8]:
import tensorflow

In [9]:
tensorflow.__version__

'2.15.0'

In [10]:
# Define class names
class_names = [
    "Barred_Spiral_Galaxies",
    "Cigar_Shaped_Smooth_Galaxies",
    "Disturbed_Galaxies",
    "Edge_On_Galaxies_With_Bulge",
    "Edge_On_Galaxies_Without_Bulge",
    "In_Between_Round_Smooth_Galaxies",
    "Merging_Galaxies",
    "Round_Smooth_Galaxies",
    "Unbarred_Loose_Spiral_Galaxies",
    "Unbarred_Tight_Spiral_Galaxies"
]

In [11]:
# Preprocessing function for CNN
def preprocess_for_cnn(img):
    img = cv2.resize(img, (224, 224))  # Resize to 224x224
    img = img / 255.0  # Normalize
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    return img

In [12]:
# Preprocessing function for ViT
def preprocess_for_vit(img):
    img = cv2.resize(img, (200, 200))  # Resize to 200x200
    img = img / 255.0  # Normalize
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    patch_size = 25
    patches = []
    for i in range(0, 200, patch_size):
        for j in range(0, 200, patch_size):
            patch = img[0, i:i + patch_size, j:j + patch_size, :]
            patches.append(patch.flatten())
    patches = np.array(patches)
    return np.expand_dims(patches, axis=0)

In [13]:
# Prediction function
def predict(image, model_choice):
    img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    if model_choice == "CNN":
        preprocessed_img = preprocess_for_cnn(img)
        model = cnn_model
    elif model_choice == "ViT":
        preprocessed_img = preprocess_for_vit(img)
        model = vit_model
    else:
        return "Invalid model choice!"

    predictions = model.predict(preprocessed_img)
    predicted_class_idx = np.argmax(predictions, axis=1)[0]
    confidence_score = predictions[0][predicted_class_idx] * 100  # Confidence as a percentage

    predicted_class = class_names[predicted_class_idx]
    return f"Predicted Class: {predicted_class}\nConfidence Score: {confidence_score:.2f}%"

In [14]:
# Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload a Galaxy Image"),
        gr.Radio(["CNN", "ViT"], label="Select Model")
    ],
    outputs=gr.Textbox(label="Prediction and Confidence Score"),
    title="Galaxy Type Classification",
    description="Upload an image of a galaxy and select a model to classify its type. The prediction includes the galaxy type and the confidence score."
)

# Launch the interface
interface.launch()

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://b176a5ffee21a57a65.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




<hr>