In [1]:
%load_ext autoreload
%autoreload 2

In [54]:
from dataval.dataset import WeatherDataset
from dataval.train import Trainer

import os
import matplotlib.pyplot as plt
import modal
import pandas as pd

In [3]:
image = (
    modal.Image.debian_slim()
    .pip_install_from_requirements("requirements.txt")
    .pip_install(["tensorflow-data-validation", "tensorflow_metadata"])
)
stub = modal.Stub("tfdv-tutorial", image=image)

# Schema Validation

In this notebook, we will implement TFX's schema validation to see if any of the corruptions in the previous notebook were detected.

In [4]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-partitioned-dataset"), sample_frac=0.2)

In [5]:
def train_and_test(train_df, test_df):
    X_train, y_train = ds.split_feature_label(train_df)

    catboost_hparams = {"depth": 5, "iterations": 250, "learning_rate": 0.03, "loss_function": "RMSE"}
    continual_t = Trainer(catboost_hparams)
    continual_t.fit(X_train, y_train, verbose=False)
    print(f"Train MSE for partition {ds.get_partition_key(train_df)}: {continual_t.score(X_train, y_train)}")

    # Evaluate
    X_test, y_test = ds.split_feature_label(test_df)
    print(f"Test MSE for partition {ds.get_partition_key(test_df)}: {continual_t.score(X_test, y_test)}")
    
    return continual_t

In [6]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

In [7]:
t = train_and_test(train_df, test_df)
t.get_feature_importances()

Train MSE for partition 2018_35: 4.074896379948685
Test MSE for partition 2018_36: 4.797844923928596


## Infer schema and check test data for errors

From the train dataframe, we create a schema using TFDV. Then we use this schema to find anomalies in the test data. We apply this to the original dataframes first.

In [45]:
@stub.function
def find_anomalies(train_df, test_df):
    import tensorflow_data_validation as tfdv
    from google.protobuf.json_format import MessageToDict
    
    train_stats =  tfdv.generate_statistics_from_dataframe(train_df)
    schema = tfdv.infer_schema(statistics=train_stats)
    test_stats = tfdv.generate_statistics_from_dataframe(test_df)
    # tfdv.visualize_statistics(lhs_statistics=test_stats, rhs_statistics=train_stats, lhs_name='TEST_DATASET', rhs_name='TRAIN_DATASET') TODO: figure out how to run this
    anomalies = tfdv.validate_statistics(statistics=test_stats, schema=schema)
    anomalies_df = tfdv.utils.display_util.get_anomalies_dataframe(anomalies)
    # return MessageToDict(anomalies)
    return anomalies_df

In [46]:
with stub.run():
    X_train, _ = ds.split_feature_label(train_df)
    X_test, _ = ds.split_feature_label(test_df)
    anomalies = find_anomalies.call(X_train, X_test)

In [47]:
anomalies

Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1


Seems like the raw data did not have any anomalies, which is correct!

## Iterate through corruptions

See if tfdv detects any anomalies, for all the corruptions we had in our previous notebook.

In [56]:
X_train, _ = ds.split_feature_label(train_df)
corruption_anomalies = {}

with stub.run():
    for corruption_name, corruption_res in ds.iterate_corruptions(test_df, "cmc", corruption_rate=0.05):
        corrupted_test_df, corrupted_columns = corruption_res
        corrupted_X_test, _ = ds.split_feature_label(corrupted_test_df)
        corruption_anomalies[corruption_name] = find_anomalies.call(X_train, corrupted_X_test)

In [59]:
for corruption_name, anomalies in corruption_anomalies.items():
    print(corruption_name)
    if len(anomalies) > 0:
        print(anomalies.head())

corrupt_null
                   Anomaly short description  \
Feature name                                   
'cmc_0_3_1_0'                Multiple errors   
'cmc_0_1_65_0'               Multiple errors   
'cmc_0_0_0_2_next'           Multiple errors   
'cmc_0_2_3_850'              Multiple errors   
'cmc_0_2_2_700'              Multiple errors   

                                             Anomaly long description  
Feature name                                                           
'cmc_0_3_1_0'       The feature has a shape, but it's not always p...  
'cmc_0_1_65_0'      The feature has a shape, but it's not always p...  
'cmc_0_0_0_2_next'  The feature has a shape, but it's not always p...  
'cmc_0_2_3_850'     The feature has a shape, but it's not always p...  
'cmc_0_2_2_700'     The feature has a shape, but it's not always p...  
corrupt_nonnegative
corrupt_typecheck
corrupt_units
corrupt_average
corrupt_pinned


## Takeaways

It looks like only the `corrupt_null` corruption was flagged by schema validation! Maybe other validation techniques might flag them. Nevertheless, all the corruptions that schema validation found were true corruptions, so there isn't a false positive alert problem here.