In [None]:
from flask import Flask, render_template, request, redirect, url_for
import io
import sys
import warnings
from functools import lru_cache
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from PIL import Image

# Function to validate and upload images
def upload_image(file_obj):
    if not file_obj:
        return None
    
    try:
        image = Image.open(file_obj)
        # Ensure the file is an image by attempting to open it with PIL
        return image
    except (IOError, OSError):
        return None

# Function - Centred square crop
def crop(image):
    _ = image.shape
    height, width = _[0], _[1]
    _ = min(height, width)
    target_height, target_width = _, _

    offset_height = max(height-width, 0) // 2
    offset_width = max(width-height, 0) // 2
    image_crop = tf.image.crop_to_bounding_box(
      image, offset_height, offset_width, target_height, target_width)
    return image_crop

# Function - Load image
@lru_cache(maxsize=None)
def load_image(content):
    # Decode the processed image
    image = tf.image.decode_image(content, channels=3, dtype=tf.float32)
    # Centred square crop
    image = crop(image)
    # Resize
    image = tf.image.resize(image, (256, 256))
    return image

# Function - Plot images
def image_plot(images, title, grid):
    n = len(images)
    plt.figure()
    for i in range(n):
        plt.subplot(grid[0], grid[1], i+1)
        plt.imshow(images[i])
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

model_url = "https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2"
style_transfer_model = hub.load(model_url)

# Flask app setup
app = Flask(__name__)

@app.route("/", methods=["GET", "POST"])
def upload_and_style():
    if request.method == "GET":
        return render_template("index.html")  # Render upload form
    else:
        content_file = request.files.get("content_image")
        style_files = request.files.getlist("style_image")  # Handle multiple style uploads

        # Validate and process images
        content = upload_image(content_file)
        if not content:
            return "Invalid content image!", 400

        style_images = []
        for style_file in style_files:
            style_image = upload_image(style_file)
            if not style_image:
                return "Invalid style image!", 400
            style_images.append(style_image)

        # Convert content and style images to float32 tensors
        content = tf.cast(content, tf.float32)
        style_images = [tf.cast(style_image, tf.float32) for style_image in style_images]

        # Style transfer logic here
        style = tf.concat(style_images, axis=0)
        cast = style_transfer_model(content, style)[0]

        # Save styled image (modified for Flask)
        filename = "static/styled_image.jpg"
        tf.keras.utils.save_img(filename, tf.squeeze(cast))

        return redirect(url_for("show_result"))

@app.route("/result")
def show_result():
    return render_template("result.html", generated_image="styled_image.jpg")

if __name__ == "__main__":
    app.run(debug=True, use_reloader=False)















 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug:127.0.0.1 - - [12/May/2024 12:48:27] "GET / HTTP/1.1" 200 -
