![](https://raw.githubusercontent.com/wandb/wandb/508982e50e82c54cbf0dd464a9959fee0e1740ad/.github/wb-logo-lightbg.png)
<!--- @wandbcode{dataval-course-05} -->

In [None]:
from dataval.dataset import WeatherDataset
from dataval.train import CatBoostTrainer
from dataval import dataset_extensions

import os
import matplotlib.pyplot as plt
import pandas as pd

from gate import summarize, detect_drift

# GATE: Automatic Drift Detection

In this notebook, we leverage GATE, a new research technique to automatically detect whether partitions have drifted. The original GATE research paper is here: https://arxiv.org/abs/2303.06094

In [None]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-paritioned-dataset"), sample_frac=0.2)

In [None]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

## Iterate through corruptions

We'll iterate through the corruptions and compute the precision and recall of GATE. We'll log this to wandb as we did before in the last notebook.

In [None]:
import numpy as np

X_train, _ = ds.split_feature_label(train_df)
corruption_results = {}
corruption_columns = {}

feature_columns = X_train.columns.to_list()
feature_columns

X_train["partition_key"] = np.random.choice(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'], size=len(X_train))
train_summaries = summarize(X_train, partition_key="partition_key", columns=feature_columns)

for corruption_name, corruption_res in ds.iterate_corruptions_by_feature(test_df, "cmc", corruption_rate=0.05):
    corrupted_test_df, corrupted_columns = corruption_res
    corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
    
    corrupted_X_test["partition_key"] = "test"
    test_summary = summarize(corrupted_X_test, partition_key="partition_key", columns=feature_columns)[0]
    
    drift_result = detect_drift(test_summary, train_summaries, cluster=True)
    corruption_results[corruption_name] = drift_result
    
    corruption_columns[corruption_name] = corrupted_columns


In [None]:
# Send wandb alerts
import wandb
from wandb import AlertLevel

run = wandb.init(project="ml-dataval-tutorial", tags=["GATE"])

metrics = []

for corruption_name, drift_result in corruption_results.items():
    drifted_results = drift_result.drifted_columns(limit=None)
    drifted_results = drifted_results[
        (drifted_results["abs(z-score-cluster)"].abs() >= 1) & (drifted_results["z-score"].abs() >= 2.5)
    ]
    
    if len(drifted_results) > 0:
        table = wandb.Table(dataframe=drifted_results)
        wandb.log({corruption_name: table})
    
        found_columns = drifted_results.index.values
        inter = set(found_columns).intersection(set(corruption_columns[corruption_name]))
        
        wandb.alert(
            title=f"Errors detected in {corruption_name} experiment", 
            text = f"GATE found {len(inter)} of {len(corruption_columns[corruption_name])} anomalous columns for corruption {corruption_name}. GATE flagged {len(set(found_columns))} in total.",
            level=AlertLevel.WARN,
        )
        print(f"GATE found {len(drifted_results)} drifted results in {corruption_name} experiment")
        
        precision = float(len(inter) / len(set(found_columns)))
        recall = float(len(inter) / len(corruption_columns[corruption_name]))
        metrics.append({"corruption_name": corruption_name, "precision": precision, "recall": recall})
    
# Log precision and recall
metric_df = pd.DataFrame(metrics)
metric_table = wandb.Table(dataframe=metric_df)
wandb.log({"precision" : wandb.plot.bar(metric_table, "corruption_name", "precision",
           title="Precision")})
wandb.log({"recall" : wandb.plot.bar(metric_table, "corruption_name", "recall",
           title="Recall")})
# wandb.log({"metrics": metric_table})

wandb.finish()

In [None]:
metric_df

## Takeaways

Looks like GATE performed a bit better! Still, it's impossible to be perfect...