In [1]:
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

In [2]:
from src.tabpfn.model.causalFM4FD import PerFeatureTransformerCATE
from DATA_FD.causaldatasettestFD import create_test_data_loader

In [3]:
@torch.no_grad()
def eval_one_csv(model, device, csv_path, batch_size=1):
    """
    Evaluate on a single dataset.
    - Uses 80% as the training data and 20% as the test data for one dataset.
    """
    name = os.path.splitext(os.path.basename(csv_path))[0]
    loader = create_test_data_loader(csv_path, batch_size=batch_size, shuffle=False)

    se_list = []
    nsamples = 0

    for batch in loader: # one batch is one CSV file (load all the samples)
        x = batch['X'].to(device)
        a = batch['a'].to(device)
        y = batch['y'].to(device)
        m = batch['m'].to(device)  # factual
        ite = batch['ite'].to(device)  # ground-truth ITE

        # train test split 
        ratio = 0.8
        train_text_split = int(y.shape[0] * ratio)

        # Split data into train and test
        x_train = x[:train_text_split].squeeze(1)
        a_train = a[:train_text_split].squeeze(1)
        y_train = y[:train_text_split].squeeze(1)
        m_train = m[:train_text_split].squeeze(1)

        x_test = x[train_text_split:].squeeze(1)
        ite_test = ite[train_text_split:].squeeze()
        
        # print(f"Train samples: {train_text_split}, Test samples: {ite_test.shape[0]}")
        # print("ite_test.shape:", ite_test.shape)
        
        # Get model predictions using estimate_cate_fd method
        out = model.estimate_cate_fd(x_train, a_train, y_train, m_train, x_test)
        cate_test = out['cate']

        # Calculate error
        se = (cate_test - ite_test) ** 2
        se_list.append(se.cpu().numpy())
        nsamples += cate_test.numel()

    if nsamples == 0:
        return {'dataset': name, 'samples': 0, 'mse': np.nan, 'pehe': np.nan}

    se_all = np.concatenate(se_list)
    mse = float(np.mean(se_all))
    pehe = float(np.sqrt(mse))
    return {'dataset': name, 'samples': nsamples, 'mse': mse, 'pehe': pehe}


In [4]:
data_dir = "DATA_FD/frontdoor_TEST"
data_prefix = "frontdoor_test_dataset_"
num_datasets = 10
model_path = 'checkpoints_FD/best_model.pth'
batch_size = 1
gpu = 0
output_dir = 'test_results'

# Create output directory
os.makedirs(output_dir, exist_ok=True)

In [5]:
# Device
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() and gpu >= 0 else 'cpu')
# print(f"Using device: {device}")

# Model
model = PerFeatureTransformerCATE().to(device)
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model checkpoint not found: {model_path}")

ckpt = torch.load(model_path, map_location=device)
state = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
model.load_state_dict(state)
model.eval()
print("Model loaded successfully!")

Model loaded successfully!


In [6]:
# Evaluate each dataset one by one
rows = []
print("\n=== Per-dataset Test PEHE ===")
for i in tqdm(range(1, num_datasets + 1), desc="Datasets"):
    csv_path = os.path.join(data_dir, f"{data_prefix}{i}.csv")
    if not os.path.exists(csv_path):
        print(f"[Skip] Not found: {csv_path}")
        continue

    res = eval_one_csv(model, device, csv_path, batch_size=batch_size)
    rows.append(res)
    # Directly print the result for this CSV
    print(f"{res['dataset']}: PEHE={res['pehe']:.4f}  (MSE={res['mse']:.6f}, samples={res['samples']})")
if not rows:
    print("\nNo datasets evaluated. Nothing to save.")
