![tracker](https://us-central1-vertex-ai-mlops-369716.cloudfunctions.net/pixel-tracking?path=statmike%2Fvertex-ai-mlops%2FMLOps%2FModel+Evaluation&file=model-evaluation-classification-multi-label.ipynb)
<!--- header table --->
<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/statmike/vertex-ai-mlops/blob/main/MLOps/Model%20Evaluation/model-evaluation-classification-multi-label.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo">
      <br>Run in<br>Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https%3A%2F%2Fraw.githubusercontent.com%2Fstatmike%2Fvertex-ai-mlops%2Fmain%2FMLOps%2FModel%2520Evaluation%2Fmodel-evaluation-classification-multi-label.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo">
      <br>Run in<br>Colab Enterprise
    </a>
  </td>      
  <td style="text-align: center">
    <a href="https://github.com/statmike/vertex-ai-mlops/blob/main/MLOps/Model%20Evaluation/model-evaluation-classification-multi-label.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo">
      <br>View on<br>GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/statmike/vertex-ai-mlops/main/MLOps/Model%20Evaluation/model-evaluation-classification-multi-label.ipynb">
      <img width="32px" src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo">
      <br>Open in<br>Vertex AI Workbench
    </a>
  </td>
</table>

# Evaluations For Multi-Label Classification Models

This workflow, part of our [MLOps](../readme.md) series that dives into [Model Evaluation](./readme.md), provides a comprehensive, end-to-end example of building and evaluating a machine learning model directly within Google Cloud's Vertex AI ecosystem.

Specifically, this guide will walk you through how to:

* **Prepare Data at Scale:** Begin by exploring and efficiently preparing a large public dataset sourced from **BigQuery**, demonstrating how to leverage both powerful **SQL** queries and the interactive **BigFrames API** (pandas interface) for machine learning data readiness.
* **Build Robust Models with Scikit-learn:** Construct a complete **Scikit-learn model pipeline** that seamlessly integrates additional feature engineering steps with the training of a classification model.
* **Register Models in Vertex AI:** Master the crucial MLOps step of saving your trained model and **registering it as a version within the Vertex AI Model Registry** for centralized management and version control.
* **Generate & Upload Custom Metrics:** Delve into preparing **custom, detailed evaluation metrics** using **Scikit-learn**, and then strategically **load these results to your versioned model in the Vertex AI Model Registry**. This includes showcasing how to generate and upload metrics for specific **data slices** to gain granular performance insights.
* **Review & Retrieve Evaluations:** Conclude by learning how to easily **review and programmatically retrieve these comprehensive evaluation results** directly from the Model Registry UI and via the **Vertex AI SDK**, empowering you to effectively track and compare model performance over time.
* **Generate and Work With Slices** Create metrics for slices of the data, in this case the class levels of the multi-class model.  Learn to generate and upload these to model evaluations in the Vertex AI Model Registry.

**The Model:**

This workflow uses a data source of Stack Overflow posts.  It creates a predictive model that predicts which tags apply to a post (e.g., 'python', 'pandas', 'data-science') - a multi-label classifier.

---
## Colab Setup

To run this notebook in Colab run the cells in this section.  Otherwise, skip this section.

This cell will authenticate to GCP (follow prompts in the popup).

In [1]:
PROJECT_ID = 'statmike-mlops-349915' # replace with project ID

In [2]:
try:
    import google.colab
    from google.colab import auth
    auth.authenticate_user()
    !gcloud config set project {PROJECT_ID}
except Exception:
    pass

---
## Installs

The list `packages` contains tuples of package import names and install names.  If the import name is not found then the install name is used to install quitely for the current user.

In [4]:
# tuples of (import name, install name, min_version)
packages = [
    ('bigframes', 'bigframes'),
    ('sklearn', 'scikit-learn'),
    ('numpy', 'numpy'),
    ('shap', 'shap'),
    ('google.cloud.aiplatform', 'google-cloud-aiplatform'), 
    ('google.cloud.storage', 'google-cloud-storage'),
    ('google.cloud.bigquery', 'google-cloud-bigquery'),
    ('bigquery_magics', 'bigquery-magics'),
    ('matplotlib', 'matplotlib'),
    ('pandas', 'pandas')
]

import importlib
install = False
for package in packages:
    if not importlib.util.find_spec(package[0]):
        print(f'installing package {package[1]}')
        install = True
        !pip install {package[1]} -U -q --user
    elif len(package) == 3:
        if importlib.metadata.version(package[0]) < package[2]:
            print(f'updating package {package[1]}')
            install = True
            !pip install {package[1]} -U -q --user

### API Enablement

In [5]:
!gcloud services enable aiplatform.googleapis.com

### Restart Kernel (If Installs Occured)

After a kernel restart the code submission can start with the next cell after this one.

In [6]:
if install:
    import IPython
    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)
    IPython.display.display(IPython.display.Markdown("""<div class=\"alert alert-block alert-warning\">
        <b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. The previous cells do not need to be run again⚠️</b>
        </div>"""))

