In [1]:
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

In [2]:
from src.tabpfn.model.causalFM import PerFeatureTransformerCATE
from DATA_standard.causaldatasettest import create_test_data_loader

In [7]:
@torch.no_grad()
def eval_one_csv(model, device, csv_path, batch_size=1):
    """
    Evaluate on a single dataset.
    - Uses 80% as the training data and 20% as the test data for one dataset.
    """
    name = os.path.splitext(os.path.basename(csv_path))[0]
    loader = create_test_data_loader(csv_path, batch_size=batch_size, shuffle=False)

    se_list = []
    nsamples = 0

    for batch in loader: # one batch is one CSV file (load all the samples)
        x = batch['X'].to(device)
        a = batch['a'].to(device)
        y = batch['y'].to(device) 
        ite = batch['ite'].to(device)  # ground-truth ITE

        # train test split 
        ratio = 0.8
        train_text_split = int(y.shape[0] * ratio)

        # Split data into train and test
        x_train = x[:train_text_split].squeeze(1)
        a_train = a[:train_text_split].squeeze(1)
        y_train = y[:train_text_split].squeeze(1)

        x_test = x[train_text_split:].squeeze(1)
        ite_test = ite[train_text_split:].squeeze()
        
        # print(f"Train samples: {train_text_split}, Test samples: {ite_test.shape[0]}")
        # print("ite_test.shape:", ite_test.shape)
        
        # Get model predictions using estimate_cate method
        out = model.estimate_cate(x_train, a_train, y_train, x_test)
        cate_test = out['cate']

        # Calculate error
        se = (cate_test - ite_test) ** 2
        se_list.append(se.cpu().numpy())
        nsamples += cate_test.numel()

    if nsamples == 0:
        return {'dataset': name, 'samples': 0, 'mse': np.nan, 'pehe': np.nan}

    se_all = np.concatenate(se_list)
    mse = float(np.mean(se_all))
    pehe = float(np.sqrt(mse))
    return {'dataset': name, 'mse': mse, 'pehe': pehe}


In [4]:
data_dir = "DATA_standard"
data_prefix = "jobs_data"
num_datasets = 1
model_path = 'checkpoints_standard/best_model.pth'
batch_size = 1
gpu = 0
output_dir = 'test_results'

# Create output directory
os.makedirs(output_dir, exist_ok=True)

In [5]:
# Device
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() and gpu >= 0 else 'cpu')
# print(f"Using device: {device}")

# Model
model = PerFeatureTransformerCATE().to(device)
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model checkpoint not found: {model_path}")

ckpt = torch.load(model_path, map_location=device)
state = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
model.load_state_dict(state)
model.eval()
print("Model loaded successfully!")

Model loaded successfully!


In [6]:
# Evaluate each dataset one by one
rows = []
print("\n=== Per-dataset Test PEHE ===")
for i in tqdm(range(1, num_datasets + 1), desc="Datasets"):
    csv_path = os.path.join(data_dir, f"{data_prefix}.csv")
    if not os.path.exists(csv_path):
        print(f"[Skip] Not found: {csv_path}")
        continue

    res = eval_one_csv(model, device, csv_path, batch_size=batch_size)
    rows.append(res)
    # Directly print the result for this CSV
    print(f"{res['dataset']}: PEHE={res['pehe']:.4f}  (MSE={res['mse']:.6f})")
if not rows:
    print("\nNo datasets evaluated. Nothing to save.")
else:
    # Summarize and save
    df = pd.DataFrame(rows).sort_values('dataset')
    summary_path = os.path.join(output_dir, 'summary.csv')
    df.to_csv(summary_path, index=False)
    
    # Display the dataframe
    display(df)
    
    # Print overall statistics
    valid = df['pehe'].dropna().to_numpy()
    if valid.size > 0:
        print("\n=== Overall ===")
        print(f"Avg PEHE: {valid.mean():.4f}")
    print(f"\nSaved summary to: {summary_path}")



=== Per-dataset Test PEHE ===


Datasets:   0%|                                           | 0/1 [00:00<?, ?it/s]

Found 1 test CSV files to load:
  - DATA_standard/jobs_data.csv
Loaded 445 rows from DATA_standard/jobs_data.csv
Total number of test sequences (CSV files): 1
Each test sequence length: 445


Datasets: 100%|███████████████████████████████████| 1/1 [00:01<00:00,  1.77s/it]

jobs_data: PEHE=0.4782  (MSE=0.228690)





Unnamed: 0,dataset,samples,mse,pehe
0,jobs_data,89,0.22869,0.478215



=== Overall ===
Avg PEHE: 0.4782

Saved summary to: test_results/summary.csv
