# Vertex AI GenAI Embeddings - As Features For Hierarchical Classification

**IN DEVELOPMENT**

Embeddings are vector representations of text or images or both.  These are vectors of floating point numbers that come from a model that has been trained to embed content in a way that efficiently represents the content.

Getting embeddings for text, or multimodel text and images using Vetex AI foundational models is demonstrated in notebook []().

This notebook show a use case for embeddings as features. 


Workflow:
- Review product catelog data in BigQuery Public table: `bigquery-public-data.thelook_ecommerce.products`
- Create a table with embeddings for:
    - `name` = A brief description of the product
    - `department` = The first level of the product catelog
    - `category` = The second level of the product catelog
- Setup BigQuery Resource Connection
- Use ML.* to generate embeddings using Vertex AI


---
## Setup

In [1]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

'statmike-mlops-349915'

In [10]:
REGION = 'us-central1'
SERIES = 'applied-genai'
EXPERIMENT = 'embed-feature-classifier'

In [11]:
# make this the BQ Project / Dataset / Table prefix to store results
BQ_PROJECT = PROJECT_ID
BQ_DATASET = SERIES.replace('-', '_')
BQ_TABLE = EXPERIMENT
BQ_REGION = REGION[0:2] # subset to first two characters for multi-region

In [168]:
import json
import numpy as np
import vertexai.language_models
import bigframes.pandas as bf
import bigframes.ml as bfml
from bigframes.ml import llm
from bigframes.ml import model_selection
from google.cloud import bigquery_connection_v1 as bq_connection
from google.cloud import bigquery

In [69]:
vertexai.init(project = PROJECT_ID, location = REGION)
bq = bigquery.Client(project = PROJECT_ID)
bf.reset_session()
bf.options.bigquery.project = BQ_PROJECT
bf.options.bigquery.location = BQ_REGION
bf_session = bf.get_global_session()

---
## Review Data Source

BigQuery Public table `bigquery-public-data.thelook_ecommerce.products`.

In [143]:
products = bf.read_gbq('bigquery-public-data.thelook_ecommerce.products')

