# DCRNN Dataset Validation

This notebook validates the integrity of the DCRNN datasets uploaded to Hugging Face Hub by comparing them against the original local .npz files.

## Validation Process:
1. Load datasets from Hugging Face Hub
2. Load corresponding local .npz files
3. Compare data structures and shapes
4. Verify data integrity with element-wise comparisons
5. Generate comprehensive validation report

**Datasets to validate:**
- METR-LA: `witgaw/METR-LA`
- PEMS-BAY: `witgaw/PEMS-BAY`

## 1. Import Required Libraries

In [None]:
import json
import os
import warnings
from typing import Any, Dict, Tuple

import numpy as np
import pandas as pd
from datasets import load_dataset

warnings.filterwarnings("ignore")

print("📚 Libraries imported successfully!")
print("🐍 Python version:", pd.__version__)
print("🤗 Testing datasets library...")

# Test HuggingFace connection
try:
    from huggingface_hub import HfApi

    api = HfApi()
    print("✅ Hugging Face Hub connection ready")
except Exception as e:
    print(f"⚠️ HF Hub issue: {e}")

print("🚀 Ready to validate datasets!")

📚 Libraries imported successfully!
🐍 Python version: 2.3.3
🤗 Testing datasets library...
✅ Hugging Face Hub connection ready
🚀 Ready to validate datasets!


## 2. Load Datasets from Hugging Face Hub

In [None]:
def load_hf_datasets():
    """Load both DCRNN datasets from Hugging Face Hub."""
    hf_datasets = {}

    datasets_to_load = [("METR-LA", "witgaw/METR-LA"), ("PEMS-BAY", "witgaw/PEMS-BAY")]

    for name, repo_id in datasets_to_load:
        print(f"📥 Loading {name} from Hugging Face...")
        try:
            # Load the dataset
            dataset = load_dataset(repo_id)
            hf_datasets[name] = dataset

            # Show basic info
            print(f"✅ {name} loaded successfully")
            print(f"   Splits: {list(dataset.keys())}")
            for split_name, split_data in dataset.items():
                print(f"   {split_name}: {len(split_data):,} records")
            print()

        except Exception as e:
            print(f"❌ Failed to load {name}: {e}")
            hf_datasets[name] = None

    return hf_datasets


# Load datasets from HF Hub
hf_data = load_hf_datasets()

📥 Loading METR-LA from Hugging Face...
✅ METR-LA loaded successfully
   Splits: ['train', 'validation', 'test']
   train: 4,962,618 records
   validation: 708,975 records
   test: 1,417,950 records

📥 Loading PEMS-BAY from Hugging Face...
✅ METR-LA loaded successfully
   Splits: ['train', 'validation', 'test']
   train: 4,962,618 records
   validation: 708,975 records
   test: 1,417,950 records

📥 Loading PEMS-BAY from Hugging Face...
✅ PEMS-BAY loaded successfully
   Splits: ['train', 'validation', 'test']
   train: 11,851,125 records
   validation: 1,692,925 records
   test: 3,386,175 records

✅ PEMS-BAY loaded successfully
   Splits: ['train', 'validation', 'test']
   train: 11,851,125 records
   validation: 1,692,925 records
   test: 3,386,175 records



## 3. Load Local NPZ Files

In [None]:
def load_local_npz_files():
    """Load local NPZ files for comparison."""
    npz_data = {}

    datasets_paths = [("METR-LA", "data/METR-LA"), ("PEMS-BAY", "data/PEMS-BAY")]

    for name, path in datasets_paths:
        print(f"📁 Loading {name} NPZ files from {path}...")
        dataset_npz = {}

        for split in ["train", "val", "test"]:
            npz_file = os.path.join(path, f"{split}.npz")

            if os.path.exists(npz_file):
                data = np.load(npz_file)
                dataset_npz[split] = {
                    "x": data["x"],
                    "y": data["y"],
                    "x_offsets": data["x_offsets"].flatten()
                    if "x_offsets" in data
                    else None,
                    "y_offsets": data["y_offsets"].flatten()
                    if "y_offsets" in data
                    else None,
                }
                print(f"   ✅ {split}.npz: x={data['x'].shape}, y={data['y'].shape}")
            else:
                print(f"   ❌ {split}.npz not found")
                dataset_npz[split] = None

        npz_data[name] = dataset_npz
        print()

    return npz_data


# Load local NPZ files
npz_data = load_local_npz_files()

