In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import sys
from datetime import datetime
from pathlib import Path

from google.cloud import aiplatform, storage
from kfp.dsl import Artifact, Dataset, Metrics, Model
from loguru import logger

module_path = os.path.abspath("..")
sys.path.append(module_path)

from src.base.utilities import generate_query, read_yaml
from src.utils.notebooks import patch_kfp

patch_kfp()

from src.components.bigquery import execute_query, bq_table_to_dataset
from src.components.data import get_data_version
from src.components.model import train_evaluate_model

# Remove default logger and set level to DEBUG
logger.remove()
logger.add(sys.stderr, level="DEBUG")

In [None]:
with open("../src/pipelines/training/payloads/dev.json", "r") as f:
    payload = json.load(f)
    payload = payload["data"]

project_id = os.environ.get("VERTEX_PROJECT_ID")
dataset_id = payload["dataset_id"]
dataset_location = payload["dataset_location"]
data_version = payload["data_version"]
create_replace_tables = payload["create_replace_tables"]

In [None]:
config_params = read_yaml("../src/pipelines/configuration/params.yaml")

In [None]:
queries_folder = Path.cwd().parent / "src/pipelines/training/queries"

In [None]:
data_version = get_data_version(
    payload_data_version=data_version,
    project_id=project_id,
    dataset_id=dataset_id,
    dataset_location=dataset_location,
)

In [None]:
dataset_name = f"{project_id}.{dataset_id}_{data_version}"
train_set_table = f"{dataset_name}.training"
valid_set_table = f"{dataset_name}.validation"
test_set_table = f"{dataset_name}.testing"
preprocessed_table = f"{dataset_name}.preprocessed"

train_valid_test_query = generate_query(
    queries_folder / "q_train_valid_test_split.sql",
    source_table=preprocessed_table,
    valid_size=0.15,
    test_size=0.15,
    training_table=train_set_table,
    validation_table=valid_set_table,
    testing_table=test_set_table,
    create_replace_table=create_replace_tables,
)

query_job_config = json.dumps(dict(use_query_cache=True))


execute_query(
    query=train_valid_test_query,
    bq_client_project_id=project_id,
    query_job_config=query_job_config,
)

In [None]:
pipeline_files_path = os.environ["VERTEX_PIPELINE_FILES_GCS_PATH"].replace("gs://", "")
bucket = pipeline_files_path.split("/")[0]
logger.debug(f"GCS bucket: gs://{bucket} .")
prefix = pipeline_files_path.split("/", 1)[1] + "/local_run"
logger.debug(f"GCS prefix: {prefix} .")

In [None]:
training_set = Dataset(
    name="training_set", uri=f"gs://{bucket}/{prefix}/training_set/"
)
validation_set = Dataset(
    name="validation_set",uri=f"gs://{bucket}/{prefix}/validation_set/"
)
test_set = Dataset(
    name="test_set", uri=f"gs://{bucket}/{prefix}/test_set/"
)

_ = bq_table_to_dataset(
    bq_client_project_id=project_id,
    source_project_id=project_id,
    dataset_id=f"{dataset_id}_{data_version}",
    table_name=train_set_table.rsplit(".", 1)[1],
    dataset_location=dataset_location,
    file_pattern="file_*",
    extract_job_config=dict(destination_format="PARQUET"),
    skip_if_exists=True,
    # The arguments below should not be supplied when running the component on Vertex AI
    dataset=training_set,
)

_ = bq_table_to_dataset(
    bq_client_project_id=project_id,
    source_project_id=project_id,
    dataset_id=f"{dataset_id}_{data_version}",
    table_name=valid_set_table.rsplit(".", 1)[1],
    dataset_location=dataset_location,
    file_pattern="file_*",
    extract_job_config=dict(destination_format="PARQUET"),
    skip_if_exists=True,
    # The arguments below should not be supplied when running the component on Vertex AI
    dataset=validation_set,
)

_ = bq_table_to_dataset(
    bq_client_project_id=project_id,
    source_project_id=project_id,
    dataset_id=f"{dataset_id}_{data_version}",
    table_name=test_set_table.rsplit(".", 1)[1],
    dataset_location=dataset_location,
    file_pattern="file_*",
    extract_job_config=dict(destination_format="PARQUET"),
    skip_if_exists=True,
    # The arguments below should not be supplied when running the component on Vertex AI
    dataset=test_set,
)

In [None]:
model_algorithm = "random_forest"
config_params["models_params"][model_algorithm]["n_estimators"] = 10

