In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataval.dataset import WeatherDataset
from dataval.train import Trainer

import os
import matplotlib.pyplot as plt
import modal
import pandas as pd

In [3]:
image = (
    modal.Image.debian_slim()
    .pip_install_from_requirements("requirements.txt")
    .pip_install(["tensorflow-data-validation", "tensorflow_metadata"])
)
stub = modal.Stub("tfdv-tutorial", image=image)

# Drift Detection

Schema validation catches some, but not all, corruptions. In this notebook, we leverage TFDV's drift detection tool to see if all corruptions are identified.

In [4]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-partitioned-dataset"), sample_frac=0.2)

In [5]:
def train_and_test(train_df, test_df):
    X_train, y_train = ds.split_feature_label(train_df)

    catboost_hparams = {"depth": 5, "iterations": 250, "learning_rate": 0.03, "loss_function": "RMSE"}
    continual_t = Trainer(catboost_hparams)
    continual_t.fit(X_train, y_train, verbose=False)
    print(f"Train MSE for partition {ds.get_partition_key(train_df)}: {continual_t.score(X_train, y_train)}")

    # Evaluate
    X_test, y_test = ds.split_feature_label(test_df)
    print(f"Test MSE for partition {ds.get_partition_key(test_df)}: {continual_t.score(X_test, y_test)}")
    
    return continual_t

In [6]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

In [9]:
t = train_and_test(train_df, test_df)
t.get_feature_importance().head(5)

Train MSE for partition 2018_35: 4.074896379948685
Test MSE for partition 2018_36: 4.797844923928596


Unnamed: 0,feature,importance
6,cmc_0_0_0_2_interpolated,19.153059
89,gfs_temperature_sea_interpolated,15.313095
87,gfs_temperature_sea,12.796032
109,wrf_t2_interpolated,11.163675
8,cmc_0_0_0_2,6.928703


## Check for skew between train and test partitions

We use TFDV to infer the schema of the train partition and then check the test partition for skew. TFDV checks for distribution shift.

In [19]:
@stub.function
def check_skew(train_df, test_df, feature_columns):
    import tensorflow_data_validation as tfdv
    from google.protobuf.json_format import MessageToDict
    
    train_stats =  tfdv.generate_statistics_from_dataframe(train_df)
    schema = tfdv.infer_schema(statistics=train_stats)
    test_stats = tfdv.generate_statistics_from_dataframe(test_df)
    
    for feature in feature_columns:
        tfdv.get_feature(schema, feature).skew_comparator.jensen_shannon_divergence.threshold = 0.1

    skew_anomalies = tfdv.validate_statistics(statistics=train_stats, schema=schema, serving_statistics=test_stats)
    anomalies_df = tfdv.utils.display_util.get_anomalies_dataframe(skew_anomalies)

    return anomalies_df

In [20]:
# Run on regular train and test data

with stub.run():
    X_train, _ = ds.split_feature_label(train_df)
    X_test, _ = ds.split_feature_label(test_df)
    anomalies = check_skew.call(X_train, X_test, X_train.columns.values)