else:
    # Summarize and save
    df = pd.DataFrame(rows).sort_values('dataset')
    summary_path = os.path.join(output_dir, 'summary.csv')
    df.to_csv(summary_path, index=False)
    
    # Display the dataframe
    display(df)
    
    # Print overall statistics
    valid = df['pehe'].dropna().to_numpy()
    if valid.size > 0:
        print("\n=== Overall ===")
        print(f"Avg PEHE: {valid.mean():.4f} ± {valid.std():.4f}")
        print(f"Best PEHE: {valid.min():.4f}  ({df.iloc[int(valid.argmin())]['dataset']})")
        print(f"Worst PEHE: {valid.max():.4f} ({df.iloc[int(valid.argmax())]['dataset']})")
    print(f"\nSaved summary to: {summary_path}")



=== Per-dataset Test PEHE ===


Datasets:   0%|                                          | 0/10 [00:00<?, ?it/s]

Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_1.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_1.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  10%|███▍                              | 1/10 [00:03<00:27,  3.04s/it]

frontdoor_test_dataset_1: PEHE=0.7358  (MSE=0.541353, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_2.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_2.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  20%|██████▊                           | 2/10 [00:05<00:21,  2.70s/it]

frontdoor_test_dataset_2: PEHE=0.6783  (MSE=0.460074, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_3.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_3.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  30%|██████████▏                       | 3/10 [00:07<00:17,  2.55s/it]

frontdoor_test_dataset_3: PEHE=1.5413  (MSE=2.375613, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_4.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_4.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  40%|█████████████▌                    | 4/10 [00:10<00:14,  2.50s/it]

frontdoor_test_dataset_4: PEHE=1.0376  (MSE=1.076561, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_5.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_5.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  50%|█████████████████                 | 5/10 [00:12<00:12,  2.54s/it]

frontdoor_test_dataset_5: PEHE=0.3033  (MSE=0.091996, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_6.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_6.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  60%|████████████████████▍             | 6/10 [00:15<00:10,  2.51s/it]

frontdoor_test_dataset_6: PEHE=0.7274  (MSE=0.529088, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_7.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_7.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  70%|███████████████████████▊          | 7/10 [00:17<00:07,  2.47s/it]

frontdoor_test_dataset_7: PEHE=0.9460  (MSE=0.894842, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_8.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_8.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  80%|███████████████████████████▏      | 8/10 [00:20<00:04,  2.45s/it]

frontdoor_test_dataset_8: PEHE=0.9471  (MSE=0.897036, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_9.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_9.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets:  90%|██████████████████████████████▌   | 9/10 [00:22<00:02,  2.46s/it]

frontdoor_test_dataset_9: PEHE=1.1463  (MSE=1.314030, samples=205)
Found 1 test CSV files to load:
  - DATA_FD/frontdoor_TEST/frontdoor_test_dataset_10.csv
Loaded 1024 rows from DATA_FD/frontdoor_TEST/frontdoor_test_dataset_10.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 1024


Datasets: 100%|█████████████████████████████████| 10/10 [00:25<00:00,  2.50s/it]

frontdoor_test_dataset_10: PEHE=0.4032  (MSE=0.162564, samples=205)





Unnamed: 0,dataset,samples,mse,pehe
0,frontdoor_test_dataset_1,205,0.541353,0.735767
9,frontdoor_test_dataset_10,205,0.162564,0.403192
1,frontdoor_test_dataset_2,205,0.460074,0.678287
2,frontdoor_test_dataset_3,205,2.375613,1.541302
3,frontdoor_test_dataset_4,205,1.076561,1.037574
4,frontdoor_test_dataset_5,205,0.091996,0.303309
5,frontdoor_test_dataset_6,205,0.529088,0.727384
6,frontdoor_test_dataset_7,205,0.894842,0.945961
7,frontdoor_test_dataset_8,205,0.897036,0.94712
8,frontdoor_test_dataset_9,205,1.31403,1.146311



=== Overall ===
Avg PEHE: 0.8466 ± 0.3429
Best PEHE: 0.3033  (frontdoor_test_dataset_5)
Worst PEHE: 1.5413 (frontdoor_test_dataset_3)

Saved summary to: test_results/summary.csv
