<a href="https://colab.research.google.com/github/tantowijh/imagerestoration/blob/main/GoogleColabAPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title __1. Install all the Requirements__
#@markdown ⬅ Run this cell to install all the requirements.
!pip install diffusers transformers accelerate scipy safetensors &> /dev/null

!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git &> /dev/null
!pip install py-real-esrgan &> /dev/null

!pip install pyngrok &> /dev/null

In [None]:
#@title __2. Import Libraries and Load Models__
#@markdown ⬅ Run this cell to import all the libraries and load the models.
import requests
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForInpainting
import torch
from RealESRGAN import RealESRGAN

# Load the Real-ESRGAN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enhance_pipe = RealESRGAN(device, scale=4)
enhance_pipe.load_weights('weights/RealESRGAN_x4.pth', download=True)

# Load Inpainting Model
inpaint_model = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
inpaint_pipe = AutoPipelineForInpainting.from_pretrained(inpaint_model, torch_dtype=torch.float16, variant="fp16")
inpaint_pipe = inpaint_pipe.to("cuda")

In [3]:
#@title __3. Initialize the Flask App__
#@markdown ⬅ Run this cell to initialize the Flask app, then run the following steps.

from flask import Flask, request, send_file, jsonify

app = Flask(__name__)

@app.route('/', methods=['GET'])
def home():
    return 'Hello, World!'

@app.route('/enhance', methods=['POST'])
def enhance():
  try:
    # Get the image from the request
    image_data = request.files['image'].read()

    # Load the image from the request
    image = Image.open(BytesIO(image_data)).convert('RGB')

    # Perform the upscaling using Real-ESRGAN
    sr_image = enhance_pipe.predict(image)

    # Save the upscaled image to a BytesIO buffer
    buffered = BytesIO()
    sr_image.save(buffered, format="PNG")
    buffered.seek(0)

    return send_file(buffered, mimetype='image/png', as_attachment=True, download_name='enhanced_image.png')

  except Exception as e:
    return jsonify({'error': str(e)}), 500

@app.route('/restore', methods=['POST'])
def restore():
  try:
    # Get the image, mask, and prompt from the request
    image_data = request.files['image'].read()
    mask_data = request.files['mask'].read()
    prompt = request.form['prompt']

    # Get optional parameters from the request
    guidance_scale = float(request.form.get('guidance_scale', 8.0))
    num_inference_steps = int(request.form.get('num_inference_steps', 20))
    strength = float(request.form.get('strength', 0.99))
    seed = int(request.form.get('seed', 0))

    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Load the image and mask from the request
    image = Image.open(BytesIO(image_data)).convert("RGB")
    mask = Image.open(BytesIO(mask_data)).convert("L")  # Convert mask to grayscale

    # Perform the inpainting using the Stable Diffusion model
    restored_image = inpaint_pipe(
        prompt=prompt,
        image=image,
        mask_image=mask,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        strength=strength,
        generator=generator,
    ).images[0]

    # Ensure the restored image is the same size as the original
    restored_image = restored_image.resize(image.size, Image.LANCZOS)

    # Save the restored image to a BytesIO buffer
    buffered = BytesIO()
    restored_image.save(buffered, format="PNG")
    buffered.seek(0)

    return send_file(buffered, mimetype='image/png', as_attachment=True, download_name='restored_image.png')

  except Exception as e:
    return jsonify({'error': str(e)}), 500


In [None]:
#@title __4. Run the Flask App into Public URL__
#@markdown ⬅ Run this cell to run Flask, then retrieve the URL and enter it into the configuration page.

#@markdown __Choose a Service for the Public URL__

#@markdown We recommend using **ngrok** to obtain your public URL, as **Serveo** is not always available (it may be down).


from pyngrok import ngrok
import subprocess
import time
import re

remote_port = 80
local_port = 5000
service = "ngrok" # @param ["ngrok","serveo"]
ngrok.kill()

def run_app():
  app.run(host='0.0.0.0', port=local_port)

#@markdown ---

#@markdown __Obtain Your ngrok Token__
#@markdown
#@markdown Please obtain your ngrok auth token from the official ngrok website at
#@markdown [ngrok authtoken](https://dashboard.ngrok.com/get-started/your-authtoken).
#@markdown > **Note:** Entering your ngrok token is mandatory only if you choose the ngrok service. If you select Serveo, you can skip this step.

#@markdown Please put the token you received below:
if service == "ngrok":
  authtoken = "" #@param {type:"string"}
  ngrok.set_auth_token(authtoken)

  ngrok_run = ngrok.connect(local_port).public_url
  print(f"Ngrok public URL: {ngrok_run}\n")
  run_app()

else:
  # Command to run in the background, using the defined port variables
  command = ["ssh", "-o", "StrictHostKeyChecking=no", "-R", f"{remote_port}:localhost:{local_port}", "serveo.net"]

  # Start the process
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

  # Allow some time for the connection to establish and the URL to be printed
  time.sleep(5)

  # Read the output and search for the URL
  output, errors = process.communicate(timeout=5)

  # Use regex to find the URL in the output
  url_match = re.search(r'(https?://[^\s]+)', output)
  if url_match:
      serveo_url = url_match.group(0)
      print(f'Serveo public URL: {serveo_url}\n')
      run_app()
  else:
      print('No URL! The service may be down.')