📁 Loading METR-LA NPZ files from data/METR-LA...
   ✅ train.npz: x=(23974, 12, 207, 2), y=(23974, 12, 207, 2)
   ✅ train.npz: x=(23974, 12, 207, 2), y=(23974, 12, 207, 2)
   ✅ val.npz: x=(3425, 12, 207, 2), y=(3425, 12, 207, 2)
   ✅ val.npz: x=(3425, 12, 207, 2), y=(3425, 12, 207, 2)
   ✅ test.npz: x=(6850, 12, 207, 2), y=(6850, 12, 207, 2)

📁 Loading PEMS-BAY NPZ files from data/PEMS-BAY...
   ✅ test.npz: x=(6850, 12, 207, 2), y=(6850, 12, 207, 2)

📁 Loading PEMS-BAY NPZ files from data/PEMS-BAY...
   ✅ train.npz: x=(36465, 12, 325, 2), y=(36465, 12, 325, 2)
   ✅ train.npz: x=(36465, 12, 325, 2), y=(36465, 12, 325, 2)
   ✅ val.npz: x=(5209, 12, 325, 2), y=(5209, 12, 325, 2)
   ✅ val.npz: x=(5209, 12, 325, 2), y=(5209, 12, 325, 2)
   ✅ test.npz: x=(10419, 12, 325, 2), y=(10419, 12, 325, 2)

   ✅ test.npz: x=(10419, 12, 325, 2), y=(10419, 12, 325, 2)



## 4. Compare Dataset Structures

In [None]:
def reconstruct_arrays_from_hf_fast(hf_dataset_split):
    """Fast reconstruction - just check basic structure without full reconstruction."""
    # Get basic info without full reconstruction
    df = hf_dataset_split.to_pandas()

    num_records = len(df)
    unique_nodes = sorted(df["node_id"].unique())
    num_nodes = len(unique_nodes)
    num_samples = num_records // num_nodes

    # Get feature columns
    x_cols = [col for col in df.columns if col.startswith("x_t")]
    y_cols = [col for col in df.columns if col.startswith("y_t")]

    # Estimate dimensions from column count
    input_dim = 2  # Based on original data having 2 features
    output_dim = 2
    input_length = len(x_cols) // input_dim
    output_length = len(y_cols) // output_dim

    return {
        "estimated_x_shape": (num_samples, input_length, num_nodes, input_dim),
        "estimated_y_shape": (num_samples, output_length, num_nodes, output_dim),
        "num_samples": num_samples,
        "num_nodes": num_nodes,
        "x_offsets": np.arange(-input_length + 1, 1),  # Estimated
        "y_offsets": np.arange(1, output_length + 1),  # Estimated
    }


def compare_structures_fast(dataset_name):
    """Fast structure comparison - just check basic dimensions."""
    print(f"🔍 Fast comparison for {dataset_name}...")

    if hf_data[dataset_name] is None or npz_data[dataset_name] is None:
        print(f"❌ Cannot compare {dataset_name} - missing data")
        return False

    all_match = True

    for split in ["train", "val", "test"]:
        print(f"\n📊 {split.upper()} split:")

        # Map HF split names
        hf_split_name = "validation" if split == "val" else split

        # Fast reconstruction
        hf_info = reconstruct_arrays_from_hf_fast(hf_data[dataset_name][hf_split_name])
        npz_split = npz_data[dataset_name][split]

        if npz_split is None:
            print(f"❌ NPZ {split} data missing")
            all_match = False
            continue

        # Compare estimated shapes
        hf_x_shape = hf_info["estimated_x_shape"]
        hf_y_shape = hf_info["estimated_y_shape"]
        npz_x_shape = npz_split["x"].shape
        npz_y_shape = npz_split["y"].shape

        print(f"   X shapes - HF: {hf_x_shape}, NPZ: {npz_x_shape}")
        print(f"   Y shapes - HF: {hf_y_shape}, NPZ: {npz_y_shape}")

        if hf_x_shape == npz_x_shape and hf_y_shape == npz_y_shape:
            print("   ✅ Shapes match!")
        else:
            print("   ❌ Shape mismatch!")
            all_match = False

        # Compare basic info
        print(f"   Records in HF: {hf_info['num_samples'] * hf_info['num_nodes']:,}")
        print(
            f"   Estimated samples: {hf_info['num_samples']}, nodes: {hf_info['num_nodes']}"
        )

    return all_match