In [26]:
pd.options.display.max_colwidth = 100
anomalies

Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_1_65_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.160324 (up to six si...
'cmc_0_0_0_925',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.188829 (up to six si...
'cmc_0_0_0_2_interpolated',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.222345 (up to six si...
'cmc_0_0_0_2_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.224824 (up to six si...
'cmc_0_1_68_0_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.992944 (up to six si...
'cmc_0_1_67_0_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.818735 (up to six si...
'cmc_0_0_0_700',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.23262 (up to six sig...
'cmc_0_3_1_0',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.534912 (up to six si...
'gfs_2m_dewpoint_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.116569 (up to six si...
'cmc_0_1_68_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.998486 (up to six si...


Wow, it looks like there were many alerts triggered! Unclear if these alerts are meaningful though, as the test performance is not so much worse than the train performance. Also, how would we interpret the alerts?

## Iterate through corruptions

See if tfdv detects any anomalies, for all the corruptions we had in our previous notebook.

In [29]:
X_train, _ = ds.split_feature_label(train_df)
corruption_anomalies = {}
corruption_columns = {}

with stub.run():
    for corruption_name, corruption_res in ds.iterate_corruptions(test_df, "cmc", corruption_rate=0.05):
        corrupted_test_df, corrupted_columns = corruption_res
        corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
        corruption_anomalies[corruption_name] = check_skew.call(X_train, corrupted_X_test, X_train.columns.values)
        corruption_columns[corruption_name] = corrupted_columns

In [40]:
k = 5

for corruption_name, anomalies in corruption_anomalies.items():
    found_columns = [a[1:-1] for a in anomalies.index.values]
    inter = set(found_columns).intersection(set(corruption_columns[corruption_name]))
    print(f"TFDV found {len(inter)} of {len(corruption_columns[corruption_name])} anomalies for corruption {corruption_name}. {len(set(found_columns))} were found in total. Displaying {k}:")
    if len(anomalies) > 0:
        display(anomalies.head(k))

TFDV found 10 of 56 anomalies for corruption corrupt_null. 19 were found in total. Displaying 5:


Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'gfs_precipitations',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.245893 (up to six si...
'cmc_0_1_67_0_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.641023 (up to six si...
'wrf_graupel',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.769175 (up to six si...
'cmc_0_1_7_0',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.102191 (up to six si...
'gfs_r_velocity',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.128119 (up to six si...


TFDV found 32 of 38 anomalies for corruption corrupt_nonnegative. 45 were found in total. Displaying 5:


Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_0_0_700',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.278185 (up to six si...
'cmc_0_1_68_0',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.99931 (up to six sig...
'cmc_0_1_65_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.160259 (up to six si...
'gfs_r_velocity',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.128119 (up to six si...
'cmc_0_0_0_2_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.196412 (up to six si...


TFDV found 2 of 5 anomalies for corruption corrupt_typecheck. 20 were found in total. Displaying 5:


Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_1_66_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.173947 (up to six si...
'topography_bathymetry',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.166572 (up to six si...
'cmc_0_1_65_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.160259 (up to six si...
'cmc_0_1_67_0_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.641024 (up to six si...
'cmc_0_1_7_0',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.102128 (up to six si...


TFDV found 41 of 56 anomalies for corruption corrupt_units. 50 were found in total. Displaying 5:


Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_3_5_500',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.352539 (up to six si...
'cmc_0_1_0_0',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.965299 (up to six si...
'topography_bathymetry',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.166572 (up to six si...
'cmc_0_1_65_0_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.133717 (up to six si...
'cmc_0_3_5_850',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.263825 (up to six si...


TFDV found 52 of 56 anomalies for corruption corrupt_average. 61 were found in total. Displaying 5:


Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_2_3_1000',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.899013 (up to six si...
'gfs_2m_dewpoint_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.116569 (up to six si...
'cmc_0_2_3_500',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.841468 (up to six si...
'cmc_0_0_7_925',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.871904 (up to six si...
'cmc_0_2_2_500',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.82213 (up to six sig...


TFDV found 26 of 56 anomalies for corruption corrupt_pinned. 35 were found in total. Displaying 5:


Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'cmc_0_1_67_0_next',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.818735 (up to six si...
'cmc_0_3_5_700',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.251076 (up to six si...
'cmc_0_1_68_0_grad',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.998486 (up to six si...
'climate_pressure',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.116659 (up to six si...
'cmc_0_0_0_2',High approximate Jensen-Shannon divergence between training and serving,The approximate Jensen-Shannon divergence between training and serving is 0.227713 (up to six si...


## Takeaways

Looks like TFDV didn't find all the right anomalies, but it found nonzero! Finding alerts precisely is very hard.