<a href="https://colab.research.google.com/github/tantowijh/imagerestoration/blob/image-enhancement/GoogleColabAPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title __Install all the Requirements__
!pip install diffusers transformers accelerate scipy safetensors &> /dev/null

!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git &> /dev/null
!pip install py-real-esrgan &> /dev/null

!pip install pyngrok &> /dev/null

In [None]:
#@title __Import Libraries and Load Models__
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForInpainting
import torch
from RealESRGAN import RealESRGAN

# Load the Real-ESRGAN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enhance_pipe = RealESRGAN(device, scale=4)
enhance_pipe.load_weights('weights/RealESRGAN_x4.pth', download=True)

# Load Inpainting Model
inpaint_model = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
inpaint_pipe = AutoPipelineForInpainting.from_pretrained(inpaint_model, torch_dtype=torch.float16, variant="fp16")
inpaint_pipe = inpaint_pipe.to("cuda")

In [None]:
#@title __Run the Flask App__
#@markdown Make sure to re-run the Ngrok Setup first and update _.env_ if you restart this cell.

from flask import Flask, request, send_file, jsonify
import numpy as np

app = Flask(__name__)

@app.route('/', methods=['GET'])
def home():
    return 'Hello, World!'

@app.route('/enhance', methods=['POST'])
def enhance():
  try:
    # Get the image from the request
    image_data = request.files['image'].read()

    # Load the image from the request
    image = Image.open(BytesIO(image_data)).convert('RGB')

    # Perform the upscaling using Real-ESRGAN
    sr_image = enhance_pipe.predict(image)

    # Save the upscaled image to a BytesIO buffer
    buffered = BytesIO()
    sr_image.save(buffered, format="PNG")
    buffered.seek(0)

    return send_file(buffered, mimetype='image/png', as_attachment=True, download_name='enhanced_image.png')

  except Exception as e:
    return jsonify({'error': str(e)}), 500

@app.route('/restore', methods=['POST'])
def restore():
  try:
    # Get the image, mask, and prompt from the request
    image_data = request.files['image'].read()
    mask_data = request.files['mask'].read()
    prompt = request.form['prompt']

    # Get optional parameters from the request
    guidance_scale = float(request.form.get('guidance_scale', 8.0))
    num_inference_steps = int(request.form.get('num_inference_steps', 20))
    strength = float(request.form.get('strength', 0.99))
    seed = int(request.form.get('seed', 0))

    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Load the image and mask from the request
    image = Image.open(BytesIO(image_data)).convert("RGB")
    mask = Image.open(BytesIO(mask_data)).convert("L")  # Convert mask to grayscale

    # Perform the inpainting using the Stable Diffusion model
    restored_image = inpaint_pipe(
        prompt=prompt,
        image=image,
        mask_image=mask,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        strength=strength,
        generator=generator,
    ).images[0]

    # Ensure the restored image is the same size as the original
    restored_image = restored_image.resize(image.size, Image.LANCZOS)

    # Save the restored image to a BytesIO buffer
    buffered = BytesIO()
    restored_image.save(buffered, format="PNG")
    buffered.seek(0)

    return send_file(buffered, mimetype='image/png', as_attachment=True, download_name='restored_image.png')

  except Exception as e:
    return jsonify({'error': str(e)}), 500


def run_app():
  app.run()

# Start the Flask app in a separate thread
import threading
import os

flask_thread = threading.Thread(target=run_app)
flask_thread.start()

!ssh -o StrictHostKeyChecking=no -R 80:localhost:5000 serveo.net