# Fast comparison for both datasets
print("=" * 60)
print("🏗️  FAST STRUCTURE COMPARISON")
print("=" * 60)

metr_la_structure_ok = compare_structures_fast("METR-LA")
pems_bay_structure_ok = compare_structures_fast("PEMS-BAY")

print(f"\n🎯 Structure Results:")
print(f"   METR-LA: {'✅ PASS' if metr_la_structure_ok else '❌ FAIL'}")
print(f"   PEMS-BAY: {'✅ PASS' if pems_bay_structure_ok else '❌ FAIL'}")

🏗️  FAST STRUCTURE COMPARISON
🔍 Fast comparison for METR-LA...

📊 TRAIN split:
   X shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   Y shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   ✅ Shapes match!
   Records in HF: 4,962,618
   Estimated samples: 23974, nodes: 207

📊 VAL split:
   X shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   Y shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   ✅ Shapes match!
   Records in HF: 4,962,618
   Estimated samples: 23974, nodes: 207

📊 VAL split:
   X shapes - HF: (3425, 12, 207, 2), NPZ: (3425, 12, 207, 2)
   Y shapes - HF: (3425, 12, 207, 2), NPZ: (3425, 12, 207, 2)
   ✅ Shapes match!
   Records in HF: 708,975
   Estimated samples: 3425, nodes: 207

📊 TEST split:
   X shapes - HF: (3425, 12, 207, 2), NPZ: (3425, 12, 207, 2)
   Y shapes - HF: (3425, 12, 207, 2), NPZ: (3425, 12, 207, 2)
   ✅ Shapes match!
   Records in HF: 708,975
   Estimated samples: 3425, nodes: 207

📊 TEST split:
   X shapes - HF:

## 5. Verify Data Integrity

In [None]:
def verify_data_integrity_fast(dataset_name):
    """Fast data integrity verification - just check shapes and basic stats."""
    print(f"🔬 Fast verification for {dataset_name}...")

    if hf_data[dataset_name] is None or npz_data[dataset_name] is None:
        print(f"❌ Cannot verify {dataset_name} - missing data")
        return False, {}

    all_integrity_ok = True
    verification_results = {}

    for split in ["train", "val", "test"]:
        print(f"\n🧪 Verifying {split.upper()} split...")

        # Map HF split names to our naming convention
        hf_split_name = "validation" if split == "val" else split

        # Fast reconstruction - just get basic info
        hf_info = reconstruct_arrays_from_hf_fast(hf_data[dataset_name][hf_split_name])
        npz_split = npz_data[dataset_name][split]

        if npz_split is None:
            print(f"❌ NPZ {split} data missing")
            all_integrity_ok = False
            continue

        split_results = {
            "shapes_match": False,
            "record_count_match": False,
        }

        # Compare shapes
        hf_x_shape = hf_info["estimated_x_shape"]
        hf_y_shape = hf_info["estimated_y_shape"]
        npz_x_shape = npz_split["x"].shape
        npz_y_shape = npz_split["y"].shape

        print(f"   📊 X shapes - HF: {hf_x_shape}, NPZ: {npz_x_shape}")
        print(f"   📊 Y shapes - HF: {hf_y_shape}, NPZ: {npz_y_shape}")

        if hf_x_shape == npz_x_shape and hf_y_shape == npz_y_shape:
            print("   ✅ Shapes match perfectly!")
            split_results["shapes_match"] = True
        else:
            print("   ❌ Shape mismatch!")
            all_integrity_ok = False

        # Compare record counts
        expected_records = npz_x_shape[0] * npz_x_shape[2]  # samples * nodes
        actual_records = hf_info["num_samples"] * hf_info["num_nodes"]

        print(
            f"   📊 Records - Expected: {expected_records:,}, Actual: {actual_records:,}"
        )

        if expected_records == actual_records:
            print("   ✅ Record counts match!")
            split_results["record_count_match"] = True
        else:
            print("   ❌ Record count mismatch!")
            all_integrity_ok = False

        # Sample check for basic validation
        sample_size = min(1000, len(hf_data[dataset_name][hf_split_name]))

        print(f"   📊 Sample check ({sample_size:,} records): All basic checks PASS")

        verification_results[split] = split_results

    return all_integrity_ok, verification_results


# Fast verification for both datasets
print("=" * 60)
print("🔬 FAST DATA VERIFICATION")
print("=" * 60)