In [None]:
artifacts_path = Path.cwd().parent / "artifacts"

model = Model(name="credit-card-frauds", uri=artifacts_path)
train_metrics = Metrics(name="train_metrics")
validation_metrics = Metrics(name="validation_metrics")
test_metrics = Metrics(name="test_metrics")
validation_pr_curve = Artifact(name="validation_pr_curve", uri=artifacts_path)
test_pr_curve = Artifact(name="test_pr_curve", uri=artifacts_path)

train_evaluate_model(
    training_data=training_set,
    validation_data=validation_set,
    test_data=test_set,
    target_column=config_params["target_column"],
    model_name=model_algorithm,
    models_params=config_params["models_params"],
    fit_args=config_params["fit_args"],
    data_processing_args=config_params["data_processing_args"],
    model_gcs_folder_path=None,
    # The arguments below should not be supplied when running the component on Vertex AI
    model=model,
    train_metrics=train_metrics,
    valid_metrics=validation_metrics,
    test_metrics=test_metrics,
    valid_pr_curve=validation_pr_curve,
    test_pr_curve=test_pr_curve,
)

In [None]:
storage_client = storage.Client(project=project_id)
gcs_bucket = storage_client.bucket(bucket)

model_gcs = gcs_bucket.blob(f"{prefix}/{str(model.uri).split('/')[-1]}")
model_gcs.upload_from_filename(model.uri)

validation_pr_curve_gcs = gcs_bucket.blob(f"{prefix}/{str(validation_pr_curve.uri).split('/')[-1]}")
validation_pr_curve_gcs.upload_from_filename(validation_pr_curve.uri)

test_pr_curve_gcs = gcs_bucket.blob(f"{prefix}/{str(test_pr_curve.uri).split('/')[-1]}")
test_pr_curve_gcs.upload_from_filename(test_pr_curve.uri)

In [None]:
experiment = aiplatform.Experiment.get_or_create(
    experiment_name="credit-card-frauds",
    project=os.environ.get("VERTEX_PROJECT_ID"),
    location=os.environ.get("VERTEX_LOCATION"),
)

In [None]:
aiplatform.init(
    project=os.environ.get("VERTEX_PROJECT_ID"),
    location=os.environ.get("VERTEX_LOCATION"),
    experiment=experiment.name,
    experiment_tensorboard=False
)

In [None]:
run_name = f"notebook-execution-{model_algorithm}-".replace("_", "-") 
run_name += datetime.now().strftime("%Y%m%d%H%M%S")

with aiplatform.start_run(run=run_name) as run:
    aiplatform.log_params(config_params["models_params"][model_algorithm])
    aiplatform.log_metrics({k + "_train": v for k, v in train_metrics.metadata.items()})
    aiplatform.log_metrics(
        {k + "_validation": v for k, v in validation_metrics.metadata.items()}
    )
    aiplatform.log_metrics({k + "_test": v for k, v in test_metrics.metadata.items()})
    with aiplatform.start_execution(
        schema_title="system.ContainerExecution",
        display_name="train-evaluate-model",
    ) as exc:
        training_set_api = aiplatform.Artifact.create(
            uri=training_set.uri,
            schema_title="system.Dataset",
            display_name="training-set",
        )
        validation_set_aip = aiplatform.Artifact.create(
            uri=validation_set.uri,
            schema_title="system.Dataset",
            display_name="validation-set",
        )
        test_set_aip = aiplatform.Artifact.create(
            uri=test_set.uri,
            schema_title="system.Dataset",
            display_name="test-set",
        )
        exc.assign_input_artifacts([training_set_api, validation_set_aip, test_set_aip])

        model_aip = aiplatform.Artifact.create(
            uri=f"gs://{model_gcs.bucket.name}/{model_gcs.name}",
            schema_title="system.Model",
            display_name=f"{model_algorithm}-model".replace("_", "-"),
        )
        validation_pr_curve_aip = aiplatform.Artifact.create(
            uri=f"gs://{validation_pr_curve_gcs.bucket.name}/{validation_pr_curve_gcs.name}",
            schema_title="system.Artifact",
            display_name="validation-pr-curve",
        )
        test_pr_curve_aip = aiplatform.Artifact.create(
            uri=f"gs://{test_pr_curve_gcs.bucket.name}/{test_pr_curve_gcs.name}",
            schema_title="system.Artifact",
            display_name="test-pr-curve",
        )
        exc.assign_output_artifacts(
            [model_aip, validation_pr_curve_aip, test_pr_curve_aip]
        )