---
## Setup

inputs:

In [7]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

'statmike-mlops-349915'

In [8]:
REGION = 'us-central1'
SERIES = 'mlops'
EXPERIMENT = 'evaluation-classification-multi-label'

# Set the name of GCS Bucket to read/write to
GCS_BUCKET = PROJECT_ID

# Data source for this workflow
BQ_SOURCE = 'bigquery-public-data.stackoverflow.posts_questions'

# make this the BigQuery Project / Dataset / Table prefix to store results
BQ_PROJECT = PROJECT_ID
BQ_DATASET = SERIES.replace('-', '_')
BQ_TABLE = EXPERIMENT
BQ_REGION = REGION[0:2] # use a multi region

packages:

In [10]:
import os
import joblib
import sklearn.metrics
import sklearn.ensemble
import sklearn.pipeline
import sklearn.compose
import sklearn.preprocessing
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import bigframes.pandas as bpd
from google.cloud import aiplatform
from google.cloud import bigquery
from google.cloud import storage

clients:

In [11]:
# vertex ai clients
aiplatform.init(project = PROJECT_ID, location = REGION, experiment = SERIES+'-'+EXPERIMENT)

# gcs storage client
gcs = storage.Client(project = GCS_BUCKET)
bucket = gcs.bucket(GCS_BUCKET)

# bigquery client
bq = bigquery.Client(project = PROJECT_ID)

# bigframes setup
bpd.options.bigquery.project = PROJECT_ID

# bigquery cell magics load
%load_ext bigquery_magics

Parameters:

In [12]:
DIR = f"files/{EXPERIMENT}"

Environment:

In [13]:
if not os.path.exists(DIR):
    os.makedirs(DIR)

---
## Review Data Source: Stack Overflow Posts

This project uses a sample of data from the following data source as a tutorial.  

The source table is a BigQuery Public Dataset table of Stack Overflow posted questions and contains the title and body text as well as associated tags for the posts.  

There are 23,020,127 observations including 2,116,212 from the year 2017 that are processed for use in this workflow.

### Review BigQuery table:

