In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataval.dataset import WeatherDataset
from dataval.train import CatBoostTrainer

import os
import matplotlib.pyplot as plt
import modal
import pandas as pd

In [3]:
image = (
    modal.Image.debian_slim()
    .pip_install_from_requirements("requirements.txt")
    .pip_install(["tensorflow-data-validation", "tensorflow_metadata"])
)
stub = modal.Stub("tfdv-tutorial", image=image)

# Drift Detection

Schema validation catches some, but not all, corruptions. In this notebook, we leverage TFDV's drift detection tool to see if all corruptions are identified.

In [4]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-partitioned-dataset"), sample_frac=0.2)

In [5]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

## Check for skew between train and test partitions

We use TFDV to infer the schema of the train partition and then check the test partition for skew. TFDV checks for distribution shift.

In [6]:
@stub.function
def check_skew(train_df, test_df, feature_columns):
    import tensorflow_data_validation as tfdv
    from google.protobuf.json_format import MessageToDict
    
    train_stats =  tfdv.generate_statistics_from_dataframe(train_df)
    schema = tfdv.infer_schema(statistics=train_stats)
    test_stats = tfdv.generate_statistics_from_dataframe(test_df)
    
    for feature in feature_columns:
        tfdv.get_feature(schema, feature).skew_comparator.jensen_shannon_divergence.threshold = 0.1

    skew_anomalies = tfdv.validate_statistics(statistics=train_stats, schema=schema, serving_statistics=test_stats)
    anomalies_df = tfdv.utils.display_util.get_anomalies_dataframe(skew_anomalies)

    return anomalies_df

In [7]:
# Run on regular train and test data

with stub.run():
    X_train, _ = ds.split_feature_label(train_df)
    X_test, _ = ds.split_feature_label(test_df)
    anomalies = check_skew.call(X_train, X_test, X_train.columns.values)

Output()

Output()

Output()

In [8]:
pd.options.display.max_colwidth = 100
anomalies

Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_1_68_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 1 (up to six significa...
'wrf_rain',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.169053 (up to six si...
'cmc_0_1_7_0',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.102128 (up to six si...
'cmc_0_1_65_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.160259 (up to six si...
'topography_bathymetry',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.166572 (up to six si...
'cmc_precipitations',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.160258 (up to six si...
'cmc_0_1_67_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.769623 (up to six si...
'gfs_r_velocity',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.128119 (up to six si...
'wrf_snow',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.76025 (up to six sig...
'cmc_0_1_66_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.173947 (up to six si...


Wow, it looks like there were many alerts triggered! Unclear if these alerts are meaningful though, as the test performance is not so much worse than the train performance. Also, how would we interpret the alerts?

## Iterate through corruptions

See if tfdv detects any anomalies, for all the corruptions we had in our previous notebook.

In [9]:
X_train, _ = ds.split_feature_label(train_df)
corruption_anomalies = {}
corruption_columns = {}

with stub.run():
    for corruption_name, corruption_res in ds.iterate_corruptions(test_df, "cmc", corruption_rate=0.05):
        corrupted_test_df, corrupted_columns = corruption_res
        corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
        corruption_anomalies[corruption_name] = check_skew.call(X_train, corrupted_X_test, X_train.columns.values)
        corruption_columns[corruption_name] = corrupted_columns

Output()

Output()

Output()

In [10]:
# Send wandb alerts
import wandb
from wandb import AlertLevel

run = wandb.init(project="ml-dataval-tutorial", tags=["TFDV-drift"])

metrics = []

for corruption_name, anomalies in corruption_anomalies.items():
    
    if len(anomalies) > 0:
        table = wandb.Table(dataframe=anomalies)
        wandb.log({corruption_name: table})
    
        found_columns = [a[1:-1] for a in anomalies.index.values]
        inter = set(found_columns).intersection(set(corruption_columns[corruption_name]))
        
        wandb.alert(
            title=f"Errors detected in {corruption_name} experiment", 
            text = f"TFDV found {len(inter)} of {len(corruption_columns[corruption_name])} anomalies for corruption {corruption_name}. TFDV flagged {len(set(found_columns))} in total.",
            level=AlertLevel.WARN,
        )
        
        precision = float(len(inter) / len(set(found_columns)))
        recall = float(len(inter) / len(corruption_columns[corruption_name]))
        metrics.append({"corruption_name": corruption_name, "precision": precision, "recall": recall})
    
# Log precision and recall
metric_df = pd.DataFrame(metrics)
metric_table = wandb.Table(dataframe=metric_df)
wandb.log({"precision" : wandb.plot.bar(metric_table, "corruption_name", "precision",
           title="Precision")})
wandb.log({"recall" : wandb.plot.bar(metric_table, "corruption_name", "recall",
           title="Recall")})
# wandb.log({"metrics": metric_table})

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33msh_reya[0m ([33mnnprov[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
metric_df

Unnamed: 0,corruption_name,precision,recall
0,corrupt_null,0.526316,0.178571
1,corrupt_nonnegative,0.711111,0.842105
2,corrupt_typecheck,0.1,0.4
3,corrupt_units,0.82,0.732143
4,corrupt_average,0.852459,0.928571
5,corrupt_pinned,0.742857,0.464286


## Takeaways

Looks like TFDV didn't find all the right anomalies, but it found nonzero! Finding alerts precisely is very hard.