In [1]:
import gradio as gr
import numpy as np
from PIL import Image
from typing import Tuple
import io
import tensorflow as tf
from sklearn.preprocessing import LabelEncoder

# Import the ArtAnalyzer class
from artclass_class import ArtAnalyzer

tf.random.set_seed(42)
np.random.seed(42)



def init_analyzer():
    style_labels = [
    "naive_art",
    "baroque",
    "rococo",
    "romanticism",
    "art_deco",
    "american_realism",
    "art_nouveau",
    "expressionism",
    "modernism",
    "post_impressionism",
    "high_renaissance",
    "cubism",
    "abstract_expressionism",
    "art_informel",
    "mannerism",
    "northern_renaissance",
    "surrealism",
    "symbolism",
    "early_renaissance",
    "neo_romantic",
    "ukiyo_e",
    "impressionism",
    "pop_art",
    "fauvism",
    "neoclassicism",
    "minimalism"
    ]

    label_encoder = LabelEncoder()
    label_encoder.fit(style_labels)
    style_labels = label_encoder.classes_

    try:
        return ArtAnalyzer(
            cnn_model_path="art_classifier_curated.keras",
            style_labels=style_labels
        )
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise

# Initialize analyzer globally
analyzer = init_analyzer()

def preprocess_image(image):
    try:
        if isinstance(image, str):
            img = Image.open(image)
        elif isinstance(image, np.ndarray):
            # BGR to RGB if from numpy/Gradio
            image = image[..., ::-1]
            img = Image.fromarray(image)
        else:
            return None, "Unsupported image format"
            
        # Always convert to RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')
            
        # Resize first
        target_size = 256
        w, h = img.size
        ratio = min(target_size / w, target_size / h)
        new_w = int(w * ratio)
        new_h = int(h * ratio)
        resized = img.resize((new_w, new_h), Image.LANCZOS)
        
        # Calculate mean color from resized image
        resized_array = np.array(resized)
        mean_color = tuple(map(int, np.mean(resized_array, axis=(0, 1))))
        
        # Create padded image
        new_image = Image.new('RGB', (target_size, target_size), mean_color)
        
        # Calculate padding
        x_offset = (target_size - new_w) // 2
        y_offset = (target_size - new_h) // 2
        
        # Paste the resized image
        new_image.paste(resized, (x_offset, y_offset))
        
        # Convert to numpy and normalize
        final_array = np.array(new_image, dtype=np.float32) / 255.0
        
        return final_array, None
            
    except Exception as e:
        return None, f"Error processing image: {str(e)}"

def process_and_analyze(image: np.ndarray, question: str) -> Tuple[str, str]:
    try:
        if image is None:
            return "Error: No image provided", "Please upload an image"
            
        # Preprocess image
        img_array, error = preprocess_image(image)
        if error:
            return f"Error: {error}", "Unable to process image"
            
        # Get predictions
        predictions = analyzer.analyze_image(img_array)
        
        # Format style predictions with proper newlines and spacing
        style_output = "Top Predicted Styles:\n"
        sorted_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)
        for style, conf in sorted_predictions[:5]:  # Show top 5 for better context
            style_output += f"{style}: {conf*100:.2f}%\n"
        
        response = analyzer.answer_question(question, predictions)
        if "Question:" in response:
            response = response.split("Answer:", 1)[-1].strip()
        
        return style_output, response
        
    except Exception as e:
        print(f"Error in process_and_analyze: {str(e)}")
        return f"Error: {str(e)}", "An unexpected error occurred"

# Then, the Gradio interface:
with gr.Blocks() as demo:
    gr.Markdown("# Art Style Analyzer")
    
    with gr.Row():
        image_input = gr.Image(
            type="numpy",
            label="Upload artwork (supported formats: jpg, png, webp)"
        )
        question_input = gr.Textbox(
            label="Ask a question about the artwork",
            placeholder="What techniques are used in this style?"
        )
    
    with gr.Row():
        style_output = gr.Textbox(label="Style Predictions")
        answer_output = gr.Textbox(label="Answer")
    
    analyze_btn = gr.Button("Analyze Artwork")
    
    analyze_btn.click(
        fn=process_and_analyze,
        inputs=[image_input, question_input],
        outputs=[style_output, answer_output]
    )

# Launch the interface
if __name__ == "__main__":
    demo.launch()

Model summary:


: 