In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Bus stop image generation using Imagen 3

## Overview

### Imagen 3

Imagen 3 on Vertex AI brings Google's state of the art generative AI capabilities to application developers. Imagen 3 is Google's highest quality text-to-image model to date. It's capable of creating images with astonishing detail. Thus, developers have more control when building next-generation AI products that transform their imagination into high quality visual assets. Learn more about [Imagen on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/image/overview).


In this notebook you will be able to generate test images of the bus stops which you can then process using the multimodal capabilities of Google Cloud implemented in this repo.

**NOTE**: this notebook is experimental and not fully integrated with the rest of the solution.

## Get started


### Install Vertex AI SDK for Python


In [None]:
%pip install --quiet --upgrade --user google-cloud-aiplatform

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the following cell to authenticate your environment.


In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information, initialize Vertex AI SDK and import libraries

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel, Image
import ipywidgets as widgets

PROJECT_ID = ""  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}
IMAGE_BUCKET = "" # @param {type:"string"}
BASE_IMAGE_FOLDER = "gs://" + IMAGE_BUCKET + "/base-images"
MASK_IMAGE_FOLDER = "gs://" + IMAGE_BUCKET + "/masks"
PROCESSING_FOLDER = "gs://" + IMAGE_BUCKET + "/images"

vertexai.init(project=PROJECT_ID, location=LOCATION)

### Define a helper function

In [None]:
import typing

import IPython.display
from PIL import Image as PIL_Image
from PIL import ImageOps as PIL_ImageOps


def display_image(
    image,
    max_width: int = 600,
    max_height: int = 350,
) -> None:
    pil_image = typing.cast(PIL_Image.Image, image._pil_image)
    if pil_image.mode != "RGB":
        # RGB is supported by all Jupyter environments (e.g. RGBA is not yet)
        pil_image = pil_image.convert("RGB")
    image_width, image_height = pil_image.size
    if max_width < image_width or max_height < image_height:
        # Resize to display a smaller notebook image
        pil_image = PIL_ImageOps.contain(pil_image, (max_width, max_height))
    IPython.display.display(pil_image)

## Select the image generation model and generation parameters





With Imagen 3, you also have the option to use Imagen 3 Fast. These two model options give you the choice to optimize for quality and latency, depending on your use case.

**Imagen 3:** Generates high quality images with natural lighting and increased photorealism.

**Imagen 3 Fast:** Suitable for creating brighter images with a higher contrast. Overall, you can see a 40% decrease in latency in Imagen 3 Fast compared to Imagen 2.

With Imagen 3 and Imagen 3 Fast, you can also configure the `aspect ratio` to any of the following:
* 1:1
* 9:16
* 16:9
* 3:4
* 4:3

In [None]:
model_selector = widgets.Dropdown(
    options=[("Imagen 3", "imagen-3.0-generate-001"), ("Imagen 3 Fast",  "imagen-3.0-fast-generate-001")],
    description='Model type:',
    disabled=False,
)

ratio_selector = widgets.Dropdown(
    options=["1:1", "9:16", "16:9", "3:4", "4:3"],
    description='Aspect ratio:',
    value="16:9",
    disabled=False,
)

display(model_selector, ratio_selector)

### Generate image

In [None]:
import matplotlib.pyplot as plt

generation_model = ImageGenerationModel.from_pretrained(model_selector.value)
ratio = ratio_selector.value
model_name = "Imagen 3 Fast" if "fast" in model_selector.value else "Imagen 3"

prompt = """
A photo of a bus stop with clean glass, city street background.
Add "Main Street" name to the stop.
"""

generate_image_response = generation_model.generate_images(
    prompt=prompt,
    number_of_images=1,
    aspect_ratio=ratio,
    safety_filter_level="block_some",
    person_generation="allow_adult",

)

# Display generated images
fig, axis = plt.subplots(1, 2, figsize=(12, 6))
axis[0].imshow(generate_image_response[0]._pil_image)
axis[0].set_title("Model: " + model_name + ", ratio: " + ratio)
# axis[1].imshow(fast_image[0]._pil_image)
# axis[1].set_title("Imagen 3 Fast")
for ax in axis:
    ax.axis("off")
plt.show()

# Save the image as the base bus stop image

Once you are happy with the image, save it to the image bucket. Later on you will be using these images to simulate taking new pictures.

In [None]:
stop_name= "main-at-4th" # @param {type:"string"}

extension_selector = widgets.Dropdown(
    options=[("PNG", "png"), ("JPEG", "jpeg")],
    description="Format:"
)

display(extension_selector)

In [None]:
image_extension = extension_selector.value
original_image = generate_image_response[0]
temp_file = "/tmp" + "/" + stop_name + "." + image_extension
original_image.save(
    location=temp_file, include_generation_parameters=True)

!gcloud storage cp {temp_file} {BASE_IMAGE_FOLDER}/{stop_name}.{image_extension}

In [None]:

temp_file = "/tmp" + "/" + stop_name + "." + image_extension
edited_image.save(
    location=temp_file, include_generation_parameters=True)

import datetime
now = datetime.datetime.now()
timestamp = now.strftime('%Y-%m-%dT%H:%M:%S')

!gcloud storage cp {temp_file} {PROCESSING_FOLDER}/{stop_name}-{timestamp}.jpeg --custom-metadata=stop_id={stop_name}
