<a href="https://colab.research.google.com/github/simonsny/image_background_removal/blob/dev_simon/FlaskApp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Connecting Google Colab to your Google Drive


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Installing the Flask-ngrok to be able to run the Flask app


In [2]:

!pip install flask-ngrok 


Collecting flask-ngrok
  Downloading https://files.pythonhosted.org/packages/af/6c/f54cb686ad1129e27d125d182f90f52b32f284e6c8df58c1bae54fa1adbc/flask_ngrok-0.0.25-py3-none-any.whl
Installing collected packages: flask-ngrok
Successfully installed flask-ngrok-0.0.25


Downloading the MODNet pretrained model.

In [3]:
import os

# clone the repository
%cd /content/drive/MyDrive/BeCode/Faktion/
if not os.path.exists('MODNet'):
  !git clone https://github.com/ZHKKKe/MODNet
%cd MODNet/

# dowload the pre-trained ckpt for video matting
pretrained_webcam_ckpt = 'pretrained/modnet_webcam_portrait_matting.ckpt'
if not os.path.exists(pretrained_webcam_ckpt):
  !gdown --id 1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX \
          -O pretrained/modnet_webcam_portrait_matting.ckpt

# dowload the pre-trained ckpt for image matting
pretrained_image_ckpt = 'pretrained/modnet_photographic_portrait_matting.ckpt'
if not os.path.exists(pretrained_image_ckpt):
  !gdown --id 1mcr7ALciuAsHCpLnrtG_eop5-EYhbCmz \
          -O pretrained/modnet_photographic_portrait_matting.ckpt

/content/drive/MyDrive/BeCode/Faktion
/content/drive/MyDrive/BeCode/Faktion/MODNet


Getting the inference code for image matting.

In [7]:
import os
import sys
import numpy as np
from PIL import Image

import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from src.models.modnet import MODNet

upload_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/uploads/'
matte_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/matte/'
output_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/output/'
pretrained_image_ckpt = '/content/drive/MyDrive/BeCode/Faktion/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt'
pretrained_webcam_ckpt = '/content/drive/MyDrive/BeCode/Faktion/MODNet/pretrained/pretrained/modnet_webcam_portrait_matting.ckpt'

def image_ckpt(filename):
    input_path = upload_dir + filename

    # check input arguments
    if not os.path.exists(upload_dir):
        print(f'Cannot find input path: {upload_dir}')
        exit()
    if not os.path.exists(matte_dir):
        print(f'Cannot find matte path: {matte_dir}')
        exit()
    if not os.path.exists(output_dir):
        print(f'Cannot find output path: {output_dir}')
        exit()
    if not os.path.exists(pretrained_image_ckpt):
        print(f'Cannot find ckpt path: {pretrained_image_ckpt}')
        exit()

    # define hyper-parameters
    ref_size = 512

    # define image to tensor transform
    im_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet).cuda()
    modnet.load_state_dict(torch.load(pretrained_image_ckpt))
    modnet.eval()

    # inference images
    im_names = os.listdir(upload_dir)
    if filename in im_names:
        print(f'Process image: {filename}')

        # read image
        im = Image.open(os.path.join(upload_dir, filename))

        # unify image channels to 3
        im = np.asarray(im)
        if len(im.shape) == 2:
            im = im[:, :, None]
        if im.shape[2] == 1:
            im = np.repeat(im, 3, axis=2)
        elif im.shape[2] == 4:
            im = im[:, :, 0:3]

        # convert image to PyTorch tensor
        im = Image.fromarray(im)
        im = im_transform(im)

        # add mini-batch dim
        im = im[None, :, :, :]

        # resize image for input
        im_b, im_c, im_h, im_w = im.shape
        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w
        
        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32
        im = F.interpolate(im, size=(im_rh, im_rw), mode='area')

        # inference
        _, _, matte = modnet(im.cuda(), True)

        # resize and save matte
        matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
        matte = matte[0][0].data.cpu().numpy()
        matte_name = filename.split('.')[0] + '.png'
        Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(matte_dir, matte_name))
        print(f"Successfully saved the alpha matte for {matte_name}")
        return matte_dir+matte_name

def foreground_image(image, matte, filename):
  print(f'Image name: {filename}')
  # obtain predicted foreground
  image = np.asarray(image)
  if len(image.shape) == 2:
    image = image[:, :, None]
  if image.shape[2] == 1:
    image = np.repeat(image, 3, axis=2)
  elif image.shape[2] == 4:
    image = image[:, :, 0:3]
  matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
  foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
  
  foreground_name = filename.split('.')[0] + '.png'
  Image.fromarray(((foreground).astype('uint8')), mode='L').save(os.path.join(outpur_dir, foreground_name))
  print(f"Successfully saved the foreground image {foreground_name}")
  return foreground_name

