Fine Tune Imagen Model

Overview

Imagen model is a text to image model from Google AI, trained on massive dataset of images. You can generate images for different scenarios such as ads, photo shoots, blog post images, and etc. Sometimes, we want to personalize the models to achieve our goals.We can fine tune Imagen by using 5-10 images to do that

The notebook is structured as follows
1. We check all the prerequisites to make sure we have all the gcp resources ready for fine tuning
2. Then will see how to kick off an imagen fine-tuning pipeline
3. After the pipeline get kicked off, will see how to poll the job state
4. Once the job successfully finishes running, we will see how to get the fine-tuned model endpoint.
5. In the end, we will see how to generate imagesby querying the fine-tuned model

Pre-requisites

Quota Check: If we don't have enough, go to Quotas page and increase quota for the following metric and region:
metric:aiplatform.googleapis.com/restricted_image_training_a2_cpus and region:us-east4

Training Images setup:
Create a Cloud Storage bucket. Upload local files to the Cloud Storage bucket (JPEG, PNG, GIF or BMP files)

Create a CSV input file that lists the Cloud storage location of the fine-tuning images.

Performance & Limitations:

Maximum fine-tuning training file size: 10MB
Concurrent fine-tuning subjects per project: 1
Generate API calls(prompts per minute per project): 5
Fine-tune a subject model: up to 120 minutes
Generate request: up to 20 seconds

Fine-tune it for a specific subject


In [None]:
import requests
import json
DISPLAY_NAME = ''
GCS_OUTPUT_DIR = ''
SUBJECT_ID = ''
CLASS_NAME = ''
PROJECT_ID = ''
IMAGE_CSV_URI = ''
LOCATION = ""
DEPLOYMENT_REPLICA_COUNT =  #It's an integer not string
SERVICE_ACCOUNT = ''
TEMPLATE_URI = ""
BEAR_TOKEN = ""

In [None]:
payload = {
    "displayName": DISPLAY_NAME,
    "runtimeConfig": {
        "gcsOutputDirectory": GCS_OUTPUT_DIR,
        "parameterValues": {
            "subject_id": SUBJECT_ID,
            "class_name": CLASS_NAME,
            "project": PROJECT_ID,
            "image_csv_uri": IMAGE_CSV_URI,
            "location": LOCATION,
            "deployment_replica_count": DEPLOYMENT_REPLICA_COUNT
        }
    },
    "templateUri": TEMPLATE_URI,
    "serviceAccount": SERVICE_ACCOUNT
}

In [None]:
json_payload = json.dumps(payload)

In [None]:
url = f'https://{LOCATION}-aiplatform.googlepis.com/ui/projects/{PROJECT_ID}/locations/{LOCATION}/pipelineJobs'
headers = {'content-type': 'application/json', 'charset': 'UTF-8', 'Authorization': BEAR_TOKEN}
try:
   r = requests.post(url, data=json_payload, headers=headers)
   print(r.json())
   res = r.json()
   pipelinejob_id = res['name'].split('/')[-1]
except Exception as e:
   print(e)

Poll Job States (Optional)

In [None]:
url = f'https://{LOCATION}-aiplatform.googlepis.com/ui/projects/{PROJECT_ID}/locations/{LOCATION}/pipelineJobs/{pipelinejob_id}'
headers = {'Authorization': BEAR_TOKEN}

In [None]:
try:
   r_pipeline_job = requests.get(url, headers=headers)
   r_pipeline_job.json()
   job_state = r_pipeline_job.json()['state']
   print(f"pipeline job state: {job_state}")
except Exception as e:
   print(e)

Get the fine-tuned model endpoint

In [None]:
try:
   r_completed_pipeline_job = requests.get(url, headers=headers)
   res = r_completed_pipeline_job.json()
   endpoint_id = res['output']['endpoint_resource_name']
except Exception as e:
   print(e)

In [None]:
import re

file_one = open("res_output.txt", "w")
str_val = str(res)
file_one.write(str_val.replace(",", "\n"))
file_one.close()
patrn = "output:endpoint_resource_name"
file_one = open("res_output.txt", "r")

for line in file_one:
    if re.search(patrn, line):
       if 'stringValue' not in line:
           LOCATION = line.split('\')[-2].split("/")[-3]
           print(LOCATION)
           endpoint_id = line.split('\')[-2].split("/")[-1]
           print(endpoint_id)

Generate Images from the fine-tuned model endpoint

In [None]:
TEXT_PROMPT = 'a stock photo of [diorperfume] perfume bottle on beach with beautiful sunset in background, cinematic, 8k, highly detailed, amazon.com'

In [None]:
finetuned_model_payload = {
    "instances": [
       {
         "prompt": TEXT_PROMPT
       } 
    ],
    "parameters": {
        "sampleCount": IMAGE_COUNT
    }
}

In [None]:
json_finetuned_model_payload = json.dumps(finetuned_model_payload)

In [None]:
url = f'https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/{endpoint_id}:predict'
headers = {'content-type': 'application/json', 'charset': 'UTF-8', 'Authorization': BEAR_TOKEN}
try:
   r_generated_images = requests.post(url, data=json_finetuned_model_payload, headers=headers)
except Exception as e:
   print(e)

In [None]:
import base64
from IPython.display import Image
from IPython.display import display
predict_list = r_generated_images.json()['predictions']
c=0
for image in predict_list:
    imgdata = base64.b64decode(image['bytesBase64Encoded'])
    filename = f'perfume{c}.jpg'
    with open(filename, 'wb') as f:
         f.write(imgdata)
    display(Image(filename=f'perfume{c}.jpg'))
    c+=1