HTML(value='Query job cdad59d2-99e9-4838-b83c-6c4fcde61a37 is RUNNING. <a target="_blank" href="https://consol…

In [144]:
products.dtypes

id                                  Int64
cost                              Float64
category                  string[pyarrow]
name                      string[pyarrow]
brand                     string[pyarrow]
retail_price                      Float64
department                string[pyarrow]
sku                       string[pyarrow]
distribution_center_id              Int64
dtype: object

In [145]:
products['department'].unique().tolist()

HTML(value='Query job d8ac6f5b-2ada-4ae1-a689-1dfddd252e81 is DONE. 410.5 kB processed. <a target="_blank" hre…

['Men', 'Women']

In [146]:
products['category'].unique().tolist()

HTML(value='Query job 4232e82f-71a9-4270-a0f2-f7a62af48387 is DONE. 594.2 kB processed. <a target="_blank" hre…

['Swim',
 'Jeans',
 'Pants',
 'Socks',
 'Active',
 'Shorts',
 'Sweaters',
 'Underwear',
 'Accessories',
 'Tops & Tees',
 'Sleep & Lounge',
 'Outerwear & Coats',
 'Suits & Sport Coats',
 'Fashion Hoodies & Sweatshirts',
 'Plus',
 'Suits',
 'Skirts',
 'Dresses',
 'Leggings',
 'Intimates',
 'Maternity',
 'Clothing Sets',
 'Pants & Capris',
 'Socks & Hosiery',
 'Blazers & Jackets',
 'Jumpsuits & Rompers']

In [147]:
products['name'].head()

HTML(value='Query job 2c686d1a-c88c-41ba-b54d-37973a3bfff8 is DONE. 233.0 kB processed. <a target="_blank" hre…

HTML(value='Query job 353b34a5-78ac-4a57-bf19-8dd6bd8b91f1 is DONE. 1.7 MB processed. <a target="_blank" href=…

0       2XU Men's Swimmers Compression Long Sleeve Top
1           TYR Sport Men's Square Leg Short Swim Suit
2      TYR Sport Men's Solid Durafast Jammer Swim Suit
3    TYR Sport Men's Swim Short/Resistance Short Sw...
4                      TYR Alliance Team Splice Jammer
Name: name, dtype: string

---
## Create BigQuery Dataset

In [148]:
# create/link to dataset
ds = bigquery.DatasetReference(BQ_PROJECT, BQ_DATASET)
ds.location = BQ_REGION
ds.labels = {'series': f'{SERIES}'}
ds = bq.create_dataset(dataset = ds, exists_ok = True) 

---
## BigQuery ML: Connect To Vertex AI LLMs with ML.GENERATE_TEXT

BigQuery ML can `Create Model`s that are actually connections to Remote Models. [Reference](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-remote-model)

Using the `REMOTE_SERVICE_TYPE = "CLOUD_AI_LARGE_LANGUAGE_MODEL_V1"` option will link to LLMs in Vertex AI!

### Connection Requirement

To make a remote connection using BigQuery ML, BigQuery uses a CLOUD_RESOURCE connection. [Reference](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-remote-model#connection)

Make sure the [BigQuery Connection API](https://cloud.google.com/bigquery/docs/create-cloud-resource-connection) is enabled:

In [149]:
!gcloud services enable bigqueryconnection.googleapis.com

Create a new connection with type `CLOUD_RESOURCE`: First, check for existing connection.

In [152]:
try:
    response = bq_connection.ConnectionServiceClient().get_connection(
            request = bq_connection.GetConnectionRequest(
                name = f"projects/{BQ_PROJECT}/locations/{BQ_REGION}/connections/{SERIES}_{EXPERIMENT}"
            )
    )
    print(f'Found existing connection with service account: {response.cloud_resource.service_account_id}')
    service_account = response.cloud_resource.service_account_id
except Exception:
    request = bq_connection.CreateConnectionRequest(
        {
            "parent": f"projects/{BQ_PROJECT}/locations/{BQ_REGION}",
            "connection_id": f"{SERIES}_{EXPERIMENT}",
            "connection": bq_connection.types.Connection(
                {
                    "friendly_name": f"{SERIES}_{EXPERIMENT}",
                    "cloud_resource": bq_connection.CloudResourceProperties({})
                }
            )
        }
    )
    response = bq_connection.ConnectionServiceClient().create_connection(request)
    print(f'Created new connection with service account: {response.cloud_resource.service_account_id}')
    service_account = response.cloud_resource.service_account_id
    # assign the service account the Vertex AI User Role:
    !gcloud projects add-iam-policy-binding {BQ_PROJECT} --member=serviceAccount:{service_account} --role=roles/aiplatform.user

Created new connection with service account: bqcx-1026793852137-pdxa@gcp-sa-bigquery-condel.iam.gserviceaccount.com
Updated IAM policy for project [statmike-mlops-349915].
bindings:
- members:
  - serviceAccount:service-1026793852137@gcp-sa-aiplatform-cc.iam.gserviceaccount.com
  role: roles/aiplatform.customCodeServiceAgent
- members:
  - serviceAccount:service-1026793852137@gcp-sa-aiplatform-vm.iam.gserviceaccount.com
  role: roles/aiplatform.notebookServiceAgent
- members:
  - serviceAccount:service-1026793852137@gcp-sa-aiplatform.iam.gserviceaccount.com
  role: roles/aiplatform.serviceAgent
- members:
  - deleted:serviceAccount:bqcx-1026793852137-79ue@gcp-sa-bigquery-condel.iam.gserviceaccount.com?uid=108216671037418333398
  - deleted:serviceAccount:bqcx-1026793852137-a2ne@gcp-sa-bigquery-condel.iam.gserviceaccount.com?uid=113722269076525797130
  - deleted:serviceAccount:bqcx-1026793852137-iszu@gcp-sa-bigquery-condel.iam.gserviceaccount.com?uid=106642351460101305872
  - serviceAcco

**NOTE**: The step above created a service account and assigned it the Vertex AI User Role.  This may take a moment to be recognized in the steps below.  If you get an error in one of the cells below try rerunning it.

### Create The Remote Model In BigQuery

Create a temp model that connects to text embedding model on Vertex AI - [Reference](https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.llm.PaLM2TextEmbeddingGenerator)

In [153]:
embed_model = bfml.llm.PaLM2TextEmbeddingGenerator(
    session = bf_session,
    connection_name = f'{BQ_PROJECT}.{BQ_REGION}.{SERIES}_{EXPERIMENT}'
)

HTML(value='Query job 97b629d1-2e28-48de-98b0-374edfb4ec6d is RUNNING. <a target="_blank" href="https://consol…

---
## Create Embeddings

### For Product Descriptions: Name

**NOTE**: The following cell will create embedding requests for all 29k+ values in the `name` column and could take around **10 minutes** to run.

In [156]:
products = products.join(embed_model.predict(products['name']).rename(columns={'text_embedding':'name_embedding'}))

HTML(value='Query job 14ef3e8b-fdb9-4992-96dc-b9678974591b is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job 45257cb3-2e87-4b77-97f4-1cd5cc54dade is DONE. 233.0 kB processed. <a target="_blank" hre…

In [157]:
products.head()

HTML(value='Query job 866374a7-3ad8-4c99-9963-d1400899a6c2 is DONE. 465.9 kB processed. <a target="_blank" hre…

HTML(value='Query job 6b6c4e90-15b1-4d80-bf84-1777c1ae19cb is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job 4d9593e3-39f7-4d52-9f39-f12bae5d9f15 is DONE. 0 Bytes processed. <a target="_blank" href…

Unnamed: 0,id,cost,category,name,brand,retail_price,department,sku,distribution_center_id,name_embedding
0,27569,92.652563,Swim,2XU Men's Swimmers Compression Long Sleeve Top,2XU,150.410004,Men,B23C5765E165D83AA924FA8F13C05F25,1,"[0.021811574697494507, -0.0068705384619534016,..."
1,27445,24.719661,Swim,TYR Sport Men's Square Leg Short Swim Suit,TYR,38.990002,Men,2AB7D3B23574C3DEA2BD278AFD0939AB,1,"[0.04419781640172005, -0.009351101703941822, 0..."
2,27457,15.8976,Swim,TYR Sport Men's Solid Durafast Jammer Swim Suit,TYR,27.6,Men,8F831227B0EB6C6D09A0555531365933,1,"[0.0471641980111599, -0.03273119032382965, 0.0..."
3,27466,17.85,Swim,TYR Sport Men's Swim Short/Resistance Short Sw...,TYR,30.0,Men,67317D6DCC4CB778AEB9219565F5456B,1,"[0.049136240035295486, 0.0037870346568524837, ..."
4,27481,29.408001,Swim,TYR Alliance Team Splice Jammer,TYR,45.950001,Men,213C888198806EF1A0E2BBF2F4855C6C,1,"[0.0008693744894117117, -0.00447087874636054, ..."


### For Level 1 of Product Hierarchy: Department

This step will run quickly as it only creates embedding request for unique values of `department`.

In [158]:
department = products['department'].unique().to_frame()
department = department.join(embed_model.predict(department).rename(columns={'text_embedding':'department_embedding'}))

HTML(value='Query job 94eb3612-d324-40e4-98fd-6c0f182b69e3 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job b706b330-cb25-4916-aa59-fbadaa928e91 is DONE. 16 Bytes processed. <a target="_blank" hre…

In [159]:
department.head()

HTML(value='Query job 6e0e559c-be11-4df0-8e0e-414e7250a4b9 is DONE. 643.5 kB processed. <a target="_blank" hre…

HTML(value='Query job 4ff967fd-f22e-46ee-8347-df4bf85bc60d is DONE. 655.8 kB processed. <a target="_blank" hre…

HTML(value='Query job 748aea2c-205d-429f-8a80-9e627cbbc98b is DONE. 0 Bytes processed. <a target="_blank" href…

Unnamed: 0,department,department_embedding
0,Men,"[-0.04335380345582962, 0.00046764410217292607,..."
13131,Women,"[-0.03191345930099487, -0.006457726005464792, ..."


In [160]:
products = products.merge(department, on = 'department')
products.dtypes

id                                  Int64
cost                              Float64
category                  string[pyarrow]
name                      string[pyarrow]
brand                     string[pyarrow]
retail_price                      Float64
department                string[pyarrow]
sku                       string[pyarrow]
distribution_center_id              Int64
name_embedding                     object
department_embedding               object
dtype: object

### For Level 2 of Product Hierarchy: Category

This step will run quickly as it only creates embedding request for unique values of `category`.

In [161]:
category = products['category'].unique().to_frame()
category = category.join(embed_model.predict(category).rename(columns={'text_embedding':'category_embedding'}))

HTML(value='Query job 23121808-f56a-4560-b295-40431cfa9550 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job ac7d99da-dfb8-4fdd-8234-94e93674d7ae is DONE. 208 Bytes processed. <a target="_blank" hr…

In [162]:
category.head()

HTML(value='Query job 6a45623a-0499-4aa2-9455-5e4736523ff7 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job dc7b41ea-e84d-4dea-9564-65a1383a07b7 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job 8e1ef2c7-ac05-42cd-b5f7-164358984514 is DONE. 0 Bytes processed. <a target="_blank" href…

Unnamed: 0,category,category_embedding
0,Swim,"[0.01972227729856968, -0.01661798171699047, -0..."
906,Jeans,"[-0.016635224223136902, 0.0025853165425360203,..."
2023,Pants,"[-7.701403956161812e-05, 0.015807654708623886,..."
3064,Socks,"[0.05488337576389313, 0.006218045484274626, 0...."
3969,Active,"[0.028275124728679657, -0.00869790930300951, 0..."


In [163]:
products = products.merge(category, on = 'category')
products.dtypes

id                                  Int64
cost                              Float64
category                  string[pyarrow]
name                      string[pyarrow]
brand                     string[pyarrow]
retail_price                      Float64
department                string[pyarrow]
sku                       string[pyarrow]
distribution_center_id              Int64
name_embedding                     object
department_embedding               object
category_embedding                 object
dtype: object

### Make BigQuery Tables of Results

The `products`, `department`, and `category` dataframes are currently temporary tables in BigQuery.  To recall these for future use it is best to store them as actual BigQuery tables using the [.to_gbq](https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.dataframe.DataFrame#bigframes_dataframe_DataFrame_to_gbq) method as follows.

In [210]:
products.to_gbq(f'{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE}_products', if_exists = 'replace', index = False)
department.to_gbq(f'{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE}_deparment', if_exists = 'replace', index = False)
category.to_gbq(f'{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE}_category', if_exists = 'replace', index = False)

HTML(value='Query job 7c5975b1-8399-4e0e-b5c9-d8c5b4d4343f is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job 48884307-ef71-427a-b750-2fc5cf32ea35 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job b066c690-ec7c-4f80-9912-0383beda3b91 is RUNNING. <a target="_blank" href="https://consol…

---
## Prepare Data For ML

### Create product hierarchy set: department

In [217]:
department.dtypes

department              string[pyarrow]
department_embedding             object
dtype: object

In [215]:
department_hierarchy = department[['department', 'department_embedding']].rename(columns = {"department":"hierarchy_node", "department_embedding":"hierarchy_node_embedding"})
department_hierarchy['hierarchy_level'] = 'department'
department_hierarchy.dtypes

hierarchy_node              string[pyarrow]
hierarchy_node_embedding             object
hierarchy_level             string[pyarrow]
dtype: object

In [216]:
department_hierarchy.head()

HTML(value='Query job 651eebe0-ab8f-4de6-946d-5cf67a0f2f7e is DONE. 643.5 kB processed. <a target="_blank" hre…

HTML(value='Query job 46dab08e-85ac-44ea-9577-cb8c75282ab9 is DONE. 655.8 kB processed. <a target="_blank" hre…

HTML(value='Query job 6dd14dd8-4415-490c-8458-56eb2c3a2ce9 is DONE. 0 Bytes processed. <a target="_blank" href…

Unnamed: 0,hierarchy_node,hierarchy_node_embedding,hierarchy_level
0,Men,"[-0.04335380345582962, 0.00046764410217292607,...",department
13131,Women,"[-0.03191345930099487, -0.006457726005464792, ...",department


### Create product hierarchy set: category

In [218]:
category.dtypes

category              string[pyarrow]
category_embedding             object
dtype: object

In [220]:
category_hierarchy = category[['category', 'category_embedding']].rename(columns = {"category":"hierarchy_node", "category_embedding":"hierarchy_node_embedding"})
category_hierarchy['hierarchy_level'] = 'category'
category_hierarchy.dtypes

hierarchy_node              string[pyarrow]
hierarchy_node_embedding             object
hierarchy_level             string[pyarrow]
dtype: object

In [221]:
category_hierarchy.head()

HTML(value='Query job c1c7efd8-af3f-464c-a20a-0e3590b3bdf8 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job 0e6d70dc-caea-437c-ada4-585ce75cdde1 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job eba48217-0ddf-45c7-a742-a89119f0bd6b is DONE. 0 Bytes processed. <a target="_blank" href…

Unnamed: 0,hierarchy_node,hierarchy_node_embedding,hierarchy_level
0,Swim,"[0.01972227729856968, -0.01661798171699047, -0...",category
906,Jeans,"[-0.016635224223136902, 0.0025853165425360203,...",category
2023,Pants,"[-7.701403956161812e-05, 0.015807654708623886,...",category
3064,Socks,"[0.05488337576389313, 0.006218045484274626, 0....",category
3969,Active,"[0.028275124728679657, -0.00869790930300951, 0...",category


### Create product hiearchy: combine department and category

In [223]:
product_hierarchy = bf.concat([department_hierarchy, category_hierarchy])

In [225]:
product_hierarchy.head()

HTML(value='Query job 920306ed-0750-4222-97e4-494564794a82 is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job 4fa5c95b-c822-466f-94a2-c20a036c5b0b is RUNNING. <a target="_blank" href="https://consol…

HTML(value='Query job caf91397-d4e1-4118-b674-86dabbe3c065 is DONE. 0 Bytes processed. <a target="_blank" href…

Unnamed: 0,hierarchy_node,hierarchy_node_embedding,hierarchy_level
0,Men,"[-0.04335380345582962, 0.00046764410217292607,...",department
13131,Women,"[-0.03191345930099487, -0.006457726005464792, ...",department
0,Swim,"[0.01972227729856968, -0.01661798171699047, -0...",category
906,Jeans,"[-0.016635224223136902, 0.0025853165425360203,...",category
2023,Pants,"[-7.701403956161812e-05, 0.015807654708623886,...",category


### Create a test and train subsets of products

Create index's for rows allocated to training and test splits:

In [211]:
# retrieve index of all rows
full_index = products.index.to_numpy()
# randomly sort the full index
np.random.shuffle(full_index)
# split the randomly sorted index into 10 sequential parts
split_index = np.split(full_index, 10)
# allocate the first 9 splits (90%) to a training index
train_index = np.concatenate(split_index[0:9])
# allocate the last split (10%) to a test_index
test_index = split_index[9]

# print out the sizes of the indexes:
full_index.shape[0], train_index.shape[0], test_index.shape[0]

HTML(value='Query job 5a40838c-31c9-46ee-870c-90e6505784d1 is DONE. 0 Bytes processed. <a target="_blank" href…

(29120, 26208, 2912)

In [212]:
train_products = products[['name_embedding', 'category', 'department']].iloc[train_index.tolist()]
test_products = products[['name_embedding', 'category', 'department']].iloc[test_index.tolist()]

HTML(value='Load job e0df97c4-5e58-42de-996a-7f4a24938548 is RUNNING. <a target="_blank" href="https://console…

HTML(value='Load job a610fcc6-0540-41e2-be9f-23fdbcb1dcc5 is RUNNING. <a target="_blank" href="https://console…

In [151]:
bq_connection.ConnectionServiceClient().delete_connection(name = f"projects/{BQ_PROJECT}/locations/{BQ_REGION}/connections/{SERIES}_{EXPERIMENT}")