In [1]:
%load_ext autoreload
%autoreload 2

# Schema Validation

In this notebook, we will implement TFX's schema validation to see if any of the corruptions in the previous notebook were detected. We'll log the results of schema validation to wandb.

You can set up wandb alerts here: https://docs.wandb.ai/guides/runs/alert

I use Modal because TFDV doesn't run on Mac M1s. You can create a free account on Modal here: https://modal.com/signup -- it comes with $10/month credits, which should be plenty more than enough to run the notebooks in this course. Once you have created an account, follow the "Getting Started" instructions on the homepage:

* Run `pip install modal-client` (also included in `requirements.txt` in this repo)
* Run `modal token new`, which will open a browser window and authenticate you with your account

Then you should be able to run this notebook!

In [2]:
from dataval.dataset import WeatherDataset
from dataval.train import CatBoostTrainer

import os
import matplotlib.pyplot as plt
import modal
import pandas as pd

import wandb
from wandb import AlertLevel

In [3]:
image = (
    modal.Image.debian_slim()
    .pip_install_from_requirements("requirements.txt")
    .pip_install(["tensorflow-data-validation", "tensorflow_metadata"])
)
stub = modal.Stub("tfdv-tutorial", image=image)

In [4]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-partitioned-dataset"), sample_frac=0.2)

In [5]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

In [6]:
train_df

Unnamed: 0,fact_time,fact_latitude,fact_longitude,fact_temperature,fact_cwsm_class,climate,topography_bathymetry,sun_elevation,climate_temperature,climate_pressure,...,cmc_0_1_66_0_next,cmc_0_1_67_0_grad,cmc_0_1_67_0_next,cmc_0_1_68_0_grad,cmc_0_1_68_0_next,gfs_2m_dewpoint_grad,gfs_2m_dewpoint_next,gfs_total_clouds_cover_low_grad,gfs_total_clouds_cover_low_next,year_week
0,2018-09-01 00:00:00,-34.583333,-68.400000,3.0,20.0,dry,740.0,-21.623605,13.373571,700.525625,...,2.7778,0.0,0.000005,0.0,0.0,-0.269379,-1.919379,-17.0,33.0,2018_35
1,2018-09-01 00:00:00,-1.650000,13.433333,23.0,20.0,tropical,430.0,-75.638305,22.277857,729.286679,...,0.0000,0.0,0.000000,0.0,0.0,-0.299988,20.450006,-6.0,2.0,2018_35
2,2018-09-01 00:00:00,35.533333,35.766667,25.0,0.0,mild temperate,2.0,-35.902113,26.162143,751.406267,...,0.0000,0.0,0.000000,0.0,0.0,0.182770,22.032770,0.0,0.0,2018_35
3,2018-09-01 00:00:00,47.432201,0.727606,15.0,0.0,mild temperate,103.0,-35.725373,16.347143,756.800746,...,0.0000,0.0,0.000000,0.0,0.0,0.000000,8.749994,-1.0,0.0,2018_35
4,2018-09-01 00:00:00,43.427101,-3.820010,18.0,0.0,mild temperate,5.0,-39.615037,18.630714,758.808740,...,0.0000,0.0,0.000000,0.0,0.0,-0.799988,11.650018,0.0,0.0,2018_35
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5684,2018-09-02 23:56:00,27.692600,-97.291100,31.0,10.0,dry,6.0,22.706908,30.584286,760.531371,...,0.0000,0.0,0.000000,0.0,0.0,1.221222,23.749994,0.0,0.0,2018_35
5685,2018-09-02 23:56:00,40.437401,-104.633003,28.0,0.0,dry,1416.0,26.545844,28.351429,644.658950,...,0.0000,0.0,0.000000,0.0,0.0,0.100006,10.350000,0.0,0.0,2018_35
5686,2018-09-02 23:56:00,40.193501,-76.763397,26.0,10.0,mild temperate,97.0,5.506977,27.213571,750.389714,...,0.0000,0.0,0.000000,0.0,0.0,-1.542908,22.607080,0.0,0.0,2018_35
5687,2018-09-02 23:56:00,32.898602,-80.040497,28.0,10.0,mild temperate,10.0,7.724592,29.476492,762.621737,...,0.0000,0.0,0.000000,0.0,0.0,-0.700043,22.949976,-1.0,1.0,2018_35


## Infer schema and check test data for errors

From the train dataframe, we create a schema using TFDV. Then we use this schema to find anomalies in the test data. We apply this to the original dataframes first.

In [7]:
@stub.function
def find_anomalies(train_df, test_df):
    import tensorflow_data_validation as tfdv
    from google.protobuf.json_format import MessageToDict
    
    train_stats =  tfdv.generate_statistics_from_dataframe(train_df)
    schema = tfdv.infer_schema(statistics=train_stats)
    test_stats = tfdv.generate_statistics_from_dataframe(test_df)
    
    anomalies = tfdv.validate_statistics(statistics=test_stats, schema=schema)
    anomalies_df = tfdv.utils.display_util.get_anomalies_dataframe(anomalies)
    # return MessageToDict(anomalies)
    return anomalies_df

In [8]:
with stub.run():
    X_train, _ = ds.split_feature_label(train_df)
    X_test, _ = ds.split_feature_label(test_df)
    anomalies = find_anomalies.call(X_train, X_test)

Output()

Output()

Output()

In [9]:
anomalies

Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1


Seems like the raw data did not have any anomalies!

## Iterate through corruptions

See if tfdv detects any anomalies, for all the corruptions we had in our previous notebook.

In [12]:
X_train, _ = ds.split_feature_label(train_df)
corruption_anomalies = {}

for corruption_name, corruption_res in ds.iterate_corruptions(test_df, "cmc", corruption_rate=0.05):
    corrupted_test_df, corrupted_columns = corruption_res
    corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
    with stub.run():
        corruption_anomalies[corruption_name] = find_anomalies.call(X_train, corrupted_X_test)

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

In [13]:
# Send wandb alerts

run = wandb.init(project="ml-dataval-tutorial", tags=["TFDV-schema"])

for corruption_name, anomalies in corruption_anomalies.items():
    if len(anomalies) > 0:
        table = wandb.Table(dataframe=anomalies)
        wandb.log({corruption_name: table})
    
        wandb.alert(
            title=f"Errors detected in {corruption_name} experiment", 
            text = f"Found {len(anomalies)} anomalies",
            level=AlertLevel.WARN,
        )

wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.183002…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01671355486666736, max=1.0)…

## Takeaways

It looks like only the `corrupt_null` corruption was flagged by schema validation! Maybe other validation techniques might flag them. Nevertheless, all the corruptions that schema validation found were true corruptions, so there isn't a false positive alert problem here.