In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Virtual Try-On - Generation at Scale

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-creative-studio/blob/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%vertex-ai-creative-studio%2Fmain%2Fexperiments%2FImagen_Product_Recontext%imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-creative-studio/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/bigquery/import?url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-creative-studio/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg" alt="BigQuery Studio logo"><br> Open in BigQuery Studio
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-creative-studio/blob/main/experiments/Imagen_Product_Recontext/imagen_product_recontext_at_scale.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

 Author : Layolin Jesudhass , Gen AI Solution Architect


# Virtual Try-On
 This notebook demonstrates the virtual try-on process.
 It includes the uploaded person image and try-on results for selected outfits.
 Try on multiple outfits in parallel and compare your looks side by side in a single view.

In [None]:
from PIL import Image
from IPython.display import display
import base64
import io
import time
import os
import concurrent.futures
from IPython.display import display
from google.cloud import aiplatform
from google.cloud.aiplatform.gapic import PredictionServiceClient
from IPython.display import display, HTML

# Constants
PROJECT_ID = "consumer-genai-experiments"
LOCATION = "us-central1"
MODEL_ID = "virtual-try-on-exp-05-31"
PRODUCT_IMAGE_FILES = ["red.jpg", "green.png", "dress.png", "blue.png", "yellow.png"]
TARGET_SIZE = (250, 550)
model_endpoint = f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}"


In [None]:
# --- Utility functions ---

def load_image_bytes(path):
    with open(path, "rb") as f:
        return f.read()

def encode_image(img_bytes):
    return base64.b64encode(img_bytes).decode("utf-8")

def prediction_to_pil_image(prediction, size=TARGET_SIZE):
    encoded = prediction["bytesBase64Encoded"]
    decoded = base64.b64decode(encoded)
    image = Image.open(io.BytesIO(decoded)).convert("RGB")
    return image.resize(size)

def run_tryon(person_b64, name, b64):
    start = time.time()
    client = PredictionServiceClient(client_options={"api_endpoint": f"{LOCATION}-aiplatform.googleapis.com"})
    instances = [{
        "personImage": {"image": {"bytesBase64Encoded": person_b64}},
        "productImages": [{"image": {"bytesBase64Encoded": b64}}],
    }]
    response = client.predict(endpoint=model_endpoint, instances=instances, parameters={})
    elapsed = time.time() - start
    output_img = prediction_to_pil_image(response.predictions[0])
    return output_img, elapsed


In [None]:
# Load and encode person image
person_image_path = "model.png"
with open(person_image_path, "rb") as f:
    person_bytes = f.read()

def encode_image(img_bytes):
    return base64.b64encode(img_bytes).decode("utf-8")

person_b64 = encode_image(person_bytes)

# Open and resize the person image for display (max 250x550, keeping aspect ratio)
person_img = Image.open(person_image_path)
person_img.thumbnail((250, 550))  # Resize in-place maintaining aspect ratio

# Display inline small
display(person_img)

In [None]:
def pil_image_to_base64(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()

product_pil_images = []
for file_name in PRODUCT_IMAGE_FILES:
    img = Image.open(file_name)
    img = img.resize((150, 250))  # fixed size (width=150, height=200)
    product_pil_images.append((file_name, img))

imgs_html = ""
for name, img in product_pil_images:
    img_b64 = pil_image_to_base64(img)
    imgs_html += f"""
    <div style="text-align:center; margin-right:15px;">
        <img src="data:image/png;base64,{img_b64}" style="max-height:250px; max-width:150px;"><br>
        <b>{name}</b><br>
    </div>
    """

display(HTML(f"<div style='display:flex; align-items:center;'>{imgs_html}</div>"))

In [None]:
# Prepare product_data: a list of (file_name, base64_encoded_image) tuples
product_data = []
for file_name in PRODUCT_IMAGE_FILES:
    with open(file_name, "rb") as f:
        img_bytes = f.read()
        img_b64 = base64.b64encode(img_bytes).decode("utf-8")
        product_data.append((file_name, img_b64))

In [None]:
# --- Run inference in parallel ---
results = []

def run_thread(item):
    name, b64 = item
    out_img, elapsed = run_tryon(person_b64, name, b64)
    return (name, out_img, elapsed)

with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = [executor.submit(run_thread, pd) for pd in product_data]
    for f in concurrent.futures.as_completed(futures):
        results.append(f.result())

In [None]:
ordered_results = sorted(results, key=lambda x: PRODUCT_IMAGE_FILES.index(x[0]))

imgs_html = ""
for name, out_img, elapsed in ordered_results:
    img_b64 = pil_image_to_base64(out_img)
    imgs_html += f"""
    <div style="text-align:center; margin-right:15px;">
        <img src="data:image/png;base64,{img_b64}" style="max-height:250px; max-width:150px;"><br>
        <b>{"Time Taken"}</b><br>
        <small>{elapsed:.2f}s</small>
    </div>
    """

display(HTML(f"<div style='display:flex; align-items:center;'>{imgs_html}</div>"))