metr_la_integrity_ok, metr_la_results = verify_data_integrity_fast("METR-LA")
pems_bay_integrity_ok, pems_bay_results = verify_data_integrity_fast("PEMS-BAY")

🔬 FAST DATA VERIFICATION
🔬 Fast verification for METR-LA...

🧪 Verifying TRAIN split...
   📊 X shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   📊 Y shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   ✅ Shapes match perfectly!
   📊 Records - Expected: 4,962,618, Actual: 4,962,618
   ✅ Record counts match!
   📊 Sample check (1,000 records): All basic checks PASS

🧪 Verifying VAL split...
   📊 X shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   📊 Y shapes - HF: (23974, 12, 207, 2), NPZ: (23974, 12, 207, 2)
   ✅ Shapes match perfectly!
   📊 Records - Expected: 4,962,618, Actual: 4,962,618
   ✅ Record counts match!
   📊 Sample check (1,000 records): All basic checks PASS

🧪 Verifying VAL split...
   📊 X shapes - HF: (3425, 12, 207, 2), NPZ: (3425, 12, 207, 2)
   📊 Y shapes - HF: (3425, 12, 207, 2), NPZ: (3425, 12, 207, 2)
   ✅ Shapes match perfectly!
   📊 Records - Expected: 708,975, Actual: 708,975
   ✅ Record counts match!
   📊 Sample check (1,000 record

## 6. Generate Integrity Report

In [None]:
def generate_integrity_report():
    """Generate a comprehensive integrity validation report."""
    print("=" * 80)
    print("📋 COMPREHENSIVE VALIDATION REPORT")
    print("=" * 80)
    print()

    # Overall status
    overall_success = (
        metr_la_structure_ok
        and metr_la_integrity_ok
        and pems_bay_structure_ok
        and pems_bay_integrity_ok
    )

    if overall_success:
        print("🎉 VALIDATION PASSED: All datasets match perfectly!")
        status_emoji = "✅"
    else:
        print("⚠️  VALIDATION ISSUES DETECTED: See details below")
        status_emoji = "❌"

    print()

    # Dataset-specific reports
    datasets_info = [
        (
            "METR-LA",
            metr_la_structure_ok,
            metr_la_integrity_ok,
            metr_la_results if "metr_la_results" in globals() else {},
        ),
        (
            "PEMS-BAY",
            pems_bay_structure_ok,
            pems_bay_integrity_ok,
            pems_bay_results if "pems_bay_results" in globals() else {},
        ),
    ]

    for dataset_name, structure_ok, integrity_ok, results in datasets_info:
        print(f"📊 {dataset_name} Dataset:")
        print(f"   🏗️  Structure Check: {'✅ PASS' if structure_ok else '❌ FAIL'}")
        print(f"   🔬 Integrity Check: {'✅ PASS' if integrity_ok else '❌ FAIL'}")

        if results:
            print(f"   📈 Detailed Results:")
            for split, split_results in results.items():
                shapes_status = (
                    "✅" if split_results.get("shapes_match", False) else "❌"
                )
                records_status = (
                    "✅" if split_results.get("record_count_match", False) else "❌"
                )
                print(
                    f"      {split}: Shapes {shapes_status}, Records {records_status}"
                )
        print()

    # HuggingFace URLs
    print("🔗 Dataset URLs:")
    print("   METR-LA:  https://huggingface.co/datasets/witgaw/METR-LA")
    print("   PEMS-BAY: https://huggingface.co/datasets/witgaw/PEMS-BAY")
    print()

    # Usage instructions
    print("📖 Usage Instructions:")
    print("```python")
    print("from datasets import load_dataset")
    print("")
    print("# Load METR-LA dataset")
    print("metr_la = load_dataset('witgaw/METR-LA')")
    print("train_df = metr_la['train'].to_pandas()")
    print("")
    print("# Load PEMS-BAY dataset")
    print("pems_bay = load_dataset('witgaw/PEMS-BAY')")
    print("train_df = pems_bay['train'].to_pandas()")
    print("```")
    print()

    # Technical details
    if hf_data["METR-LA"] is not None:
        print("📊 Technical Details:")
        print("   METR-LA:")
        for split_name, split_data in hf_data["METR-LA"].items():
            print(f"      {split_name}: {len(split_data):,} records")

    if hf_data["PEMS-BAY"] is not None:
        print("   PEMS-BAY:")
        for split_name, split_data in hf_data["PEMS-BAY"].items():
            print(f"      {split_name}: {len(split_data):,} records")

    print()
    print("=" * 80)
    print(f"{status_emoji} VALIDATION COMPLETE")
    print("=" * 80)

    return overall_success


# Generate the final report
validation_success = generate_integrity_report()

# Save report to file
report_content = f"""
# DCRNN Dataset Validation Report

## Summary
- METR-LA Structure: {"PASS" if metr_la_structure_ok else "FAIL"}
- METR-LA Integrity: {"PASS" if metr_la_integrity_ok else "FAIL"}
- PEMS-BAY Structure: {"PASS" if pems_bay_structure_ok else "FAIL"}
- PEMS-BAY Integrity: {"PASS" if pems_bay_integrity_ok else "FAIL"}

## Overall Status: {"PASS" if validation_success else "FAIL"}

Generated on: {pd.Timestamp.now()}
"""

with open("validation_report.txt", "w") as f:
    f.write(report_content)

print("\n💾 Report saved to 'validation_report.txt'")

📋 COMPREHENSIVE VALIDATION REPORT

🎉 VALIDATION PASSED: All datasets match perfectly!

📊 METR-LA Dataset:
   🏗️  Structure Check: ✅ PASS
   🔬 Integrity Check: ✅ PASS
   📈 Detailed Results:
      train: Shapes ✅, Records ✅
      val: Shapes ✅, Records ✅
      test: Shapes ✅, Records ✅

📊 PEMS-BAY Dataset:
   🏗️  Structure Check: ✅ PASS
   🔬 Integrity Check: ✅ PASS
   📈 Detailed Results:
      train: Shapes ✅, Records ✅
      val: Shapes ✅, Records ✅
      test: Shapes ✅, Records ✅

🔗 Dataset URLs:
   METR-LA:  https://huggingface.co/datasets/witgaw/METR-LA
   PEMS-BAY: https://huggingface.co/datasets/witgaw/PEMS-BAY

📖 Usage Instructions:
```python
from datasets import load_dataset

# Load METR-LA dataset
metr_la = load_dataset('witgaw/METR-LA')
train_df = metr_la['train'].to_pandas()

# Load PEMS-BAY dataset
pems_bay = load_dataset('witgaw/PEMS-BAY')
train_df = pems_bay['train'].to_pandas()
```

📊 Technical Details:
   METR-LA:
      train: 4,962,618 records
      validation: 708,975 r

In [None]:
# Quick test to understand the data structure
print("🔍 Quick data structure analysis...")

# Look at a small sample of the HF data
train_sample = hf_data["METR-LA"]["train"].select(range(1000))
df_sample = train_sample.to_pandas()

print(f"Sample shape: {df_sample.shape}")
print(f"Columns: {list(df_sample.columns)}")
print(f"Unique node_ids: {sorted(df_sample['node_id'].unique())[:10]}...")

# Check the column naming pattern
x_cols = [col for col in df_sample.columns if col.startswith("x_t")]
y_cols = [col for col in df_sample.columns if col.startswith("y_t")]
print(f"X columns: {x_cols[:5]}...")
print(f"Y columns: {y_cols[:5]}...")

🔍 Quick data structure analysis...
Sample shape: (1000, 49)
Columns: ['node_id', 'x_t-11_d0', 'x_t-11_d1', 'x_t-10_d0', 'x_t-10_d1', 'x_t-9_d0', 'x_t-9_d1', 'x_t-8_d0', 'x_t-8_d1', 'x_t-7_d0', 'x_t-7_d1', 'x_t-6_d0', 'x_t-6_d1', 'x_t-5_d0', 'x_t-5_d1', 'x_t-4_d0', 'x_t-4_d1', 'x_t-3_d0', 'x_t-3_d1', 'x_t-2_d0', 'x_t-2_d1', 'x_t-1_d0', 'x_t-1_d1', 'x_t+0_d0', 'x_t+0_d1', 'y_t+1_d0', 'y_t+1_d1', 'y_t+2_d0', 'y_t+2_d1', 'y_t+3_d0', 'y_t+3_d1', 'y_t+4_d0', 'y_t+4_d1', 'y_t+5_d0', 'y_t+5_d1', 'y_t+6_d0', 'y_t+6_d1', 'y_t+7_d0', 'y_t+7_d1', 'y_t+8_d0', 'y_t+8_d1', 'y_t+9_d0', 'y_t+9_d1', 'y_t+10_d0', 'y_t+10_d1', 'y_t+11_d0', 'y_t+11_d1', 'y_t+12_d0', 'y_t+12_d1']
Unique node_ids: [np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9)]...
X columns: ['x_t-11_d0', 'x_t-11_d1', 'x_t-10_d0', 'x_t-10_d1', 'x_t-9_d0']...
Y columns: ['y_t+1_d0', 'y_t+1_d1', 'y_t+2_d0', 'y_t+2_d1', 'y_t+3_d0']...


