In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataval.dataset import WeatherDataset
from dataval.train import Trainer

import os
import matplotlib.pyplot as plt
import modal
import pandas as pd

In [12]:
image = (
    modal.Image.debian_slim()
    .pip_install_from_requirements("requirements.txt")
    .pip_install(["tensorflow-data-validation"])
)
stub = modal.Stub("tfdv-tutorial", image=image)

# Schema Validation

In this notebook, we will implement TFX's schema validation to drop rows that might be corrupted. We will use the same corruptions from the previous notebook, and measure how the validation improves MSE.

In [7]:
# Load dataset

ds = WeatherDataset(os.path.join(os.getcwd(), "canonical-partitioned-dataset"), sample_frac=0.2)

In [6]:
def train_and_test(train_df, test_df):
    X_train, y_train = ds.split_feature_label(train_df)

    catboost_hparams = {"depth": 5, "iterations": 250, "learning_rate": 0.03, "loss_function": "RMSE"}
    continual_t = Trainer(catboost_hparams)
    continual_t.fit(X_train, y_train, verbose=False)
    print(f"Train accuracy for partition {ds.get_partition_key(train_df)}: {continual_t.score(X_train, y_train)}")

    # Evaluate
    X_test, y_test = ds.split_feature_label(test_df)
    print(f"Test accuracy for partition {ds.get_partition_key(test_df)}: {continual_t.score(X_test, y_test)}")
    
    return continual_t

In [8]:
train_df = ds.load(ds.get_partition_keys()[0])
test_df = ds.load(ds.get_partition_keys()[1])

## Infer schema and check test data for errors

From the train dataframe, we create a schema using TFDV. Then we use this schema to find anomalies in the test data. We apply this to the original dataframes first.

In [23]:
@stub.function
def gen_schema_from_dataframe(dataframe):
    import tensorflow_data_validation as tfdv
    stats =  tfdv.generate_statistics_from_dataframe(dataframe)    
    schema = tfdv.infer_schema(statistics=stats)
    print(tfdv.display_schema(schema=schema))

In [24]:
with stub.run():
    gen_schema_from_dataframe.call(train_df)

In [27]:
@stub.function
def find_anomalies(train_df, test_df):
    import tensorflow_data_validation as tfdv
    train_stats =  tfdv.generate_statistics_from_dataframe(train_df)
    schema = tfdv.infer_schema(statistics=train_stats)
    test_stats = tfdv.generate_statistics_from_dataframe(test_df)
    # tfdv.visualize_statistics(lhs_statistics=test_stats, rhs_statistics=train_stats, lhs_name='TEST_DATASET', rhs_name='TRAIN_DATASET')
    anomalies = tfdv.validate_statistics(statistics=test_stats, schema=schema)
    print(anomalies)

In [28]:
with stub.run():
    find_anomalies.call(train_df, test_df)