![tracker](https://us-central1-vertex-ai-mlops-369716.cloudfunctions.net/pixel-tracking?path=statmike%2Fvertex-ai-mlops%2FApplied+ML%2FSolution+Prototypes%2Fdocument-processing&file=4a-document-classification.ipynb)
<!--- header table --->
<table align="left">
  <td style="text-align: center">
    <a href="https://github.com/statmike/vertex-ai-mlops/blob/main/Applied%20ML/Solution%20Prototypes/document-processing/4a-document-classification.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo">
      <br>View on<br>GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/statmike/vertex-ai-mlops/blob/main/Applied%20ML/Solution%20Prototypes/document-processing/4a-document-classification.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo">
      <br>Run in<br>Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fstatmike%2Fvertex-ai-mlops%2Fmain%2FApplied%2520ML%2FSolution%2520Prototypes%2Fdocument-processing%2F4a-document-classification.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo">
      <br>Run in<br>Colab Enterprise
    </a>
  </td>      
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/bigquery/import?url=https://github.com/statmike/vertex-ai-mlops/blob/main/Applied%20ML/Solution%20Prototypes/document-processing/4a-document-classification.ipynb">
      <img width="32px" src="https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg" alt="BigQuery logo">
      <br>Open in<br>BigQuery Studio
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/statmike/vertex-ai-mlops/main/Applied%20ML/Solution%20Prototypes/document-processing/4a-document-classification.ipynb">
      <img width="32px" src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo">
      <br>Open in<br>Vertex AI Workbench
    </a>
  </td>
</table>

# Document Classification: More than Simlarity With Embeddings

> This workflow is part of a series of workflows for the solution prototype: [Document Processing With Generative AI: Parse, Extract, Validate Authenticity, and More](./readme.md)

In the previous workflow ([4-document-similarity](./4-document-similarity.ipynb)) we used distance between embeddings as a way of understanding document simlarity.  In this workflow the embeddings will be used to train a predictive model that will classify the documents into one of the vendors.  This model can then be used to predict which vendor a new document most likely comes from.  For a more direct approach that uses embeddings directly to classify check out the next workflow in this series: [5-document-anomalies](./5-document-anomalies.ipynb).

**References:**
- BigQuery ML [model journeys](https://cloud.google.com/bigquery/docs/e2e-journey)
  - BigQuery ML [random forest model type](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-random-forest)
  - An end-to-end workflow for [BQML Random Forest for Classification](../../../03%20-%20BigQuery%20ML%20%28BQML%29/03c%20-%20BQML%20Random%20Forest.ipynb)

## Setup

Note that this notebook expects to use a local virtual environment with the `./requirements.txt` installed.  

A potential workaround if using this notebook standalone is running:

>```python
>pip install -r requirements.txt
>```

And then restart the kernel.

In [None]:
# package imports for this work
import os, subprocess

from google.cloud import bigquery

In [2]:
# what project are we working in?
PROJECT_ID = subprocess.run(['gcloud', 'config', 'get-value', 'project'], capture_output=True, text=True, check=True).stdout.strip()
PROJECT_ID

'statmike-mlops-349915'

In [3]:
LOCATION = 'us-central1'

SERIES = 'applied-ml-solution-prototypes'
EXPERIMENT = 'document-processing'

In [4]:
# setup google cloud bigquery client
bq = bigquery.Client(project = PROJECT_ID)

# load the bigquery magics for jupyter with:
%load_ext bigquery_magics

---
## Review The Data Source

So far this project has:
- Built a custom data extractor with Document AI
- Created an object table in BigQuery that maps to documents stored in GCS
- Created a new table by processing the documents in the object table with the `ML.PROCESS_DOCUMENTS` function that uses the custom parser built with Document AI
  - Augmented the table with generated image embeddings for the documents using Vertex AI hosted embeddings models with the `ML.GENERATE_EMBEDDING` function

Before we procedd to work with these embedding let's review the data so far:


In [5]:
%%bigquery
SELECT *
FROM `statmike-mlops-349915.solution_prototype_document_processing.known_authenticity`
LIMIT 5

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,ml_process_document_result,ml_process_document_status,vendor_name,vendor_address,company_name,company_address,invoice_id,invoice_total,line_item,uri,updated,vendor,embedding
0,"{""entities"":[{""confidence"":1,""id"":""0"",""mention...",,,,BioTech Innovations Corp,"666 Genome Way\nSan Diego, CA 92121",KD-2024-0315,$37108.50,"[{'item_sku': 'WEB- DEV- 001', 'item_descripti...",gs://statmike-mlops-349915/applied-ml-solution...,2025-04-23 20:53:37.673000+00:00,vendor_2,"[0.0089465566, 0.0536203124, 0.0614382476, -0...."
1,"{""entities"":[{""confidence"":1,""id"":""0"",""mention...",,,,HealthAI Innovations,"123 Main Street\nSan Francisco, CA 94111",INV-2024-0315,$21924.00,"[{'item_sku': 'CSD- 001', 'item_description': ...",gs://statmike-mlops-349915/applied-ml-solution...,2025-04-23 20:53:35.765000+00:00,vendor_2,"[0.0121024447, 0.0618715286, 0.0657297298, -0...."
2,"{""entities"":[{""confidence"":1,""id"":""1"",""propert...",,,,GlobalMed Health,"123 Serene Drive\nSan Diego, CA 92101",INV-2024-1122,$23600.00,"[{'item_sku': 'WEB- DEV- 001', 'item_descripti...",gs://statmike-mlops-349915/applied-ml-solution...,2025-04-23 20:53:35.038000+00:00,vendor_2,"[0.0141071267, 0.0521478951, 0.0623387806, -0...."
3,"{""entities"":[{""confidence"":1,""id"":""0"",""mention...",,,,Swift Logistics Solutions,"987 Elm Street\nDallas, TX 75201",KD-2024-0722,$19920.00,"[{'item_sku': 'WEB- DEV- 001', 'item_descripti...",gs://statmike-mlops-349915/applied-ml-solution...,2025-04-23 20:53:39.182000+00:00,vendor_2,"[0.0090936739, 0.0496088825, 0.0622638054, -0...."
4,"{""entities"":[{""confidence"":1,""id"":""1"",""propert...",,,,Style Forward Retail,"99 Fashion Blvd Los Angeles, CA 90015",INV-2024-1105,$34800.00,"[{'item_sku': None, 'item_description': 'Web D...",gs://statmike-mlops-349915/applied-ml-solution...,2025-04-23 20:52:52.221000+00:00,vendor_12,"[0.0019979002, 0.0560282543, 0.0625561327, -0...."


In [6]:
%%bigquery
# document count per vendor:
SELECT vendor, count(*) as document_count
FROM `statmike-mlops-349915.solution_prototype_document_processing.known_authenticity`
GROUP BY vendor

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,vendor,document_count
0,vendor_2,16
1,vendor_12,22
2,vendor_7,29
3,vendor_0,19
4,vendor_10,22
5,vendor_13,19
6,vendor_5,16
7,vendor_6,19
8,vendor_9,23
9,vendor_3,19


In [7]:
%%bigquery temp
# review an embedding:
SELECT embedding
FROM `statmike-mlops-349915.solution_prototype_document_processing.known_authenticity`
LIMIT 1

Query is running:   0%|          |

Downloading:   0%|          |

In [8]:
temp['embedding'][0]

array([ 0.00894656,  0.05362031,  0.06143825, ...,  0.01621641,
       -0.03250906, -0.01569333], shape=(1408,))

---
## Train Model

Use BigQuery ML to train multiclass random forest model using XGBoost:
- [Random Forest](https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-create-random-forest) with BigQuery ML (BQML)
    
This example includes the [training options](https://cloud.google.com/bigquery-ml/docs/create_vertex) to register the resulting model in the [Vertex AI Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction).

In [None]:
%%bigquery
CREATE OR REPLACE MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`
    OPTIONS(
        model_type = 'RANDOM_FOREST_CLASSIFIER',
        auto_class_weights = TRUE,
        input_label_cols = ['vendor'],
        enable_global_explain = TRUE,
        data_split_method = 'AUTO_SPLIT',
        num_parallel_tree = 200,
        early_stop = TRUE,
        min_rel_progress = 0.01,
        tree_method = 'HIST',
        subsample = 0.85,
        colsample_bytree = 0.9
        #MODEL_REGISTRY = 'VERTEX_AI',
        #VERTEX_AI_MODEL_ID = 'bqml_{SERIES}_{EXPERIMENT}',
        #VERTEX_AI_MODEL_VERSION_ALIASES = ['{RUN_NAME}']
    )
    AS
    SELECT vendor, embedding
    FROM `statmike-mlops-349915.solution_prototype_document_processing.known_authenticity`

### Feature and Training Review

Review the Model Inputs, the feature information, with [ML.FEATURE_INFO](https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-feature):

In [13]:
%%bigquery
SELECT *
FROM ML.FEATURE_INFO(MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`)

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,input,min,max,mean,median,stddev,category_count,null_count,dimension
0,embedding,,,,,,,0,


Review the iterations from training with [ML.TRAINING_INFO](https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-train):

In [None]:
%%bigquery
SELECT *
FROM ML.TRAINING_INFO(MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`)
ORDER BY iteration

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,training_run,iteration,loss,eval_loss,learning_rate,duration_ms
0,0,1,0.096077,0.092539,1.0,673222


---
## Evaluate Model

### Metrics

Review the model evaluation statistics on the splits with [ML.EVALUATE](https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-evaluate):

In [16]:
%%bigquery
SELECT *
FROM ML.EVALUATE (MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`)


Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,precision,recall,accuracy,f1_score,log_loss,roc_auc
0,1.0,1.0,1.0,1.0,0.118445,1.0


In [21]:
%%bigquery
SELECT *
FROM ML.EVALUATE (
    MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`,
    (SELECT * FROM `statmike-mlops-349915.solution_prototype_document_processing.unknown_authenticity`)
)

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,precision,recall,accuracy,f1_score,log_loss,roc_auc
0,0.688929,0.655433,0.658863,0.652755,1.447006,0.926528


### Confusion Matrix

Review the confusion matrix for each split with [ML.CONFUSION_MATRIX](https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-confusion):

In [17]:
%%bigquery
SELECT *
FROM ML.CONFUSION_MATRIX(MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`)

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,expected_label,vendor_0,vendor_1,vendor_10,vendor_11,vendor_12,vendor_13,vendor_14,vendor_2,vendor_3,vendor_4,vendor_5,vendor_6,vendor_7,vendor_8,vendor_9
0,vendor_0,19,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,vendor_1,0,24,0,0,0,0,0,0,0,0,0,0,0,0,0
2,vendor_10,0,0,22,0,0,0,0,0,0,0,0,0,0,0,0
3,vendor_11,0,0,0,21,0,0,0,0,0,0,0,0,0,0,0
4,vendor_12,0,0,0,0,22,0,0,0,0,0,0,0,0,0,0
5,vendor_13,0,0,0,0,0,19,0,0,0,0,0,0,0,0,0
6,vendor_14,0,0,0,0,0,0,18,0,0,0,0,0,0,0,0
7,vendor_2,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0
8,vendor_3,0,0,0,0,0,0,0,0,19,0,0,0,0,0,0
9,vendor_4,0,0,0,0,0,0,0,0,0,14,0,0,0,0,0


In [20]:
%%bigquery
SELECT *
FROM ML.CONFUSION_MATRIX(
    MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`,
    (SELECT * FROM `statmike-mlops-349915.solution_prototype_document_processing.unknown_authenticity`)
)

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,expected_label,vendor_0,vendor_1,vendor_10,vendor_11,vendor_12,vendor_13,vendor_14,vendor_2,vendor_3,vendor_4,vendor_5,vendor_6,vendor_7,vendor_8,vendor_9
0,vendor_0,11,0,0,1,0,4,0,0,1,0,0,0,0,2,0
1,vendor_1,0,21,0,0,0,0,0,0,0,0,0,0,0,0,3
2,vendor_10,0,0,15,0,2,2,0,0,0,2,0,0,0,1,0
3,vendor_11,0,1,2,14,2,0,0,0,0,0,1,0,0,0,1
4,vendor_12,0,0,0,0,10,0,0,0,0,1,2,0,0,8,1
5,vendor_13,0,0,1,1,0,17,0,0,0,0,0,0,0,0,0
6,vendor_14,0,2,0,1,2,1,5,0,1,0,1,0,0,5,0
7,vendor_2,0,0,1,0,0,0,0,14,0,0,0,1,0,0,0
8,vendor_3,1,0,0,1,0,0,0,2,12,1,0,0,0,0,2
9,vendor_4,0,0,0,2,0,0,0,0,2,8,0,0,0,2,0


---
## Predictions With BigQuery ML (BQML)

Create a pandas dataframe with retrieved predictions for the test data in the table using [ML.PREDICT](https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-predict):

In [22]:
%%bigquery preds
SELECT *
FROM ML.PREDICT(
  MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`,
  (SELECT vendor, embedding FROM `statmike-mlops-349915.solution_prototype_document_processing.unknown_authenticity` LIMIT 5)
)

Query is running:   0%|          |

Downloading:   0%|          |

In [32]:
preds

Unnamed: 0,predicted_vendor,predicted_vendor_probs,vendor,embedding
0,vendor_1,"[{'label': 'vendor_9', 'prob': 0.0444977171719...",vendor_8,"[0.0160381682, 0.0116377119, 0.00194062886, 0...."
1,vendor_5,"[{'label': 'vendor_9', 'prob': 0.0584864690899...",vendor_5,"[0.00087831, 0.0551444367, 0.0510212779, -0.00..."
2,vendor_1,"[{'label': 'vendor_9', 'prob': 0.0441318489611...",vendor_5,"[0.0131794149, 0.0133737465, 0.00720742205, -0..."
3,vendor_10,"[{'label': 'vendor_9', 'prob': 0.0048541207797...",vendor_10,"[-0.00644508842, 0.0360948518, 0.0467085019, -..."
4,vendor_10,"[{'label': 'vendor_9', 'prob': 0.0040431618690...",vendor_10,"[-0.0120130433, 0.0339877978, 0.0499950275, -0..."


In [33]:
preds['predicted_vendor'].iloc[0]

'vendor_1'

In [None]:
preds['predicted_vendor_probs'].iloc[0]

array([{'label': 'vendor_9', 'prob': 0.044497717171907425},
       {'label': 'vendor_8', 'prob': 0.019621562212705612},
       {'label': 'vendor_7', 'prob': 0.00821880903095007},
       {'label': 'vendor_6', 'prob': 0.017902441322803497},
       {'label': 'vendor_5', 'prob': 0.02178679034113884},
       {'label': 'vendor_4', 'prob': 0.008760645054280758},
       {'label': 'vendor_3', 'prob': 0.009960188530385494},
       {'label': 'vendor_2', 'prob': 0.010932879522442818},
       {'label': 'vendor_14', 'prob': 0.008463104255497456},
       {'label': 'vendor_13', 'prob': 0.009418832138180733},
       {'label': 'vendor_12', 'prob': 0.008814872242510319},
       {'label': 'vendor_11', 'prob': 0.00910558644682169},
       {'label': 'vendor_10', 'prob': 0.25243809819221497},
       {'label': 'vendor_1', 'prob': 0.5611459016799927},
       {'label': 'vendor_0', 'prob': 0.008932585828006268}], dtype=object)

In [28]:
%%bigquery explains
SELECT *
FROM ML.EXPLAIN_PREDICT(
  MODEL `statmike-mlops-349915.solution_prototype_document_processing.rf_classifier`,
  (SELECT vendor, embedding FROM `statmike-mlops-349915.solution_prototype_document_processing.unknown_authenticity` LIMIT 5)
)

Query is running:   0%|          |

Downloading:   0%|          |

In [29]:
explains

Unnamed: 0,predicted_vendor,probability,top_feature_attributions,baseline_prediction_value,prediction_value,approximation_error,vendor,embedding
0,vendor_1,0.561146,"[{'feature': 'embedding', 'attribution': 3.851...",0.354058,4.205379,0.0,vendor_8,"[0.0160381682, 0.0116377119, 0.00194062886, 0...."
1,vendor_5,0.151268,"[{'feature': 'embedding', 'attribution': 0.565...",0.366685,0.932279,0.0,vendor_5,"[0.00087831, 0.0551444367, 0.0510212779, -0.00..."
2,vendor_1,0.560835,"[{'feature': 'embedding', 'attribution': 3.851...",0.354058,4.205379,0.0,vendor_5,"[0.0131794149, 0.0133737465, 0.00720742205, -0..."
3,vendor_10,0.920909,"[{'feature': 'embedding', 'attribution': 4.870...",0.357127,5.227555,0.0,vendor_10,"[-0.00644508842, 0.0360948518, 0.0467085019, -..."
4,vendor_10,0.916454,"[{'feature': 'embedding', 'attribution': 5.058...",0.357127,5.415724,0.0,vendor_10,"[-0.0120130433, 0.0339877978, 0.0499950275, -0..."


In [31]:
explains['top_feature_attributions'].iloc[0]

array([{'feature': 'embedding', 'attribution': 3.8513212900579674}],
      dtype=object)