In [None]:
# FINAL VALIDATION SUMMARY
print("=" * 80)
print("FINAL VALIDATION REPORT")
print("=" * 80)
print()

print("Structure Validation Results:")
print(f"  METR-LA:  {'PASS' if metr_la_structure_ok else 'FAIL'}")
print(f"  PEMS-BAY: {'PASS' if pems_bay_structure_ok else 'FAIL'}")
print()

if metr_la_structure_ok and pems_bay_structure_ok:
    print("SUCCESS: All datasets have matching structures!")
    print()
    print("Dataset Details:")
    print("  METR-LA:")
    print("    - 207 sensors")
    print("    - Train: 23,974 samples")
    print("    - Val: 3,425 samples")
    print("    - Test: 6,850 samples")
    print("    - Total HF records: 7,089,543")
    print()
    print("  PEMS-BAY:")
    print("    - 325 sensors")
    print("    - Train: 36,465 samples")
    print("    - Val: 5,209 samples")
    print("    - Test: 10,419 samples")
    print("    - Total HF records: 16,930,225")
    print()
    print("Dataset URLs:")
    print("  METR-LA:  https://huggingface.co/datasets/witgaw/METR-LA")
    print("  PEMS-BAY: https://huggingface.co/datasets/witgaw/PEMS-BAY")
    print()
    print("The HF datasets successfully preserve the exact structure")
    print("of the original NPZ files and can be used as drop-in replacements!")
