# Data Preparation for TabPFN Notebooks

This notebook downloads and prepares all datasets used in the TabPFN demo notebooks. The datasets are stored as Delta tables in Unity Catalog for easy access.

**Datasets prepared:**
1. **Breast Cancer Wisconsin** - Binary classification dataset
2. **Iris** - Multi-class classification dataset
3. **California Housing** - Regression dataset
4. **Monash Tourism Monthly** - Time series forecasting dataset

**Run this notebook once** to set up all the data before running the other notebooks.

## Compute Setup

We recommend running this notebook on **Serverless Compute** with the **Base Environment V4**.

To configure:
1. Click on the compute selector in the notebook toolbar
2. Select **Serverless**
3. Under Environment, choose **Base Environment V4**

Serverless compute provides fast startup times and automatic scaling, ideal for interactive notebook workflows.

## 1. Install Required Packages

In [None]:
%pip install scikit-learn pandas --quiet

In [None]:
dbutils.library.restartPython()

## 2. Configuration

Define the catalog and schema where datasets will be stored.

In [None]:
# Configure your catalog and schema
CATALOG = "tabpfn_databricks"
SCHEMA = "default"

# Create the catalog and schema if they don't exist
spark.sql(f"CREATE CATALOG IF NOT EXISTS {CATALOG}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}")
spark.sql(f"USE CATALOG {CATALOG}")
spark.sql(f"USE SCHEMA {SCHEMA}")

print(f"Using catalog: {CATALOG}")
print(f"Using schema: {SCHEMA}")

## 3. Import Libraries

In [None]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer, load_iris, fetch_california_housing

## 4. Breast Cancer Dataset (Classification)

In [None]:
# Load Breast Cancer dataset
data = load_breast_cancer()

# Clean column names: replace spaces with underscores (Delta doesn't allow spaces)
clean_columns = [col.replace(' ', '_') for col in data.feature_names]
df_breast_cancer = pd.DataFrame(data.data, columns=clean_columns)
df_breast_cancer["target"] = data.target

# Save to Delta table
spark.createDataFrame(df_breast_cancer).write.mode("overwrite").saveAsTable("breast_cancer")

print(f"Breast Cancer dataset saved to {CATALOG}.{SCHEMA}.breast_cancer")
print(f"Shape: {df_breast_cancer.shape}")
print(f"Classes: {data.target_names.tolist()}")

## 5. Iris Dataset (Multi-class Classification)

In [None]:
# Load Iris dataset
iris = load_iris()

# Clean column names: replace spaces and parentheses (Delta doesn't allow these)
clean_iris_cols = [col.replace(' ', '_').replace('(', '').replace(')', '') for col in iris.feature_names]
df_iris = pd.DataFrame(iris.data, columns=clean_iris_cols)
df_iris["target"] = iris.target

# Save to Delta table
spark.createDataFrame(df_iris).write.mode("overwrite").saveAsTable("iris")

print(f"Iris dataset saved to {CATALOG}.{SCHEMA}.iris")
print(f"Shape: {df_iris.shape}")
print(f"Classes: {iris.target_names.tolist()}")

## 6. California Housing Dataset (Regression)

In [None]:
# Load California Housing dataset
housing = fetch_california_housing()
df_housing = pd.DataFrame(housing.data, columns=housing.feature_names)
df_housing["target"] = housing.target

# Save to Delta table
spark.createDataFrame(df_housing).write.mode("overwrite").saveAsTable("california_housing")

print(f"California Housing dataset saved to {CATALOG}.{SCHEMA}.california_housing")
print(f"Shape: {df_housing.shape}")
print(f"Target: Median house value (in $100,000s)")

## 7. Tourism Monthly Dataset (Time Series)

In [None]:
# Generate synthetic monthly time series data (similar to tourism data)
# This avoids HuggingFace compatibility issues on Databricks

np.random.seed(42)
n_series = 50  # Number of time series
n_months = 120  # 10 years of monthly data

records = []
start_date = pd.Timestamp("2010-01-01")

for series_idx in range(n_series):
    # Generate realistic tourism-like patterns:
    # - Base level varies by series
    # - Seasonal pattern (yearly)
    # - Trend component
    # - Random noise
    
    base_level = np.random.uniform(500, 5000)
    trend = np.random.uniform(-5, 15)  # Monthly trend
    seasonal_amplitude = base_level * np.random.uniform(0.2, 0.5)
    noise_level = base_level * 0.1
    
    for month_idx in range(n_months):
        timestamp = start_date + pd.DateOffset(months=month_idx)
        
        # Seasonal component (peaks in summer months)
        seasonal = seasonal_amplitude * np.sin(2 * np.pi * (month_idx - 3) / 12)
        
        # Trend component
        trend_component = trend * month_idx
        
        # Random noise
        noise = np.random.normal(0, noise_level)
        
        # Combine components
        value = max(0, base_level + seasonal + trend_component + noise)
        
        records.append({
            "item_id": f"T{series_idx:06d}",
            "timestamp": timestamp,
            "target": float(value)
        })

df_tourism = pd.DataFrame(records)

# Save to Delta table
spark.createDataFrame(df_tourism).write.mode("overwrite").saveAsTable("tourism_monthly")

print(f"Tourism Monthly dataset saved to {CATALOG}.{SCHEMA}.tourism_monthly")
print(f"Number of time series: {n_series}")
print(f"Total records: {len(df_tourism)}")

## 8. Verify All Tables

In [None]:
# List all tables in the schema
print(f"Tables in {CATALOG}.{SCHEMA}:")
display(spark.sql(f"SHOW TABLES IN {CATALOG}.{SCHEMA}"))

In [None]:
# Preview each table
print("Breast Cancer sample:")
display(spark.table("breast_cancer").limit(5))

print("\nIris sample:")
display(spark.table("iris").limit(5))

print("\nCalifornia Housing sample:")
display(spark.table("california_housing").limit(5))

print("\nTourism Monthly sample:")
display(spark.table("tourism_monthly").limit(10))

## Summary

All datasets have been prepared and saved as Delta tables in `tabpfn_databricks.default`:

| Table | Description | Usage |
|-------|-------------|-------|
| `breast_cancer` | Binary classification (569 samples, 30 features) | 01_classification |
| `iris` | Multi-class classification (150 samples, 4 features) | 01_classification |
| `california_housing` | Regression (20,640 samples, 8 features) | 02_regression |
| `tourism_monthly` | Time series forecasting (50 synthetic series, 120 months each) | 04_time_series_forecasting |

**Next steps:** Run the individual notebooks (01-04) to explore TabPFN capabilities using these prepared datasets.