Use the [BigQuery BigFrames](https://cloud.google.com/bigquery/docs/use-bigquery-dataframes) package, setup above, to treat BigQuery tables like dataframes.  This has the advantage of doing the computing within BigQuery rather than pulling the data locally.

In [14]:
source_data = bpd.read_gbq(BQ_SOURCE)
source_data.head()

Unnamed: 0,id,title,body,accepted_answer_id,answer_count,comment_count,community_owned_date,creation_date,favorite_count,last_activity_date,last_edit_date,last_editor_display_name,last_editor_user_id,owner_display_name,owner_user_id,parent_id,post_type_id,score,tags,view_count
0,70258907,Issue when using host file in docker container...,<p>My issue is the &quot;auto-encryption&quot;...,70422908.0,1,0,,2021-12-07 10:56:44.280000+00:00,,2021-12-20 14:38:06.993000+00:00,2021-12-07 12:44:57.153000+00:00,,835636.0,,835636,,1,0,docker|apache-karaf|apache-felix|jasypt,126
1,60254980,"I have click event attached to elements <li>, ...",<p>I have click event attached to elements <co...,,2,2,,2020-02-17 01:53:31.527000+00:00,,2020-02-17 08:31:50.267000+00:00,2020-02-17 08:31:50.267000+00:00,,4370109.0,,3230529,,1,-2,javascript|jquery|jquery-events,659
2,47012267,Make a better grid of points with python,<p>I can make simple grid of points in a recta...,,1,2,,2017-10-30 09:51:57.940000+00:00,1.0,2017-10-30 10:19:21.197000+00:00,,,,,8855484,,1,1,numpy|grid|boundary,5159
3,67073082,Jetpack Compose beta cannot go edge-to-edge du...,<p>I have upgraded Compose for my app from <co...,,1,5,,2021-04-13 10:26:27.813000+00:00,,2021-07-13 11:18:35.677000+00:00,2021-04-13 12:25:11.390000+00:00,,975887.0,,975887,,1,0,android|android-jetpack-compose|android-immersive,1280
4,5520689,Analyzing Recursive Algorithms,<p>I've often been slightly stumped by recursi...,5657512.0,2,4,,2011-04-02 02:19:08.940000+00:00,1.0,2011-04-14 01:14:42.780000+00:00,2011-04-02 02:58:21.543000+00:00,,33833.0,,33833,,1,1,algorithm|analysis,1232


In [15]:
source_data.info()

<class 'bigframes.dataframe.DataFrame'>
Index: 23020127 entries, 0 to 23020126
Data columns (total 20 columns):
  #  Column                    Dtype
---  ------------------------  ------------------------------
  0  id                        Int64
  1  title                     string
  2  body                      string
  3  accepted_answer_id        Int64
  4  answer_count              Int64
  5  comment_count             Int64
  6  community_owned_date      timestamp[us, tz=UTC][pyarrow]
  7  creation_date             timestamp[us, tz=UTC][pyarrow]
  8  favorite_count            Int64
  9  last_activity_date        timestamp[us, tz=UTC][pyarrow]
 10  last_edit_date            timestamp[us, tz=UTC][pyarrow]
 11  last_editor_display_name  string
 12  last_editor_user_id       Int64
 13  owner_display_name        string
 14  owner_user_id             Int64
 15  parent_id                 string
 16  post_type_id              Int64
 17  score                     Int64
 18  tags         

In [16]:
source_data.describe(include = 'all')

Unnamed: 0,id,title,body,accepted_answer_id,answer_count,comment_count,community_owned_date,creation_date,favorite_count,last_activity_date,last_edit_date,last_editor_display_name,last_editor_user_id,owner_display_name,owner_user_id,parent_id,post_type_id,score,tags,view_count
count,23020127.0,23020127.0,23020127.0,11755280.0,23020127.0,23020127.0,11355.0,23020127.0,5029531.0,23020127.0,12559922.0,265452.0,12361577.0,563534.0,22562265.0,0.0,23020127.0,23020127.0,23020127.0,23020127.0
nunique,,23006006.0,23017995.0,,,,,,,,,53951.0,,130869.0,,0.0,,,8448317.0,
mean,38214250.492806,,,34495650.183957,1.478016,1.985216,,,2.824621,,,,3487890.377558,,4898520.717387,,1.0,2.214833,,2827.689736
std,21099662.230776,,,21071281.966936,1.453571,2.662199,,,21.154099,,,,4029319.918362,,4691874.94353,,0.0,25.754884,,24555.661991
min,4.0,,,7.0,0.0,0.0,,,0.0,,,,-1.0,,1.0,,1.0,-146.0,,1.0
25%,20082502.0,,,15560941.0,1.0,0.0,,,1.0,,,,472792.0,,1149526.0,,1.0,0.0,,97.0
50%,38244069.0,,,32850998.0,1.0,1.0,,,1.0,,,,1839439.0,,3276418.0,,1.0,0.0,,348.0
75%,56483609.0,,,52647556.0,2.0,3.0,,,2.0,,,,5050401.0,,7347345.0,,1.0,2.0,,1247.0
max,73842327.0,,,73842204.0,518.0,108.0,,,11784.0,,,,20080937.0,,20081043.0,,1.0,26621.0,,11649204.0


### Understand The Data Source With `ML.DESCRIBE_DATA`

Reviewing a few records, like above, gives a good sense of how the data is arranged. Before proceeding with machine learning techniques it is important to understand more about these raw columns.  Are they ready to use a features in a model or is some form of feature engineering needed first?  For this, the distribution of values is an important starting point.  

While SQL could be used to look at the distribution, it would be a time consuming process and requires different techniques for different data types like numerical, string, boolean, dates, times, array and struct version of these, and arrays of structs.

To make this process fast and simple, the new [`ML.DESCRIBE_DATA`](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-describe-data) function is used to get a single row for each column the describes the data distribution:
- `top_k`: get the top 3 most frequent categories for string columns (default = 1)
- `num_quantiles`: get 4 quantiles for numerical columns (default = 2)

In [17]:
%%bigquery
SELECT *
FROM ML.DESCRIBE_DATA(
    TABLE `bigquery-public-data.stackoverflow.posts_questions`,
    STRUCT(3 as top_k, 4 as num_quantiles)
)

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,name,num_rows,num_nulls,num_zeros,min,max,mean,stddev,median,quantiles,unique,avg_string_length,num_values,top_values,min_array_length,max_array_length,avg_array_length,total_array_length,array_length_quantiles,dimension
0,accepted_answer_id,23020127,11264847,0.0,7,73842204,34495650.0,21071280.0,33019127.0,"[7.0, 15736431.0, 33030249.0, 52323255.0, 7384...",,,11755280,[],,,,,[],
1,answer_count,23020127,0,3316808.0,0,518,1.478016,1.453571,1.0,"[0.0, 1.0, 1.0, 2.0, 518.0]",,,23020127,[],,,,,[],
2,body,23020127,0,,""" runat=""server"" />\n \n\n<p>But this code ...",مرحبا بك\n {{authService.decodedToken?.uniq...,,,,[],22860569.0,1559.723116,23020127,[{'value': '<p>I'm responsible of finding a go...,,,,,[],
3,comment_count,23020127,0,9363819.0,0,108,1.985216,2.662199,1.0,"[0.0, 0.0, 1.0, 3.0, 108.0]",,,23020127,[],,,,,[],
4,community_owned_date,23020127,23008772,,2008-08-12 04:59:35.017+00,2022-09-19 13:19:19.410+00,,,,[],11361.0,25.982387,11355,"[{'value': None, 'count': 23008772}, {'value':...",,,,,[],
5,creation_date,23020127,0,,2008-07-31 21:42:52.667+00,2022-09-25 05:56:32.863+00,,,,[],23161476.0,25.986729,23020127,"[{'value': '2013-11-07 05:34:24.390+00', 'coun...",,,,,[],
6,favorite_count,23020127,17990596,725181.0,0,11784,2.824621,21.1541,1.0,"[0.0, 1.0, 1.0, 2.0, 11784.0]",,,5029531,[],,,,,[],
7,id,23020127,0,0.0,4,73842327,38214250.0,21099660.0,37654803.0,"[4.0, 20018398.0, 38164817.0, 56447641.0, 7384...",,,23020127,[],,,,,[],
8,last_activity_date,23020127,0,,2008-09-04 12:50:25.060+00,2022-09-25 05:56:36.210+00,,,,[],22961997.0,25.986694,23020127,"[{'value': '2010-12-02 21:25:32.517+00', 'coun...",,,,,[],
9,last_edit_date,23020127,10460205,,2008-08-03 21:38:52.623+00,2022-09-25 05:56:30.490+00,,,,[],11831214.0,25.986712,12559922,"[{'value': None, 'count': 10460205}, {'value':...",,,,,[],


Some observations:
- The columns `body`, `title`, `tags` and `id` are always present, this is a great starting poing for a model to predict tags based on the contents of a post.

---
## Prepare Data Source

The data preparation includes adding splits for machine learning with a column named `splits` with values for training (`TRAIN`), validation (`VALIDATE`), and for testing (`TEST`). 

>These steps could be done locally at training but are instead done in the source system, BigQuery in this case, which provides several advantages:
>
>-   **Single Source of Truth:** A single data preparation can benefit multiple model training jobs for different architectures or even different team members working on the same model. This ensures consistency and avoids duplication of effort.
>-   **Leverage BigQuery's Power:** BigQuery is highly optimized for large-scale data processing. Performing these operations directly in BigQuery leverages its distributed processing capabilities, making the preparation significantly faster and more efficient than local processing, especially for massive datasets.
>-   **Reduced Data Movement:** Preparing the data in BigQuery reduces the amount of data that needs to be moved out of BigQuery and into the training environment. This minimizes latency and potential bottlenecks associated with data transfer.
>-   **Data Versioning and Reproducibility:** By preparing the splits and unique ID in BigQuery, the specific dataset used for training can be easily tracked and versioned. This enhances the reproducibility of experiments and makes it easier to understand the provenance of the data used in a particular model.
>-   **Simplified Training Pipeline:** The training pipeline becomes simpler because it can directly read pre-split data from BigQuery, eliminating the need for complex splitting logic within the training code.
>-   **Pre-calculated Joins and Features:** BigQuery can be used to pre-calculate joins and engineer new features that are beneficial for the model. This can improve model performance and further reduce the workload during the training phase.
>
>**Further Considerations:**
>
>-   **Data Governance and Security:** BigQuery offers robust data governance and security features. Performing data preparation within BigQuery allows you to maintain control over access and ensure data quality.
>-   **Scalability:** This approach is highly scalable. As your dataset grows, BigQuery can handle the increased workload without requiring significant changes to your data preparation pipeline.
>-   **Cost Optimization:** While moving large amounts of data out of BigQuery can incur costs, performing the preparation steps within BigQuery and only extracting the necessary data for training can often be more cost-effective.
>
>By preparing the data in BigQuery, you create a streamlined, efficient, and reproducible workflow (pipeline) that leverages the strengths of the platform and sets your machine learning models up for success.


### Create/Recall Dataset

In [18]:
dataset = bigquery.Dataset(f"{BQ_PROJECT}.{BQ_DATASET}")
dataset.location = BQ_REGION
bq_dataset = bq.create_dataset(dataset, exists_ok = True)

### Create/Recall Table With Preparation For ML

Copy the data from the source while filtering rows, processing columns:
- Filter to a single years records: 2017
- `splits` column to randomly assign rows to 'TRAIN", "VALIDATE" and "TEST" groups
- Subset columns to `body`, `title`, and `id`
- **NOTE:** This sample the data further down by only assigning splits to about 10% of the data.  This is strickly for local testing in this example.

In [21]:
job = bq.query(f"""
CREATE OR REPLACE TABLE `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE}` AS
WITH
    base_data AS (
        SELECT
            id,                                -- Unique identifier for the question
            title,                             -- Raw title text
            body,                              -- Raw body text
            tags,                               -- Multi-label target (pipe-separated string)
            ROW_NUMBER() OVER (ORDER BY RAND()) as rn -- simple row numbering for splitting
        FROM
            `{BQ_SOURCE}` -- Source table
        WHERE
            title IS NOT NULL AND body IS NOT NULL AND tags IS NOT NULL
            AND LENGTH(TRIM(title)) > 0 AND LENGTH(TRIM(body)) > 0 AND LENGTH(TRIM(tags)) > 0
            -- Filter out deleted or non-question posts if necessary (e.g., post_type_id = 1 for questions)
            -- For simplicity, we'll keep it minimal here.
            AND creation_date BETWEEN '2017-01-01 00:00:00+00:00' AND '2017-12-31 23:59:59+00:00' -- Focus on a specific year for manageability
    )
SELECT
    * EXCEPT(rn),
    CASE
        WHEN rn <= 0.08 * COUNT(*) OVER () THEN 'TRAIN'
        WHEN rn <= 0.09 * COUNT(*) OVER () THEN 'VALIDATE'
        WHEN rn <= 0.10 * COUNT(*) OVER () THEN 'TEST'
        ELSE Null -- records not sampled for using in training here
    END AS splits
FROM
    base_data
""")
job.result()
(job.ended-job.started).total_seconds()

13.147

In [22]:
source_data_prepared = bpd.read_gbq(f"{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE}", use_cache = False)
source_data_prepared.head()

Unnamed: 0,id,title,body,tags,splits
0,45949298,When running my function my if statement is no...,"<p>I am trying to make a simple bubble sort, a...",python|arrays|python-3.x|if-statement|bubble-sort,
1,41810643,Finding a gps locations near another one based...,"<p>I have a<code>GpsLocation</code> model, thi...",ruby-on-rails|ruby-on-rails-4|geocoding,
2,44769160,Spring - Provide parsed JWT as Resource method...,"<p>In Spring framework, is there a way to pars...",spring|jwt,
3,42025994,Scala slick retrieve data tables in parallel,<p>I need to read data from two different tabl...,multithreading|scala|jdbc|slick,TRAIN
4,42289360,What is the correct way to test multiple andro...,"<p>I have a working ""client-server"" program , ...",android|android-studio|android-emulator|emulation,


### Review the number of records for each of the data splits:

In [23]:
bq.query(f"""
SELECT splits,
    count(*) as count,
    ROUND(
        COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2
    ) AS percentage
FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE}`
WHERE splits IS NOT NULL
GROUP BY splits
ORDER BY splits
""").to_dataframe()

Unnamed: 0,splits,count,percentage
0,TEST,21162,10.0
1,TRAIN,169296,80.0
2,VALIDATE,21163,10.0


### Further Data Preparation



#### Rows Assigned To A Split (TRAIN, TEST, VALIDATE)

Avoid the unsampled rows where `splits = Null`:

In [24]:
# only use rows assigned to a split:
source_data_prepared_filtered = source_data_prepared[source_data_prepared['splits'].notna()]

#### Group Columns By Type

In [25]:
# Feature Columns
string_feature_cols = [
    'title',
    'body'
]

target_col = 'tags'
split_col = 'splits'
id_col = 'id' # Keep id for later lookup if needed, but not as a feature

#### Check For Columns With NaNs

In [26]:
source_data_prepared_filtered[string_feature_cols + [target_col]].isna().sum()

title    0
body     0
tags     0
dtype: Int64

#### Setup Pointers For Training

In [27]:
X_frame = source_data_prepared_filtered[string_feature_cols]
y_frame = source_data_prepared_filtered[target_col]
splits_frame = source_data_prepared_filtered[split_col]

print(f"There are {X_frame.shape[0]} training rows for {X_frame.shape[1]} raw features.")
print(f"Confirming the rows for splits and the label are {splits_frame.shape[0]} and {y_frame.shape[0]} respectively.")

There are 211621 training rows for 2 raw features.
Confirming the rows for splits and the label are 211621 and 211621 respectively.


---
## Train With Scikit-Learn

### Local Dataframes

Convert the Bigframes pointers to Pandas dataframes for the training objects:

In [28]:
X_train = X_frame[splits_frame == 'TRAIN'].to_pandas().reset_index(drop=True)
X_val = X_frame[splits_frame == 'VALIDATE'].to_pandas().reset_index(drop=True)
X_test = X_frame[splits_frame == 'TEST'].to_pandas().reset_index(drop=True)

y_train_raw = y_frame[splits_frame == 'TRAIN'].to_pandas().reset_index(drop=True)
y_val_raw = y_frame[splits_frame == 'VALIDATE'].to_pandas().reset_index(drop=True)
y_test_raw = y_frame[splits_frame == 'TEST'].to_pandas().reset_index(drop=True)

In [29]:
X_train.shape, y_train_raw.shape

((169296, 2), (169296,))

### Prepare Multi-Label Target

The target data need to be turned into a sparse matrics where columns represent individual tags and the values represent the presence (1) or absense (0) of the tag for the sample (row).

The `y_raw_train` data will be used to make a list of the most common tags and then train a `MultiLabelBinarizer` to create the sparse matrix.  This learned representation will then be applied to the test and validate data splits.

In [30]:
# Convert pipe-delimited strings into lists of tags for each split
y_train_raw_lists = y_train_raw.apply(lambda x: x.split('|') if pd.notna(x) and x.strip() != '' else [])
y_val_raw_lists = y_val_raw.apply(lambda x: x.split('|') if pd.notna(x) and x.strip() != '' else [])
y_test_raw_lists = y_test_raw.apply(lambda x: x.split('|') if pd.notna(x) and x.strip() != '' else [])


# Count tag frequency from TRAINING DATA ONLY to identify frequent tags
all_tags_train = [tag for sublist in y_train_raw_lists for tag in sublist]
tag_counts_train = pd.Series(all_tags_train).value_counts()

# Identify most frequent tags based on training data (e.g., frequency > 500)
frequent_tags_train = tag_counts_train[tag_counts_train >= 500].index.tolist()

# Filter the tag lists for all splits to only include frequent tags learned from training
y_train_raw_lists_filtered = y_train_raw_lists.apply(lambda tags_list: [tag for tag in tags_list if tag in frequent_tags_train])
y_val_raw_lists_filtered = y_val_raw_lists.apply(lambda tags_list: [tag for tag in tags_list if tag in frequent_tags_train])
y_test_raw_lists_filtered = y_test_raw_lists.apply(lambda tags_list: [tag for tag in tags_list if tag in frequent_tags_train])

In [31]:
# Initialize MultiLabelBinarizer to output sparse matrix
mlb = sklearn.preprocessing.MultiLabelBinarizer(sparse_output=True)

# Fit mlb ONLY on training data, then transform all splits
y_train_processed = mlb.fit_transform(y_train_raw_lists_filtered)
y_val_processed = mlb.transform(y_val_raw_lists_filtered)
y_test_processed = mlb.transform(y_test_raw_lists_filtered)

In [32]:
y_train_processed.shape

(169296, 127)

In [33]:
mlb.classes_.shape

(127,)

In [34]:
pd.DataFrame(y_train_processed[:5].toarray(), columns=mlb.classes_).head()

Unnamed: 0,.net,ajax,algorithm,amazon-web-services,android,android-studio,angular,angularjs,apache,apache-spark,...,visual-studio,vue.js,webpack,windows,winforms,wordpress,wpf,xamarin,xcode,xml
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### Define A Preprocessor

Create a Scikit-Learn preprocessor that standardizes numeric variables and one-hot encodes categorical variables:

In [35]:
preprocessor = sklearn.compose.ColumnTransformer(
    transformers = [
        # Apply TfidfVectorizer to string columns
        ('title_tfidf', sklearn.feature_extraction.text.TfidfVectorizer(max_features=200, stop_words='english'), 'title'),
        ('body_tfidf', sklearn.feature_extraction.text.TfidfVectorizer(max_features=400, stop_words='english'), 'body'),
    ],
    remainder = 'drop' # Drop any other columns not explicitly listed (if any)
)

### Define A Model Pipeline: Proprocessor + Model

Combine the preprocessor with the model type in a pipeline:

In [36]:
base_classifier = sklearn.linear_model.LogisticRegression(
    random_state=42,
    solver='liblinear', # 'liblinear' is good for smaller datasets and L1/L2 regularization
    max_iter=100,      # Increase max_iter for convergence if needed
    penalty='l1',       # L1 regularization often good for sparse features
    C=0.1               # Regularization strength (tune this)
)

In [37]:
multi_label_classifier = sklearn.multioutput.MultiOutputClassifier(
    estimator=base_classifier,
    n_jobs=2 # Parallelize fitting of individual classifiers for each label
)

In [38]:
model_pipeline = sklearn.pipeline.Pipeline(
    steps = [
        ('preprocessor', preprocessor),
        ('classifier', multi_label_classifier)
    ]
)

### Train/Fit The Model

Use the training data to train the model using the pipeline:

In [39]:
X_train.shape, y_train_processed.shape

((169296, 2), (169296, 127))

In [40]:
model_pipeline.fit(X_train, y_train_processed.toarray())