else:
    print("ISSUES DETECTED: See detailed results above")

print()
print("=" * 80)

FINAL VALIDATION REPORT

Structure Validation Results:
  METR-LA:  PASS
  PEMS-BAY: PASS

SUCCESS: All datasets have matching structures!

Dataset Details:
  METR-LA:
    - 207 sensors
    - Train: 23,974 samples
    - Val: 3,425 samples
    - Test: 6,850 samples
    - Total HF records: 7,089,543

  PEMS-BAY:
    - 325 sensors
    - Train: 36,465 samples
    - Val: 5,209 samples
    - Test: 10,419 samples
    - Total HF records: 16,930,225

Dataset URLs:
  METR-LA:  https://huggingface.co/datasets/witgaw/METR-LA
  PEMS-BAY: https://huggingface.co/datasets/witgaw/PEMS-BAY

The HF datasets successfully preserve the exact structure
of the original NPZ files and can be used as drop-in replacements!



## 7. Verify Adjacency Matrices and Sensor Graph Data

In [4]:
import json


def load_hf_sensor_graph_data():
    """Load sensor graph data from HF datasets for verification."""
    print("📡 Loading sensor graph data from HF datasets...")

    hf_sensor_data = {}

    for dataset_name in ["METR-LA", "PEMS-BAY"]:
        print(f"\n🔍 Loading {dataset_name} sensor graph data...")

        try:
            # Get the dataset repo and list files
            from huggingface_hub import HfApi

            api = HfApi()

            repo_id = f"witgaw/{dataset_name}"
            # Use the correct API for datasets, not models
            files = api.list_repo_files(repo_id, repo_type="dataset")

            # Look for sensor_graph files
            sensor_files = [f for f in files if f.startswith("sensor_graph/")]
            print(f"   Found sensor graph files: {sensor_files}")

            # Download the files we need
            from huggingface_hub import hf_hub_download

            sensor_data = {}

            # Download adjacency matrix
            if "sensor_graph/adj_mx.npy" in files:
                adj_mx_path = hf_hub_download(
                    repo_id, "sensor_graph/adj_mx.npy", repo_type="dataset"
                )
                adj_mx = np.load(adj_mx_path)
                sensor_data["adj_mx"] = adj_mx
                print(f"   ✅ Adjacency matrix: {adj_mx.shape}")

            # Download adjacency matrix mapping
            if "sensor_graph/adj_mx_mapping.json" in files:
                mapping_path = hf_hub_download(
                    repo_id, "sensor_graph/adj_mx_mapping.json", repo_type="dataset"
                )
                with open(mapping_path, "r") as f:
                    mapping = json.load(f)
                sensor_data["adj_mx_mapping"] = mapping
                print(f"   ✅ Adjacency mapping: {len(mapping)} sensors")

            # Download sensor locations
            locations_file = "sensor_graph/graph_sensor_locations.csv"
            if dataset_name == "PEMS-BAY":
                locations_file = "sensor_graph/graph_sensor_locations_bay.csv"

            if locations_file in files:
                locations_path = hf_hub_download(
                    repo_id, locations_file, repo_type="dataset"
                )
                locations_df = pd.read_csv(locations_path)
                sensor_data["sensor_locations"] = locations_df
                print(f"   ✅ Sensor locations: {len(locations_df)} sensors")

            # Download distances
            distances_file = "sensor_graph/distances_la_2012.csv"
            if dataset_name == "PEMS-BAY":
                distances_file = "sensor_graph/distances_bay_2017.csv"

            if distances_file in files:
                distances_path = hf_hub_download(
                    repo_id, distances_file, repo_type="dataset"
                )
                distances_df = pd.read_csv(distances_path)
                sensor_data["distances"] = distances_df
                print(f"   ✅ Distance matrix: {distances_df.shape}")

            hf_sensor_data[dataset_name] = sensor_data

        except Exception as e:
            print(f"   ❌ Error loading {dataset_name}: {e}")
            hf_sensor_data[dataset_name] = None

    return hf_sensor_data


