## Setup

Let’s start by checking to see what GPU we’ve been assigned. Ideally we get a V100, but a P100 is fine too. Other GPUs may lead to issues.

In [1]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-f1333f81-f3a0-e848-d4b5-d0d426278fe2)


In [None]:
# Crop Images and upload to S3 as jpg images with white background
import re
import os
from PIL import Image
from tqdm import tqdm
import numpy as np
from resizeimage import resizeimage
from PIL import ImageFile
from rembg.bg import remove
import io
from numpy import asarray
from IPython.display import display, clear_output, Image as Image2
import json
from urllib.request import urlopen, Request
from io import BytesIO
import base64
from google.cloud import automl_v1beta1
import time

os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="./key.json"

ImageFile.LOAD_TRUNCATED_IMAGES = True

def image_to_base64(image_path):
    img = Image.open(image_path)
    output_buffer = BytesIO()
    img.save(output_buffer, format='JPEG')
    byte_data = output_buffer.getvalue()
    base64_str = base64.b64encode(byte_data).decode('UTF-8')
    return base64_str

def base64_to_image(base64_str):
    base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
    byte_data = base64.b64decode(base64_data)
    image_data = BytesIO(byte_data)
    img = Image.open(image_data)
    return img

def get_prediction(content, project_id, model_id):
  prediction_client = automl_v1beta1.PredictionServiceClient()

  name = 'projects/{}/locations/us-central1/models/{}'.format(project_id, model_id)
  payload = {'image': {'image_bytes': content }}
  params = {}
  request = prediction_client.predict(name=name, payload=payload, params=params)
  return request  # waits till request is returned


def removeBgRunway(image_path):
    b64img = image_to_base64(image_path)
    inputs = {
      "image": b64img
    }
    req = Request(
      "https://voxel.hosted-models.runwayml.cloud/v1/query",
      method="POST",
      headers={
        "Accept": "application/json",
        "Content-Type": "application/json",
        "Authorization": "Bearer msnJO6m56bE/3bu4qlQyBw==",
      },
      data=json.dumps(inputs).encode("utf8")
    )
    with urlopen(req) as url:
      output = json.loads(url.read().decode("utf8"))
    return output["image"]

def resize(image_pil, width, height):
    '''
    Resize PIL image keeping ratio and using white background.
    '''
    ratio_w = width / image_pil.width
    ratio_h = height / image_pil.height
    if ratio_w < ratio_h:
        # It must be fixed by width
        resize_width = width
        resize_height = round(ratio_w * image_pil.height)
    else:
        # Fixed by height
        resize_width = round(ratio_h * image_pil.width)
        resize_height = height
    image_resize = image_pil.resize((resize_width, resize_height), Image.ANTIALIAS)
    background = Image.new('RGB', (width, height), "WHITE")
    display(image_resize)
    offset = (round((width - resize_width) / 2), round((height - resize_height) / 2))
    background.paste(image_resize, offset, image_resize)
    return background.convert('RGB')

sourcePath = '../homely-search-engine-batch1-bg-removed/'
targetPath = '../homely-search-engine-batch1-processed/'
tempPath = '../homely-sofa-temp/'
files = os.listdir(sourcePath)
size = 1024
numberOfImages = 37966

def filterImages():
    i = 0
    goodImagesCount = 0
    for subdir, dirs, files in os.walk(sourcePath):
        for file in files:            
            if(not ( file.endswith('.png') or file.endswith('.PNG') or file.endswith('.jpg') or file.endswith('.JPG'))):
                continue

            sourceFile = subdir + '/' + file
            targetFileJpeg = targetPath + os.path.splitext(file)[0] + '.jpg'
            tempFileJpg = tempPath + os.path.splitext(file)[0] + '.jpg'                 
            
            with open(sourceFile, 'rb') as ff:
              content = ff.read()
            prediction = get_prediction(content, '433049012679', 'ICN2368213905810915328')
            if("good" in str(prediction)):
                goodImagesCount+=1
                Image.open(sourceFile).convert("RGB").save(targetFileJpeg, 'JPEG')
                clear_output(wait=True)
                display(Image2(filename=sourceFile))
                display(goodImagesCount)
                display('Iteration: '+ str(i) + '/' + str(numberOfImages))
            
            i+=1

def processImage():
    i = 0
    goodImagesCount = 0
    for subdir, dirs, files in os.walk(sourcePath):
        for file in files:            
            if(not ( file.endswith('.png') or file.endswith('.PNG') or file.endswith('.jpg') or file.endswith('.JPG'))):
                continue
            
#             if(i<10004):
#                 i+=1
#                 continue 
                    
            sourceFile = subdir + '/' + file
            targetFileJpeg = targetPath + os.path.splitext(file)[0] + '.jpg'
            tempFileJpg = tempPath + os.path.splitext(file)[0] + '.jpg'

            f = np.fromfile(sourceFile)
            try:
                result = remove(f, model_name="u2net")
                targetImg = resize(Image.open(io.BytesIO(result)), size, size)
                targetImg.save(targetFileJpeg, "JPEG")
                clear_output(wait=True)
                display(targetImg.resize((512,512),0))
                display(sourceFile)
                display('Iteration: '+ str(i) + '/' + str(numberOfImages))
            except Exception:
                i+=1
                continue
            i+=1

filterImages()

In [None]:
%env AWS_ACCESS_KEY_ID=SOME_KEY
%env AWS_SECRET_ACCESS_KEY=SOME_SECRET
!aws s3 ls

In [None]:
!aws s3 sync ../homely-parsed-images-processed/ s3://homely-parsed-images-processed-1024 --delete