In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataval.dataset import WeatherDataset
from dataval.train import CatBoostTrainer

import os
import matplotlib.pyplot as plt
import pandas as pd

from gate import summarize, detect_drift

  from .autonotebook import tqdm as notebook_tqdm


# GATE: Automatic Drift Detection

In this notebook, we leverage GATE, a new research technique to automatically detect whether partitions have drifted. The original GATE research paper is here: https://arxiv.org/abs/2303.06094

In [3]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-partitioned-dataset"), sample_frac=0.2)

In [4]:
def train_and_test(train_df, test_df):
    X_train, y_train = ds.split_feature_label(train_df)

    catboost_hparams = {"depth": 5, "iterations": 250, "learning_rate": 0.03, "loss_function": "RMSE"}
    continual_t = CatBoostTrainer(catboost_hparams)
    continual_t.fit(X_train, y_train, verbose=False)
    print(f"Train MSE for partition {ds.get_partition_key(train_df)}: {continual_t.score(X_train, y_train)}")

    # Evaluate
    X_test, y_test = ds.split_feature_label(test_df)
    print(f"Test MSE for partition {ds.get_partition_key(test_df)}: {continual_t.score(X_test, y_test)}")
    
    return continual_t, continual_t.score(X_train, y_train), continual_t.score(X_test, y_test)

In [5]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

## Iterate through corruptions

We'll iterate through the corruptions and compute the precision and recall of GATE. We'll log this to wandb as we did before in the last notebook.

In [6]:
import numpy as np

X_train, _ = ds.split_feature_label(train_df)
corruption_results = {}
corruption_columns = {}

feature_columns = X_train.columns.to_list()
feature_columns

# X_train["partition_key"] = ds.get_partition_key(train_df)
X_train["partition_key"] = np.random.choice(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'], size=len(X_train))
train_summaries = summarize(X_train, partition_key="partition_key", columns=feature_columns)

for corruption_name, corruption_res in ds.iterate_corruptions(test_df, "cmc", corruption_rate=0.05):
    corrupted_test_df, corrupted_columns = corruption_res
    corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
    
    corrupted_X_test["partition_key"] = "test"
    test_summary = summarize(corrupted_X_test, partition_key="partition_key", columns=feature_columns)[0]
    
    drift_result = detect_drift(test_summary, train_summaries, cluster=False)
    corruption_results[corruption_name] = drift_result
    
    corruption_columns[corruption_name] = corrupted_columns

In [7]:
print(drift_result)

Drift score: 37.7838 (100.00% percentile)
Top drifted columns:
                  statistic  z-score
column                              
cmc_0_1_11_0            p95  2.84605
cmc_0_1_66_0            p95  2.84605
cmc_0_1_66_0_next       p95  2.84605
cmc_0_1_67_0_next       p95  2.84605
cmc_0_1_67_0_grad       p95  2.84605
cmc_0_1_66_0_grad       p95  2.84605
cmc_0_1_67_0            p95  2.84605
cmc_0_1_68_0_grad      mean  2.84605
cmc_available          mean  2.84605
cmc_0_1_68_0           mean  2.84605


In [10]:
# Send wandb alerts
import wandb
from wandb import AlertLevel

run = wandb.init(project="ml-dataval-tutorial", tags=["GATE"])

metrics = []

for corruption_name, drift_result in corruption_results.items():
    drifted_results = drift_result.drifted_columns(limit=None)
    drifted_results = drifted_results[
        drifted_results["z-score"].abs() >= 2.75
    ]
    
    if len(drifted_results) > 0:
        table = wandb.Table(dataframe=drifted_results)
        wandb.log({corruption_name: table})
    
        found_columns = drifted_results.index.values
        inter = set(found_columns).intersection(set(corruption_columns[corruption_name]))
        
        wandb.alert(
            title=f"Errors detected in {corruption_name} experiment", 
            text = f"GATE found {len(inter)} of {len(corruption_columns[corruption_name])} anomalous columns for corruption {corruption_name}. GATE flagged {len(set(found_columns))} in total.",
            level=AlertLevel.WARN,
        )
        
        precision = float(len(inter) / len(set(found_columns)))
        recall = float(len(inter) / len(corruption_columns[corruption_name]))
        metrics.append({"corruption_name": corruption_name, "precision": precision, "recall": recall})
    
# Log precision and recall
metric_df = pd.DataFrame(metrics)
metric_table = wandb.Table(dataframe=metric_df)
wandb.log({"precision" : wandb.plot.bar(metric_table, "corruption_name", "precision",
           title="Precision")})
wandb.log({"recall" : wandb.plot.bar(metric_table, "corruption_name", "recall",
           title="Recall")})
# wandb.log({"metrics": metric_table})

wandb.finish()

In [11]:
metric_df

Unnamed: 0,corruption_name,precision,recall
0,corrupt_null,1.0,1.0
1,corrupt_nonnegative,1.0,0.868421
2,corrupt_typecheck,1.0,0.4
3,corrupt_units,1.0,0.821429
4,corrupt_average,1.0,0.964286
5,corrupt_pinned,1.0,0.553571


## Takeaways

Looks like GATE performed a bit better! Still, it's impossible to be perfect...