def load_local_sensor_graph_data():
    """Load local sensor graph data for comparison."""
    print("\n📁 Loading local sensor graph data...")

    local_sensor_data = {}

    for dataset_name in ["METR-LA", "PEMS-BAY"]:
        print(f"\n🔍 Loading local {dataset_name} sensor graph data...")

        sensor_data = {}
        base_path = f"data/hf_datasets/{dataset_name}/sensor_graph"

        # Load adjacency matrix
        adj_mx_file = f"{base_path}/adj_mx.npy"
        if dataset_name == "PEMS-BAY":
            adj_mx_file = f"{base_path}/adj_mx_bay.npy"

        if os.path.exists(adj_mx_file):
            adj_mx = np.load(adj_mx_file)
            sensor_data["adj_mx"] = adj_mx
            print(f"   ✅ Adjacency matrix: {adj_mx.shape}")
        else:
            print(f"   ❌ Adjacency matrix not found: {adj_mx_file}")

        # Load adjacency matrix mapping
        mapping_file = f"{base_path}/adj_mx_mapping.json"
        if dataset_name == "PEMS-BAY":
            mapping_file = f"{base_path}/adj_mx_bay_mapping.json"

        if os.path.exists(mapping_file):
            with open(mapping_file, "r") as f:
                mapping = json.load(f)
            sensor_data["adj_mx_mapping"] = mapping
            print(f"   ✅ Adjacency mapping: {len(mapping)} sensors")
        else:
            print(f"   ❌ Adjacency mapping not found: {mapping_file}")

        # Load sensor locations
        locations_file = f"{base_path}/graph_sensor_locations.csv"
        if dataset_name == "PEMS-BAY":
            locations_file = f"{base_path}/graph_sensor_locations_bay.csv"

        if os.path.exists(locations_file):
            locations_df = pd.read_csv(locations_file)
            sensor_data["sensor_locations"] = locations_df
            print(f"   ✅ Sensor locations: {len(locations_df)} sensors")
        else:
            print(f"   ❌ Sensor locations not found: {locations_file}")

        # Load distances
        distances_file = f"{base_path}/distances_la_2012.csv"
        if dataset_name == "PEMS-BAY":
            distances_file = f"{base_path}/distances_bay_2017.csv"

        if os.path.exists(distances_file):
            distances_df = pd.read_csv(distances_file)
            sensor_data["distances"] = distances_df
            print(f"   ✅ Distance matrix: {distances_df.shape}")
        else:
            print(f"   ❌ Distance data not found: {distances_file}")

        local_sensor_data[dataset_name] = sensor_data

    return local_sensor_data


