In [21]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [22]:
from pathlib import Path
from typing import List

import pandas as pd
from artifact_core.libs.resource_spec.tabular.spec import TabularDataSpec
from artifact_experiment.libs.tracking.filesystem.client import FilesystemTrackingClient
from artifact_experiment.table_comparison.validation_plan import (
    TableComparisonArrayCollectionType,
    TableComparisonArrayType,
    TableComparisonPlotCollectionType,
    TableComparisonPlotType,
    TableComparisonScoreCollectionType,
    TableComparisonScoreType,
    TableComparisonValidationPlan,
)

In [23]:
artifact_core_root = Path().absolute().parent

df_real = pd.read_csv(artifact_core_root / "assets/real.csv")
df_synthetic = pd.read_csv(artifact_core_root / "assets/synthetic.csv")

In [24]:
ls_cts_features = ["Age", "RestingBP", "Cholesterol", "MaxHR", "Oldpeak"]

resource_spec = TabularDataSpec.from_df(
    df=df_real,
    ls_cts_features=ls_cts_features,
    ls_cat_features=[feature for feature in df_real.columns if feature not in ls_cts_features],
)

In [25]:
class MyValidationPlan(TableComparisonValidationPlan):
    @staticmethod
    def _get_score_types() -> List[TableComparisonScoreType]:
        return [TableComparisonScoreType.MEAN_JS_DISTANCE]

    @staticmethod
    def _get_array_types() -> List[TableComparisonArrayType]:
        return []

    @staticmethod
    def _get_plot_types() -> List[TableComparisonPlotType]:
        return [
            TableComparisonPlotType.PDF_PLOT,
            TableComparisonPlotType.CDF_PLOT,
            TableComparisonPlotType.DESCRIPTIVE_STATS_COMPARISON_PLOT,
            TableComparisonPlotType.PCA_PROJECTION_PLOT,
            TableComparisonPlotType.TSNE_PROJECTION_PLOT,
        ]

    @staticmethod
    def _get_score_collection_types() -> List[TableComparisonScoreCollectionType]:
        return [TableComparisonScoreCollectionType.JS_DISTANCE]

    @staticmethod
    def _get_array_collection_types() -> List[TableComparisonArrayCollectionType]:
        return []

    @staticmethod
    def _get_plot_collection_types() -> List[TableComparisonPlotCollectionType]:
        return []


plan = MyValidationPlan.build(resource_spec=resource_spec)

In [27]:
plan.execute(dataset_real=df_real, dataset_synthetic=df_synthetic)



In [28]:
plan.scores

{'MEAN_JS_DISTANCE': 0.059110809178533986}

In [29]:
filesystem_tracker = FilesystemTrackingClient.build(experiment_id="demo")

In [40]:
filesystem_tracker.run_dir

'C:\\Users\\hecto\\artifact_ml\\demo\\ba7397e7-fa67-44bb-8791-f89c9cc69d1b'

In [38]:
plan = MyValidationPlan.build(resource_spec=resource_spec, tracking_client=filesystem_tracker)

plan.execute(dataset_real=df_real, dataset_synthetic=df_synthetic)