def foreground_img(image, alpha, filename):
  # Read the images
  img = cv2.imread(image)
  matte = cv2.imread(alpha)

  # Convert uint8 to float
  img = img.astype(float)

  # Normalize the alpha matte mask to keep intensity between 0 and 1
  matte = matte.astype(float)/255

  # Multiply the foreground with the alpha matte
  foreground = cv2.multiply(matte, img)

  # Save image
  cv2.imwrite(output_dir+filename,  foreground)
  print(f"Successfully saved the foreground image for {filename}.")
  return filename

Flask Application starts here.

In [None]:
import os
from flask_ngrok import run_with_ngrok
from flask import Flask, flash, render_template, request, redirect, url_for, send_file
import time
from werkzeug.utils import secure_filename

"""
:attrib upload_dir contains the upload path of the uploaded image
:attrib output_dir contains the output/result image 
:attrib matte_dir contains the alpha matte of the image
:allowed_extensions contains the list of allowed image file extensions
"""
upload_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/uploads/'
output_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/output/'
matte_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/matte/'
allowed_extensions = {"png", "jpg", "jpeg", "gif"}

app = Flask(__name__, template_folder='/content/drive/MyDrive/BeCode/Faktion/Flask/templates/', static_folder='/content/drive/MyDrive/BeCode/Faktion/Flask/static')
run_with_ngrok(app)
app.secret_key = "secret key"
app.config["UPLOAD_FOLDER"] = upload_dir

def allowed_file(filename):
    """
    Function that checks the file extensions is included on the allowed list.
    """
    return "." in filename and filename.rsplit(".", 1)[1].lower() in allowed_extensions

@app.route("/")
def home():
    """
    Function that returns the home page.
    """
    return render_template("home.html")

@app.route("/start", methods=["GET", "POST"])
def start():
    """
    Function that has both GET and POST method.
    This is the function where it will ask the user input and 
    then return the input with the edited version of the input
    """
    if request.method == "GET":
        return render_template("start.html")

    if request.method == "POST":
        if "image" not in request.files and "video" not in request.files:
            return redirect(request.url)
        """
        :attrib image will contain the image user input 
        :attrib video will contain the video user input
        :attrib start will contain the start time on when the program started
        :attrib end will contain the end time on when the program ended
        """
        image = request.files.get("image")
        print(f"Image Input: {image}")
        video = request.files.get("video")
        print(f"Video Input: {video}")
        start = time.time()

        if image == "" and video == "":
            return redirect(request.url)

        elif image:
            print("Elif image file")
            if allowed_file(image.filename):
                image_upload = secure_filename(image.filename)
                image.save(os.path.join(app.config["UPLOAD_FOLDER"], image_upload))
                #calling the function to create the alpha matte
                matte_name = image_ckpt(image_upload)

                #calling the function to create the foreground image
                #img_path = Image.open(os.path.join(upload_dir, image.filename))
                #matte_path = Image.open(os.path.join(matte_dir, matte_name))

                image_path = upload_dir + image.filename
                print(f'Image Path: {image_path}')
                print(f'Alpha Matte Path: {matte_name}')
                print(f'Image FileName: {image.filename}')

                edited_image = foreground_img(image_path, matte_name, image.filename)
                print(f'Edited Image: {edited_image}')
                foreground = Image.open(os.path.join(output_dir, edited_image))
                end = time.time()
                print(f"Program runs for {end - start} seconds.")
                return render_template("upload_image.html", filename=foreground)
            else:
                flash("Chosen file is not supported! Please upload an image file.")
                flash("Allowed image types are -> png, jpg, jpeg, gif")
                return redirect(request.url)

        elif video:
            print("Elif video file")
            end = time.time()
            print(f"Program runs for {end - start} seconds.")
            return render_template("home.html")

        else:
            return render_template("start.html")

@app.route("/display/<filename>")
def display_image(filename):
    """
    Function that displays the uploaded image.
    """
    return redirect(url_for("static", filename="uploads/" + filename), code=301)

@app.route("/edited/<filename>")
def edited_image(filename):
    """
    Function that displayes the edited image/result image.
    """
    return redirect(url_for("static", filename="uploads/" + filename), code=301)

@app.route("/save_file/<filename>")
def save_file(filename):
    """
    Function that that allows the user to save/download the image.
    """
    path = f"/content/drive/MyDrive/BeCode/Faktion/Flask/static/uploads/{filename}"
    print(f"Send File Path: {path}")
    return send_file(path, as_attachment=True)

@app.route("/upload_image")
def upload_image():
    """
    Function that returns the upload_image page.
    """
    return render_template("upload_image.html")

@app.route("/howtouse")
def howtouse():
    """
    Function that returns the how to use page.
    """
    return render_template("howtouse.html")

@app.route("/about")
def about():
    """
    Function that returns the about page.
    """
    return render_template("about.html")

@app.route("/live_feed")
def live_feed():
    """
    Function that returns the live feed page.
    """
    return render_template("live_feed.html")

if __name__ == "__main__":
    app.run()