In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Open Models (Stable Diffusion XL)


## Overview

This notebook demonstrates how you can call open models that are deployed on Vertex AI Prediction. The model used in this notebook is:
 * [Stable Diffusion XL]

### Objective

* Send requests to the Vertex AI Prediction Endpoint that is hosting Stable Diffusion 2.1


### Assumptions
- The model is already deployed via Model Garden's one-click deploy feature

## Before you begin

### Installation
Run the cell below if this is your first time running the notebook. Else, feel free to skip the cell below as the libraries would have already been installed.

In [None]:
%pip install --upgrade --user --quiet google-cloud-aiplatform
%pip install --quiet tensorflow
! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

### Import libraries

In [None]:
import datetime
import importlib
import os
import uuid
from typing import Tuple


from google.cloud import aiplatform

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

### Set your project ID and region

In [None]:
# Get the default cloud project id.
PROJECT_ID= !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

print(f"Project ID:", PROJECT_ID)

aiplatform.init(project=PROJECT_ID)

endpoints = {}

## Predicting with Vertex AI Prediction Endpoints

After you've deployed your target model to a Vertex AI Prediction Endpoint, you can send requests to the endpoint with text prompts based on your `template`. Note that the first few prompts will take longer to execute. 

First, let's retrieved the details of the endpoints

In [None]:
# Retrieve the Vertex AI Prediction Endpoint IDs and set it
check_regions = ["us-central1", "asia-southeast1", "europe-west4"]

for region in check_regions:
    all_endpoints = aiplatform.Endpoint.list(location=region)
    for endpoint in all_endpoints:
        full_endpoint = f"projects/{PROJECT_ID}/locations/{region}/endpoints/{endpoint.name}"
        
        if endpoint.display_name == "stabilityai_stable-diffusion-xl-1-mg-one-click-deploy":
            endpoints['sdxl'] = aiplatform.Endpoint(full_endpoint)

print(f"Stable Diffusion XL Endpoint Name: {endpoints['sdxl'].display_name}")

### Generating Images with Stable Diffusion XL

After your model is deployed, you'll be able to generate images by sending text prompts to the endpoint. Try your hand at generating some images below! 

**Example:**
```
> A photo of an astronaut riding a horse on mars
> A stone castle in a forest by the river
```

In [None]:
# Create your prompts by adding them to a prompt list

comma_separated_prompt_list = "A photo of an astronaut riding a horse on mars, A stone castle in a forest by the river"  # @param {type: "string"}
prompt_list = [x.strip() for x in comma_separated_prompt_list.split(",")]

# [Optional] Set a negative prompt to define what you don't want to see.
negative_prompt = ""

# Set parameters
height = 768
width = 768
num_inference_steps = 25
guidance_scale = 7.5


# Construct instance list
instances = [{"text": prompt} for prompt in prompt_list]
parameters = {
    "negative_prompt": negative_prompt,
    "height": height,
    "width": width,
    "num_inference_steps": num_inference_steps,
    "guidance_scale": 7.5,
}

# Send prompts and parameters to the endpoint
response = endpoints['sdxl'].predict(
    instances=instances, parameters=parameters
)

# Display the generated images
images = [
    common_util.base64_to_image(prediction.get("output"))
    for prediction in response.predictions
]
display(common_util.image_grid(images, rows=math.ceil(len(images) ** 0.5)))