![](https://raw.githubusercontent.com/wandb/wandb/508982e50e82c54cbf0dd464a9959fee0e1740ad/.github/wb-logo-lightbg.png)
<!--- @wandbcode{dataval-course-04} -->

In [None]:
from dataval.dataset import WeatherDataset
from dataval.train import CatBoostTrainer
from dataval import dataset_extensions

import os
import matplotlib.pyplot as plt
import modal
import pandas as pd

In [None]:
image = (
    modal.Image.debian_slim()
    .pip_install_from_requirements("requirements.txt")
    .pip_install(["tensorflow-data-validation", "tensorflow_metadata", "protobuf==3.20.0"])
)
stub = modal.Stub("tfdv-tutorial", image=image)

# Drift Detection

Schema validation catches some, but not all, corruptions. In this notebook, we leverage TFDV's drift detection tool to see if all corruptions are identified.

In [None]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-paritioned-dataset"), sample_frac=0.2)

In [None]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

## Check for skew between train and test partitions

We use TFDV to infer the schema of the train partition and then check the test partition for skew. TFDV checks for distribution shift.

In [None]:
@stub.function()
def check_skew(train_df, test_df, feature_columns):
    import tensorflow_data_validation as tfdv
    from google.protobuf.json_format import MessageToDict
    
    train_stats =  tfdv.generate_statistics_from_dataframe(train_df)
    schema = tfdv.infer_schema(statistics=train_stats)
    test_stats = tfdv.generate_statistics_from_dataframe(test_df)
    
    for feature in feature_columns:
        tfdv.get_feature(schema, feature).skew_comparator.jensen_shannon_divergence.threshold = 0.1

    skew_anomalies = tfdv.validate_statistics(statistics=train_stats, schema=schema, serving_statistics=test_stats)
    anomalies_df = tfdv.utils.display_util.get_anomalies_dataframe(skew_anomalies)

    return anomalies_df

In [None]:
# Run on regular train and test data

with stub.run():
    X_train, _ = ds.split_feature_label(train_df)
    X_test, _ = ds.split_feature_label(test_df)
    anomalies = check_skew.call(X_train, X_test, X_train.columns.values)

In [None]:
pd.options.display.max_colwidth = 100
anomalies

Wow, it looks like there were many alerts triggered! Unclear if these alerts are meaningful though, as the test performance is not so much worse than the train performance. Also, how would we interpret the alerts?

## Iterate through corruptions

See if tfdv detects any anomalies, for all the corruptions we had in our previous notebook.

In [None]:
X_train, _ = ds.split_feature_label(train_df)
corruption_anomalies = {}
corruption_columns = {}

with stub.run():
    for corruption_name, corruption_res in ds.iterate_corruptions_by_feature(test_df, "cmc", corruption_rate=0.05):
        corrupted_test_df, corrupted_columns = corruption_res
        corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
        corruption_anomalies[corruption_name] = check_skew.call(X_train, corrupted_X_test, X_train.columns.values)
        corruption_columns[corruption_name] = corrupted_columns

In [None]:
# Send wandb alerts
import wandb
from wandb import AlertLevel

run = wandb.init(project="ml-dataval-tutorial", tags=["TFDV-drift"])

metrics = []

for corruption_name, anomalies in corruption_anomalies.items():
    
    if len(anomalies) > 0:
        table = wandb.Table(dataframe=anomalies)
        wandb.log({corruption_name: table})
    
        found_columns = [a[1:-1] for a in anomalies.index.values]
        inter = set(found_columns).intersection(set(corruption_columns[corruption_name]))
        
        wandb.alert(
            title=f"Errors detected in {corruption_name} experiment", 
            text = f"TFDV found {len(inter)} of {len(corruption_columns[corruption_name])} anomalies for corruption {corruption_name}. TFDV flagged {len(set(found_columns))} in total.",
            level=AlertLevel.WARN,
        )
        print(f"TFDV found {len(anomalies)} anomalies in {corruption_name} experiment")
        
        precision = float(len(inter) / len(set(found_columns)))
        recall = float(len(inter) / len(corruption_columns[corruption_name]))
        metrics.append({"corruption_name": corruption_name, "precision": precision, "recall": recall})
    
# Log precision and recall
metric_df = pd.DataFrame(metrics)
metric_table = wandb.Table(dataframe=metric_df)
wandb.log({"precision" : wandb.plot.bar(metric_table, "corruption_name", "precision",
           title="Precision")})
wandb.log({"recall" : wandb.plot.bar(metric_table, "corruption_name", "recall",
           title="Recall")})
# wandb.log({"metrics": metric_table})

wandb.finish()

In [None]:
metric_df

## Takeaways

Looks like TFDV didn't find all the right anomalies, but it found nonzero! Finding alerts precisely is very hard.