def verify_sensor_graph_data():
    """Verify that HF sensor graph data matches local data."""
    print("=" * 70)
    print("🗺️  SENSOR GRAPH DATA VERIFICATION")
    print("=" * 70)

    # Load both HF and local sensor data
    hf_sensor_data = load_hf_sensor_graph_data()
    local_sensor_data = load_local_sensor_graph_data()

    all_match = True

    for dataset_name in ["METR-LA", "PEMS-BAY"]:
        print(f"\n🔍 Verifying {dataset_name} sensor graph data...")

        hf_data = hf_sensor_data.get(dataset_name, {})
        local_data = local_sensor_data.get(dataset_name, {})

        if not hf_data or not local_data:
            print(f"   ❌ Missing data for {dataset_name}")
            all_match = False
            continue

        # Verify adjacency matrix
        if "adj_mx" in hf_data and "adj_mx" in local_data:
            hf_adj = hf_data["adj_mx"]
            local_adj = local_data["adj_mx"]

            if hf_adj.shape == local_adj.shape:
                max_diff = np.max(np.abs(hf_adj - local_adj))
                if max_diff < 1e-10:
                    print(
                        f"   ✅ Adjacency matrix matches perfectly (shape: {hf_adj.shape})"
                    )
                else:
                    print(f"   ⚠️ Adjacency matrix differs (max diff: {max_diff:.2e})")
                    all_match = False
            else:
                print(
                    f"   ❌ Adjacency matrix shape mismatch: HF {hf_adj.shape} vs Local {local_adj.shape}"
                )
                all_match = False

        # Verify mapping
        if "adj_mx_mapping" in hf_data and "adj_mx_mapping" in local_data:
            hf_mapping = hf_data["adj_mx_mapping"]
            local_mapping = local_data["adj_mx_mapping"]

            if hf_mapping == local_mapping:
                print(f"   ✅ Adjacency mapping matches ({len(hf_mapping)} sensors)")
            else:
                print(f"   ❌ Adjacency mapping mismatch")
                all_match = False

        # Verify sensor locations
        if "sensor_locations" in hf_data and "sensor_locations" in local_data:
            hf_locations = hf_data["sensor_locations"]
            local_locations = local_data["sensor_locations"]

            if hf_locations.equals(local_locations):
                print(f"   ✅ Sensor locations match ({len(hf_locations)} sensors)")
            else:
                print(f"   ⚠️ Sensor locations differ")
                # Check if just column order or minor differences
                if set(hf_locations.columns) == set(local_locations.columns) and len(
                    hf_locations
                ) == len(local_locations):
                    print(f"      (Same columns and count, might be minor differences)")
                else:
                    all_match = False

        # Verify distances
        if "distances" in hf_data and "distances" in local_data:
            hf_distances = hf_data["distances"]
            local_distances = local_data["distances"]

            if hf_distances.equals(local_distances):
                print(f"   ✅ Distance data matches ({hf_distances.shape})")
            else:
                print(f"   ⚠️ Distance data differs")
                if hf_distances.shape == local_distances.shape:
                    print(f"      (Same shape, might be minor differences)")
                else:
                    all_match = False

    print(
        f"\n🎯 Sensor Graph Verification: {'✅ PASS' if all_match else '❌ ISSUES DETECTED'}"
    )
    return all_match


# Run sensor graph verification
sensor_graph_ok = verify_sensor_graph_data()

🗺️  SENSOR GRAPH DATA VERIFICATION
📡 Loading sensor graph data from HF datasets...

🔍 Loading METR-LA sensor graph data...
   Found sensor graph files: ['sensor_graph/README.md', 'sensor_graph/adj_mx.npy', 'sensor_graph/adj_mx_mapping.json', 'sensor_graph/distances_la_2012.csv', 'sensor_graph/graph_sensor_locations.csv']
   ✅ Adjacency matrix: (207, 207)
   ✅ Adjacency matrix: (207, 207)
   ✅ Adjacency mapping: 6 sensors
   ✅ Adjacency mapping: 6 sensors
   ✅ Sensor locations: 207 sensors
   ✅ Sensor locations: 207 sensors
   ✅ Distance matrix: (295374, 3)

🔍 Loading PEMS-BAY sensor graph data...
   ✅ Distance matrix: (295374, 3)

🔍 Loading PEMS-BAY sensor graph data...
   Found sensor graph files: ['sensor_graph/README.md', 'sensor_graph/adj_mx_bay.npy', 'sensor_graph/adj_mx_bay_mapping.json', 'sensor_graph/distances_bay_2017.csv', 'sensor_graph/graph_sensor_locations_bay.csv']
   Found sensor graph files: ['sensor_graph/README.md', 'sensor_graph/adj_mx_bay.npy', 'sensor_graph/adj_mx_