<a href="https://colab.research.google.com/github/shaigilat/exploring-tabpfn/blob/main/explore_tabpfn_preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Dive into TabPFN's Preprocessing Pipeline


This notebook aims to demystify the internal preprocessing steps of TabPFN (Tabular Neural Network) by breaking down its `_initialize_dataset_preprocessing` method. We will walk through each transformation applied to the raw data, from validation and type fixing to categorical encoding, scaling, and feature shuffling, and observe how each step alters the dataset before it is fed into the neural network.

In [1]:
# @title Installation & Setup
# Install the official TabPFN package
!pip install tabpfn==6.0.6 -q
!pip install shapiq

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import shapiq
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Import TabPFN components
from tabpfn import TabPFNClassifier
from tabpfn.constants import ModelVersion
# Import the specific internal utility functions TabPFN uses
from tabpfn.utils import (
    validate_Xy_fit,
    infer_categorical_features,
    fix_dtypes,
    process_text_na_dataframe
)
from tabpfn.preprocessors.preprocessing_helpers import get_ordinal_encoder

# Set print options for cleaner output
torch.set_printoptions(sci_mode=False, precision=4, linewidth=120)
np.set_printoptions(suppress=True, precision=4)

print("‚úÖ Installation complete. Ready to inspect TabPFN v2.")

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/551.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[90m‚ï∫[0m [32m542.7/551.9 kB[0m [31m27.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m551.9/551.9 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/144.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m144.7/144.7 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00

## Step 0. Dataset Description: Titanic Survival Prediction

For this deep dive, we'll use the classic Titanic dataset, which contains information about passengers on the ill-fated RMS Titanic. The goal is to predict survival based on features like passenger class, sex, age, and embarkation point.

| Feature Name | Type    | NaNs | Description                                     |
|--------------|---------|------|-------------------------------------------------|
| `Pclass`     | int64   | 0    | Passenger Class (1st, 2nd, 3rd)                 |
| `Sex`        | object  | 0    | Gender (male, female)                           |
| `Age`        | float64 | 177  | Age in years                                    |
| `SibSp`      | int64   | 0    | Number of siblings/spouses aboard               |
| `Parch`      | int64   | 0    | Number of parents/children aboard               |
| `Fare`       | float64 | 0    | Passenger fare                                  |
| `Embarked`   | object  | 2    | Port of Embarkation (C, Q, S)                   |
| `Ticket`     | object  | 0    | Ticket number                                   |
| `Name`       | object  | 0    | Passenger's name                                |
| `Cabin`      | object  | 687  | Cabin number                                    |

**Target Variable:**
- `Survived`: Survival status (0 = No, 1 = Yes)

This dataset is ideal for demonstrating TabPFN's preprocessing due to its mix of numerical, categorical, and missing data.


In [2]:
# @title Raw Data Loading
import pandas as pd # Import pandas to resolve NameError
from sklearn.model_selection import train_test_split

url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
print(f"Loading dataset from: {url}")
df = pd.read_csv(url)

# Select a mix of features to match your requirements:
# Select a mix of features (Numeric, Categorical, some with NaNs)
features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked', 'Ticket', 'Name', 'Cabin']
target = 'Survived'

X = df[features]
y = df[target]

# Split data into training and testing sets
X_raw, X_test, y_raw, y_test = train_test_split(X, y, test_size=0.1, random_state=42, stratify=y)

print("\nDataset Preview (Training Data):")
print(X_raw.head().to_string())
print(f"\nOriginal Shape: {df.shape}")
print(f"Training Data Shape: {X_raw.shape}")
print(f"Testing Data Shape: {X_test.shape}")
# Calculate and print the percentage of survived in the test data
print(f"Percentage Survived in Train Data: {y_raw.sum() / len(y_raw):.1%}%")
print(f"Percentage Survived in Test Data: {y_test.sum() / len(y_test):.1%}%")
print("-" * 50 + "\n")

clf = TabPFNClassifier.create_default_for_version(
    ModelVersion.V2,
    device='cpu',
    n_estimators=4
)
# Fit on a small subset of the training data to initialize the model
clf.fit(X_raw.iloc[:50], y_raw.iloc[:50])

Loading dataset from: https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv

Dataset Preview (Training Data):
     Pclass     Sex   Age  SibSp  Parch     Fare Embarked      Ticket                               Name Cabin
86        3    male  16.0      1      3  34.3750        S  W./C. 6608             Ford, Mr. William Neal   NaN
329       1  female  16.0      0      1  57.9792        C      111361       Hippach, Miss. Jean Gertrude   B18
517       3    male   NaN      0      0  24.1500        Q      371110                  Ryan, Mr. Patrick   NaN
844       3    male  17.0      0      0   8.6625        S      315090                Culumovic, Mr. Jeso   NaN
408       3    male  21.0      0      0   7.7750        S      312992  Birkeland, Mr. Hans Martin Monsen   NaN

Original Shape: (891, 12)
Training Data Shape: (801, 10)
Testing Data Shape: (90, 10)
Percentage Survived in Train Data: 38.3%%
Percentage Survived in Test Data: 38.9%%
------------------------------

tabpfn-v2-classifier-finetuned-zk73skhh.(‚Ä¶):   0%|          | 0.00/29.0M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/37.0 [00:00<?, ?B/s]

## Step 1. Inital Preprocessing: breaking down  `_initialize_dataset_preprocessing`

TabPFN's `_initialize_dataset_preprocessing` method orchestrates the initial data preparation, setting up the foundation for all subsequent ensemble views. It performs several crucial steps:

1.  **Validation & Cleaning:** Ensures data integrity by checking for disallowed types, infinities, and basic structural requirements.
2.  **Target (Label) Encoding:** Transforms the target variable into a numerical format (0 to N-1 classes).
3.  **Categorical Inference & Type Fixing:** Identifies categorical features (if not provided) and casts them to a 'category' dtype, also handling object types.
4.  **Ordinal Encoding & String Vectorization:** Converts categorical and string columns into numerical representations suitable for the model, including handling missing string values.

In [30]:
# @title Step 1.0: Show Raw Data

def display_dataset_state(df, title, additional_info=None):
    """
    Displays the head, shape, null counts, and dtypes of a DataFrame.
    """
    print(f"\n{title}")
    print("=" * 50)

    # Ensure df is a pandas DataFrame for consistent handling
    if not isinstance(df, pd.DataFrame):
        df = pd.DataFrame(df)

    # 1. Basic Shape
    print(f"Shape: {df.shape}")

    # 2. Nulls & Dtypes Summary
    print("\nColumn Summary:")
    summary = pd.DataFrame({
        'Dtype': df.dtypes,
        'Null Count': df.isnull().sum(),
        'Null %': (df.isnull().sum() / len(df) * 100).round(2)
    })
    print(summary)

    # 3. Head
    print("\nHead of DataFrame:")
    print(df.head().to_string())

    # 4. Additional Info
    if additional_info:
        print("\nAdditional Info:")
        for key, value in additional_info.items():
            print(f"- {key}: {value}")

    print("=" * 50)

display_dataset_state(X_raw, "Raw Dataset")


Raw Dataset
Shape: (801, 10)

Column Summary:
            Dtype  Null Count  Null %
Pclass      int64           0    0.00
Sex        object           0    0.00
Age       float64         161   20.10
SibSp       int64           0    0.00
Parch       int64           0    0.00
Fare      float64           0    0.00
Embarked   object           2    0.25
Ticket     object           0    0.00
Name       object           0    0.00
Cabin      object         620   77.40

Head of DataFrame:
     Pclass     Sex   Age  SibSp  Parch     Fare Embarked      Ticket                               Name Cabin
86        3    male  16.0      1      3  34.3750        S  W./C. 6608             Ford, Mr. William Neal   NaN
329       1  female  16.0      0      1  57.9792        C      111361       Hippach, Miss. Jean Gertrude   B18
517       3    male   NaN      0      0  24.1500        Q      371110                  Ryan, Mr. Patrick   NaN
844       3    male  17.0      0      0   8.6625        S      315090  

In [None]:
# @title Steps 1.1-1.4: Run the full initial dataset preprocessing
# We use the internal method to reproduce the preprocessing exactly
# This returns the numpy array that TabPFN actually uses

print("--- Start Initial Dataset Preprocessing Steps ---")
print("==================================================")

# --- Step 1.1: Validate dataset ---
print("\n--- Step 1.1: Validate dataset ---")
print("  Description: Checks for disallowed types, infinite values, and ensures basic structure.")
print(f"Maximum number of samples for inference: {clf.inference_config_.MAX_NUMBER_OF_SAMPLES}")
print(f"Maximum number of features for inference: {clf.inference_config_.MAX_NUMBER_OF_FEATURES}")
print(f"Ignore pretraining limits: {clf.ignore_pretraining_limits}")
X_valid, y_valid, feature_names_in, n_features_in = validate_Xy_fit(
            X_raw,
            y_raw,
            estimator=clf,
            ensure_y_numeric=False,
            max_num_samples=clf.inference_config_.MAX_NUMBER_OF_SAMPLES,
            max_num_features=clf.inference_config_.MAX_NUMBER_OF_FEATURES,
            ignore_pretraining_limits=clf.ignore_pretraining_limits,
        )
print("Status: Validation Successful (Data copied)")
display_dataset_state(X_valid, "Step 1.1 Output (After Validation)", additional_info={'Shape after validation': X_valid.shape})

# --- Step 1.2: Target encoding ---
print("\n--- Step 1.2: Target encoding ---")
print("  Description: Converts target classes (e.g., 'Yes'/'No' or 0/1) into strict integers 0..N-1.")
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y_valid)
# X_encoded is just X_valid copied at this stage for consistent variable naming
X_encoded = X_valid.copy()
print(f"y (After):  {y_encoded[:10]}")
print(f"Classes:    {label_encoder.classes_}")
display_dataset_state(pd.DataFrame(X_encoded), "Step 1.2 Output (X after target encoding)", additional_info={'y_encoded_head': y_encoded[:5]})

# --- Step 1.3: Identify Categorical variables and Fix Data Types ---
print("\n--- Step 1.3: Identify Categorical variables and Fix Data Types ---")
print("  Description: Guesses categorical columns and casts them to pandas 'category' dtype, also handling object types.")
print(f"Minimum number of samples required for categorical inference: {clf.inference_config_.MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE}")
print(f"Maximum number of unique values for a feature to be considered categorical: {clf.inference_config_.MAX_UNIQUE_FOR_CATEGORICAL_FEATURES}")
print(f"Minimum number of unique values for a feature to be considered numerical: {clf.inference_config_.MIN_UNIQUE_FOR_NUMERICAL_FEATURES}")
cat_indices = infer_categorical_features(
    X_encoded,
    provided=clf.categorical_features_indices,
    min_samples_for_inference=clf.inference_config_.MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE,
    max_unique_for_category=clf.inference_config_.MAX_UNIQUE_FOR_CATEGORICAL_FEATURES,
    min_unique_for_numerical=clf.inference_config_.MIN_UNIQUE_FOR_NUMERICAL_FEATURES,
)
X_fixed_dtypes = fix_dtypes(X_encoded, cat_indices=cat_indices)
y_fixed_dtypes = y_encoded.copy()
display_dataset_state(X_fixed_dtypes, "Step 1.3 Output (After Fixing Dtypes)", additional_info={'Inferred Categorical Indices': cat_indices})

# --- Step 1.4: Ordinal Encoding & Vectorization ---
print("\n--- Step 1.4: Ordinal Encoding & Vectorization ---")
print("  Description: Converts categorical and string columns into numerical representations suitable for the model, including handling missing string values.")
ord_encoder = get_ordinal_encoder()
X_fixed_strs = process_text_na_dataframe(X_fixed_dtypes, ord_encoder=ord_encoder, fit_encoder=True)
y_fixed_strs = y_fixed_dtypes.copy()
display_dataset_state(pd.DataFrame(X_fixed_strs), "Step 1.4 Output (After Ordinal Encoding)")
print("Action: 'Sex' and 'Embarked' were ordinal encoded. String columns ('Ticket', 'Name', 'Cabin') were also vectorized.")
print("Action: Column order might have changed due to preprocessing.")

print("==================================================")
print("--- End Initial Dataset Preprocessing Steps ---")

--- Start Initial Dataset Preprocessing Steps ---

--- Step 1.1: Validate dataset ---
  Description: Checks for disallowed types, infinite values, and ensures basic structure.
Maximum number of samples for inference: 10000
Maximum number of features for inference: 500
Ignore pretraining limits: False
Status: Validation Successful (Data copied)

Step 1.1 Output (After Validation)
Shape: (801, 10)

Column Summary:
    Dtype  Null Count  Null %
0  object           0    0.00
1  object           0    0.00
2  object         161   20.10
3  object           0    0.00
4  object           0    0.00
5  object           0    0.00
6  object           2    0.25
7  object           0    0.00
8  object           0    0.00
9  object         620   77.40

Head of DataFrame:
   0       1     2  3  4        5  6           7                                  8    9
0  3    male  16.0  1  3   34.375  S  W./C. 6608             Ford, Mr. William Neal  NaN
1  1  female  16.0  0  1  57.9792  C      111361       H

Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client


In [39]:
# @title Run the `_initialize_dataset_preprocessing`
# This is what the original _initialize_dataset_preprocessing returns after these steps
# (Note: The original _initialize_dataset_preprocessing might internally handle copying and numpy conversion.)
_, X_processed, y_processed = clf._initialize_dataset_preprocessing(X_raw, y_raw, rng=42)
display_dataset_state(pd.DataFrame(X_processed), "Combined Output from _initialize_dataset_preprocessing (for comparison)")


Combined Output from _initialize_dataset_preprocessing (for comparison)
Shape: (801, 10)

Column Summary:
     Dtype  Null Count  Null %
0  float64           0    0.00
1  float64           0    0.00
2  float64           0    0.00
3  float64           0    0.00
4  float64           0    0.00
5  float64           0    0.00
6  float64         163   20.35
7  float64           0    0.00
8  float64           0    0.00
9  float64         620   77.40

Head of DataFrame:
     0    1    2      3      4      5     6    7    8        9
0  2.0  1.0  2.0  620.0  227.0  136.0  16.0  1.0  3.0      NaN
1  0.0  0.0  0.0    7.0  325.0   16.0  16.0  0.0  1.0  57.9792
2  2.0  1.0  1.0  433.0  642.0  136.0   NaN  0.0  0.0      NaN
3  2.0  1.0  2.0  248.0  164.0  136.0  17.0  0.0  0.0      NaN
4  2.0  1.0  2.0  239.0   75.0  136.0  21.0  0.0  0.0      NaN


Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client
  X_raw: single or list of input dataset features, in case of single it


## Step 2. TabPFN's 'Ensemble Views' & Preprocessing

TabPFN enhances model robustness by generating multiple 'ensemble views' of the same dataset. Each view is a unique transformation of the input data, providing diverse perspectives to the neural network, akin to observing an object from different angles. This diversity prevents the model from relying on a single data representation and improves generalization.

### Preprocessing Pipeline for Each View:

Each view undergoes a specific four-step pipeline before reaching the neural network:

1.  **Remove Constant Features:** Eliminates columns with no variance (all identical values) as they offer no predictive power and can disrupt some transformations.

2.  **Reshape Distributions:** Applies various scaling (e.g., quantile or power transformations) to numerical features. This crucial step diversifies the data's representation, preparing it for the neural network.

3.  **Encode Categorical Features:** Further processes categorical variables, often by shuffling their assigned integer IDs. This randomization prevents the network from inferring meaning from arbitrary numerical order.

4.  **Shuffle Features:** Randomizes the column order. This ensures the model learns from feature content rather than their position, enhancing robustness..

In [31]:
# @title Step 2.0: Show preprocessed dataset, and select configuration

from tabpfn.preprocessors import (
    RemoveConstantFeaturesStep,
    ReshapeFeatureDistributionsStep,
    EncodeCategoricalFeaturesStep,
    ShuffleFeaturesStep,
    AddFingerprintFeaturesStep,
    NanHandlingPolynomialFeaturesStep
)

# --- Step 2.0: Setup & Configuration Selection ---
# 1. Get the global preprocessed data and the list of configurations for all views
ensemble_configs, X_processed, y_processed = clf._initialize_dataset_preprocessing(X_raw, y_raw, rng=42)
display_dataset_state(X_processed, "Step 1 Output (Global Preprocessed Data)")

# 2. Identify Categorical Indices
#    The first transformer [0] is always the 'ordinal encoder'.
#    The third element [2] contains the actual list of columns it selected.
input_cat_cols = clf.preprocessor_.transformers_[0][2]
#    Because Step 1 moves these columns to the front, their indices are 0 to N-1
cat_ix = list(range(len(input_cat_cols)))
print("\n\nEnsemble Configurations:\n" + "=" * 40)

for i, config in enumerate(ensemble_configs):
    preprocess_name = config.preprocess_config.name if config.preprocess_config else 'N/A'
    categorical_name = config.preprocess_config.categorical_name if config.preprocess_config else 'N/A'
    feature_shift_decoder = config.feature_shift_decoder
    poly_features = config.polynomial_features
    fingerprint = config.add_fingerprint_feature
    class_perm_display = 'N/A'
    if hasattr(config, 'class_permutation') and config.class_permutation is not None:
        class_perm_display = config.class_permutation.tolist()

    print(f"View {i+1}:\n" \
          f"  - Polynomial Features: {poly_features}\n" \
          f"  - Preprocessing: {preprocess_name}\n" \
          f"  - Categorical Encoding: {categorical_name}\n" \
          f"  - Fingerprint Feature: {fingerprint}\n" \
          f"  - Feature Shuffle Method: {feature_shift_decoder}\n" \
          f"  - Class Permutation: {class_perm_display}\n")

# 3. Select one specific view (estimator) to visualize
view_num = 0
curr_seed = 42
curr_config = ensemble_configs[view_num]
print(f"Selected View {view_num+1} for Detailed Walkthrough")


Step 1 Output (Global Preprocessed Data)
Shape: (801, 10)

Column Summary:
     Dtype  Null Count  Null %
0  float64           0    0.00
1  float64           0    0.00
2  float64           0    0.00
3  float64           0    0.00
4  float64           0    0.00
5  float64           0    0.00
6  float64         163   20.35
7  float64           0    0.00
8  float64           0    0.00
9  float64         620   77.40

Head of DataFrame:
     0    1    2      3      4      5     6    7    8        9
0  2.0  1.0  2.0  620.0  227.0  136.0  16.0  1.0  3.0      NaN
1  0.0  0.0  0.0    7.0  325.0   16.0  16.0  0.0  1.0  57.9792
2  2.0  1.0  1.0  433.0  642.0  136.0   NaN  0.0  0.0      NaN
3  2.0  1.0  2.0  248.0  164.0  136.0  17.0  0.0  0.0      NaN
4  2.0  1.0  2.0  239.0   75.0  136.0  21.0  0.0  0.0      NaN


Ensemble Configurations:
View 1:
  - Polynomial Features: no
  - Preprocessing: quantile_uni_coarse
  - Categorical Encoding: ordinal_very_common_categories_shuffled
  - Fingerprint F

Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client
  X_raw: single or list of input dataset features, in case of single it


In [32]:
# @title Step 2.1-2.6: Full Ensemble View Preprocessing Pipeline

print("--- Start Ensemble View Preprocessing Steps ---")
print("==================================================")

# --- Step 2.1: Polynomial Features ---
print("\n--- Step 2.1: Polynomial Features ---")
print("  Description: Generates interaction features (e.g., A*B) if enabled in config.")
poly_feats = curr_config.polynomial_features
print(f"  üî∏ Config for NanHandlingPolynomialFeaturesStep:")
print(f"    - Polynomial Features Setting: {poly_feats}")

# Determine if we run it (logic from src/tabpfn/preprocessing.py)
use_poly = poly_feats != "no"
max_poly = None if poly_feats == "all" else (poly_feats if isinstance(poly_feats, int) else None)

if use_poly:
    step_2_1 = NanHandlingPolynomialFeaturesStep(
        max_features=max_poly,
        random_state=curr_seed
    )
    # Input is X_processed from Step 1
    res_2_1 = step_2_1.fit_transform(X_processed, cat_ix)
    X_2_1 = res_2_1.X
    cats_ix_2_1 = res_2_1.categorical_features
    display_dataset_state(X_2_1, "Step 2.1 Output (After Polynomial Features)")
else:
    print("Action: Polynomial features disabled (Pass through).")
    X_2_1 = X_processed
    cats_ix_2_1 = cat_ix


# --- Step 2.2: Remove Constant Features ---
print("\n--- Step 2.2: Remove Constant Features ---")
print("  Description: Features with 0 variance (all same value) provide no info and break some scalers.")
step_2_2 = RemoveConstantFeaturesStep()
# Input is X_2_1
res_2_2 = step_2_2.fit_transform(X_2_1, cats_ix_2_1)
X_2_2 = res_2_2.X
cats_ix_2_2 = res_2_2.categorical_features
display_dataset_state(X_2_2, "Step 2.2 Output (After Removing Constant Features)")

if X_2_2.shape[1] < X_2_1.shape[1]:
    print("Action: Removed constant columns!")
else:
    print("Action: No constant columns found (Pass through).")


# --- Step 2.3: Reshape Distributions (The "View" Creator) ---
print("\n--- Step 2.3: Reshape Distributions (The 'View' Creator) ---")
print("  Description: This step applies various scaling (Quantile, Power, or None) to numerical features.")
transform_name = curr_config.preprocess_config.name
append_to_original = curr_config.preprocess_config.append_original
max_features_per_estimator = curr_config.preprocess_config.max_features_per_estimator
global_transformer_name = curr_config.preprocess_config.global_transformer_name
apply_to_categorical = (curr_config.preprocess_config.categorical_name == "numeric")

print(f"  üî∏ Config for ReshapeFeatureDistributionsStep:")
print(f"    - Transformation Type: {transform_name}")
print(f"    - Append Original Features: {append_to_original}")
print(f"    - Max Features per Estimator: {max_features_per_estimator}")

step_2_3 = ReshapeFeatureDistributionsStep(
    transform_name=transform_name,
    append_to_original=append_to_original,
    max_features_per_estimator=max_features_per_estimator,
    global_transformer_name=global_transformer_name,
    apply_to_categorical=apply_to_categorical,
    random_state=curr_seed
)
# Input is X_2_2
res_2_3 = step_2_3.fit_transform(X_2_2, cats_ix_2_2)
X_2_3 = res_2_3.X
cats_ix_2_3 = res_2_3.categorical_features
display_dataset_state(X_2_3, "Step 2.3 Output (After Scaling)")


# --- Step 2.4: Shuffle Categorical Encoding ---
print("\n--- Step 2.4: Shuffle Categorical Encoding ---")
print("  Description: Handles model-specific categorical encoding (often permuting integer IDs).")
cat_method = curr_config.preprocess_config.categorical_name
print(f"  üî∏ Config for EncodeCategoricalFeaturesStep:")
print(f"    - Categorical Transform Name: {cat_method}")

step_2_4 = EncodeCategoricalFeaturesStep(
    categorical_transform_name=cat_method,
    random_state=curr_seed
)
# Input is X_2_3
res_2_4 = step_2_4.fit_transform(X_2_3, cats_ix_2_3)
X_2_4 = res_2_4.X
cats_ix_2_4 = res_2_4.categorical_features
display_dataset_state(X_2_4, "Step 2.4 Output (After Categorical Encoding)")


# --- Step 2.5: Add Fingerprint Feature ---
print("\n--- Step 2.5: Add Fingerprint Feature ---")
print("  Description: Adds a hashed 'ID' feature to help the Transformer track rows across shuffling.")
add_fp = curr_config.add_fingerprint_feature
print(f"  üî∏ Config for AddFingerprintFeaturesStep:")
print(f"    - Fingerprint Enabled: {add_fp}")

if add_fp:
    step_2_5 = AddFingerprintFeaturesStep(random_state=curr_seed)
    # Input is X_2_4
    res_2_5 = step_2_5.fit_transform(X_2_4, cats_ix_2_4)
    X_2_5 = res_2_5.X
    cats_ix_2_5 = res_2_5.categorical_features
    display_dataset_state(X_2_5, "Step 2.5 Output (After Fingerprinting)")
    print("Action: Added a high-cardinality float column (hash of the row).")
else:
    print("Action: Fingerprinting disabled.")
    X_2_5 = X_2_4
    cats_ix_2_5 = cats_ix_2_4


# --- Step 2.6: Shuffle Column Order ---
print("\n--- Step 2.6: Shuffle Column Order ---")
print("  Description: Randomizes column order so the model isn't dependent on feature position.")
print(f"  üî∏ Config for ShuffleFeaturesStep:")
print(f"    - Shuffle Method: {curr_config.feature_shift_decoder}")
print(f"    - Shuffle Index: {curr_config.feature_shift_count}")

step_2_6 = ShuffleFeaturesStep(
    shuffle_method=curr_config.feature_shift_decoder,
    shuffle_index=curr_config.feature_shift_count,
    random_state=curr_seed
)
# Input is X_2_5
res_2_6 = step_2_6.fit_transform(X_2_5, cats_ix_2_5)
X_2_6 = res_2_6.X
cats_ix_2_6 = res_2_6.categorical_features
display_dataset_state(X_2_6, "Step 2.6 Output (After Column Shuffling)")

print(f"Shuffle Map: {step_2_6.index_permutation_}")
print("Action: The columns (including the fingerprint!) are now scrambled.")


print("\n==================================================")
print("--- End Ensemble View Preprocessing Steps ---")

# Final assignment
X_view = X_2_6

--- Start Ensemble View Preprocessing Steps ---

--- Step 2.1: Polynomial Features ---
  Description: Generates interaction features (e.g., A*B) if enabled in config.
  üî∏ Config for NanHandlingPolynomialFeaturesStep:
    - Polynomial Features Setting: no
Action: Polynomial features disabled (Pass through).

--- Step 2.2: Remove Constant Features ---
  Description: Features with 0 variance (all same value) provide no info and break some scalers.

Step 2.2 Output (After Removing Constant Features)
Shape: (801, 10)

Column Summary:
     Dtype  Null Count  Null %
0  float64           0    0.00
1  float64           0    0.00
2  float64           0    0.00
3  float64           0    0.00
4  float64           0    0.00
5  float64           0    0.00
6  float64         163   20.35
7  float64           0    0.00
8  float64           0    0.00
9  float64         620   77.40

Head of DataFrame:
     0    1    2      3      4      5     6    7    8        9
0  2.0  1.0  2.0  620.0  227.0  136.0 

In [40]:
# @title Run the step with view's pipeline
# We use the internal method to reproduce the preprocessing exactly
# This returns the numpy array that TabPFN actually uses

curr_pipeline = curr_config.to_pipeline(random_state=42)

res = curr_pipeline.fit_transform(X_processed, cat_ix)

X_view = res.X

display_dataset_state(X_view, "Step 2 Output - View 1")


Step 2 Output - View 1
Shape: (801, 20)

Column Summary:
      Dtype  Null Count  Null %
0   float64           0    0.00
1   float64         620   77.40
2   float64           0    0.00
3   float64           0    0.00
4   float64           0    0.00
5   float64         163   20.35
6   float64         163   20.35
7   float64           0    0.00
8   float64           0    0.00
9   float64           0    0.00
10  float64           0    0.00
11  float64           0    0.00
12  float64           0    0.00
13  float64           0    0.00
14  float64           0    0.00
15  float64           0    0.00
16  float64           0    0.00
17  float64           0    0.00
18  float64         620   77.40
19  float64           0    0.00

Head of DataFrame:
         0        1         2    3         4         5     6         7      8    9         10     11        12        13   14     15        16   17        18   19
0  3.760629      NaN  8.204553  1.0  0.987342  0.132911  16.0  0.557488  620.0  0.0  1.

## Step 3. Transform Dataset View to Embedding Tensor

Each ensemble view undergoes a sequence of transformations to convert the table data into high-dimensional embeddings for the neural network.

### 1. Input Reshaping & Flattening
The raw features are restructured into the format expected by the Transformer backbone.
* **Initial Reshaping:** Convert to `(Sequence, Batch, Features)`. Example: `(801, 1, 20)`.
* **Padding:** If features aren't divisible by the group size (2), zero-padding columns are added.
* **Grouping & Flattening:** Features are grouped, and the Batch dimension is merged with Groups to form an "Effective Batch."
    * *Transformation:* `(Seq, Batch, Feat) -> (Seq, Batch, Groups, GrpSize) -> (Seq, FlatBatch, GrpSize)`
    * *Example:* `(801, 1, 20) -> (801, 1, 10, 2) -> (801, 10, 2)`.

### 2. Encoder Pipeline (Steps 3.1 - 3.6)
The encoder processes these feature groups through a sequence of steps:

* **Step 3.1: Remove Constant Features** (`RemoveEmptyFeaturesEncoderStep`)
    Drops features with zero variance across the sequence.
* **Step 3.2: NaN Handling** (`NanHandlingEncoderStep`)
    Imputes missing values with the mean and adds binary **NaN indicator columns**.
    * *Dimension Change:* Group size doubles from **2** to **4** (2 values + 2 indicators).
* **Step 3.3: Variable Feature Adjustment** (`VariableNumFeaturesEncoderStep`)
    Aligns data structure with model expectations for variable feature counts.
* **Step 3.4: Normalization** (`InputNormalizationEncoderStep`)
    Standardizes features (mean 0, variance 1) and clips extreme outliers.
* **Step 3.5: Secondary Adjustment** (`VariableNumFeaturesEncoderStep`)
    Ensures consistency after normalization.
* **Step 3.6: Linear Embedding** (`LinearInputEncoderStep`)
    Projects the processed groups into the model's hidden dimension.
    * *Transformation:* Linear layer maps size **4** to **192**.
    * *Final Shape:* `(801, 10, 192)`.

The result is a `(801, 10, 192)` tensor ready for the Transformer.

In [43]:
# @title Step 3.0: Convert dataset to flat tensor

import torch
import torch.nn.functional as F
import einops

# Helper function to reconstruct the 2D view (Batch x Features)
def view_tensor(tensor, name):
    reconstructed = einops.rearrange(tensor, "s f n -> s (f n)")

    # Extract first 5 rows (Passengers) directly from the Sequence dim
    # Shape is (891, 18) -> slice -> (5, 18)
    df_data = reconstructed[:5, :].detach().numpy()

    df = pd.DataFrame(df_data)

    print(f"\n{name}")
    print(f"   Shape: {tensor.shape}")
    print(f"   View:  (Showing 5 rows x {df.shape[1]} cols)")
    print(df.to_string(index=False, float_format="%.4f"))

def view_embedding(tensor, name):
    # Shape: (Seq=891, Groups=9, Dim=192)

    # We flatten Groups and Dims to see the full "Row Vector"
    # s f d -> s (f d)
    # 891 x 9 x 192 -> 891 x 1728
    flattened = einops.rearrange(tensor, "s f d -> s (f d)")

    print(f"\n{name}")
    print(f"   Shape: {tensor.shape} (Seq, Groups, Dim)")
    print(f"   Flattened View: {flattened.shape} (Seq, Total_Embedding_Size)")

    # Show first 5 passengers, first 8 dimensions
    df = pd.DataFrame(flattened[:5, :].detach().numpy())
    df.columns = [f"Dim_{i}" for i in range(df.shape[1])]

    print("   Preview (First 8 dims of Group 0):")
    print(df.iloc[:, :8].to_string(index=False, float_format="%.4f"))
    print(f"   ... (+ {df.shape[1]-8} more columns) ...")


X_tensor = torch.as_tensor(X_view, dtype=torch.float32).unsqueeze(0)
model = clf.models_[0]
encoder_steps = model.encoder
n_features_per_group = model.features_per_group

print(f"Model requirement: Groups of {n_features_per_group} features")
print(f"\nInitial X_view shape: {X_view.shape} (Seq, Feat)")
print(f"X_tensor shape after unsqueeze(0): {X_tensor.shape} (Batch, Seq, Feat)")

# 1. PAD (Must be divisible by Group Size)
print("\n--- Padding Step ---")
remainder = X_tensor.shape[-1] % n_features_per_group
if remainder > 0:
  padding = n_features_per_group - remainder
  print(f"‚ö†Ô∏è Padding with {padding} zero-column(s) to make features divisible by {n_features_per_group}...")
  X_padded = F.pad(X_tensor, (0, padding))
else:
  print(f"No Padding Needed, features already divisible by {n_features_per_group}")
  X_padded = X_tensor
print(f"X_padded shape: {X_padded.shape} (Batch, Seq, PaddedFeat)")

# ==========================================
# 2. FLATTEN
# ==========================================
print("\n--- Flattening Step ---")
# Current Shape: (Batch, Sequence, PaddedFeatures) -> (1, 801, 20)

# A. Group features: Rearrange into (Batch, Sequence, NumGroups, GroupSize)
# Example: (1, 801, 20) -> (1, 801, 10, 2)
X_grouped = einops.rearrange(
    X_padded,
    "b s (f n) -> b s f n",
    n=n_features_per_group
)
print(f"X_grouped shape (Batch, Seq, NumGroups, GroupSize): {X_grouped.shape}")

# B. Flatten Batch and Groups: Rearrange into (Sequence, FlatBatch, GroupSize)
# Here, 'FlatBatch' = Batch * NumGroups
# Example: (1, 801, 10, 2) -> (801, 1 * 10, 2) = (801, 10, 2)
X_flat = einops.rearrange(X_grouped, "b s f n -> s (b f) n")

current_inputs = {"main": X_flat}
single_eval_pos = 0 # Predict using the whole sequence statistics
original_batch = 1  # We now effectively have 1 batch

print(f"X_flat shape (Sequence, FlatBatch (Batch*NumGroups), GroupSize): {X_flat.shape}")
print(f"   (Sequence={X_flat.shape[0]}, FlatBatch (1*{X_grouped.shape[2]})={X_flat.shape[1]}, GroupSize={X_flat.shape[2]})")


Model requirement: Groups of 2 features

Initial X_view shape: (801, 20) (Seq, Feat)
X_tensor shape after unsqueeze(0): torch.Size([1, 801, 20]) (Batch, Seq, Feat)

--- Padding Step ---
No Padding Needed, features already divisible by 2
X_padded shape: torch.Size([1, 801, 20]) (Batch, Seq, PaddedFeat)

--- Flattening Step ---
X_grouped shape (Batch, Seq, NumGroups, GroupSize): torch.Size([1, 801, 10, 2])
X_flat shape (Sequence, FlatBatch (Batch*NumGroups), GroupSize): torch.Size([801, 10, 2])
   (Sequence=801, FlatBatch (1*10)=10, GroupSize=2)


In [None]:
# @title Steps 3.1-3.6: Encoder Pipeline Walkthrough
# Initial input before encoder steps
view_tensor(X_flat, "RAW INPUT: X_flat")

# --- Step 3.1: NanHandlingEncoderStep (Remove Empty Features) ---
step_empty = model.encoder[0]
print(f"\n--- Step 3.1: {step_empty.__class__.__name__} ---")
print(f"  Description: {step_empty.__doc__.strip().split('.')[0].strip() if step_empty.__doc__ else 'No description available.'}")
out_dict_step_3_1 = step_empty({"main": X_flat}, single_eval_pos=single_eval_pos)
view_tensor(out_dict_step_3_1["main"], "STEP 3.1 Output (Remove Empty Features)")

# --- Step 3.2: NanHandlingEncoderStep (NaN Handling) ---
step_nan = model.encoder[1]
print(f"\n--- Step 3.2: {step_nan.__class__.__name__} ---")
print(f"  Description: {step_nan.__doc__.strip().split('.')[0].strip() if step_nan.__doc__ else 'No description available.'}")
print(f"  - `keep_nans`: {step_nan.keep_nans} (If True, adds separate indicator columns for NaNs/Infs)")
out_dict_step_3_2 = step_nan(out_dict_step_3_1, single_eval_pos=single_eval_pos)
view_tensor(out_dict_step_3_2["main"], "STEP 3.2 Output (NaN Handling)")

# --- Step 3.3: MetaFeaturesEncoderStep (Meta-feature generation) ---
step_meta1 = model.encoder[2]
print(f"\n--- Step 3.3: {step_meta1.__class__.__name__} ---")
print(f"  Description: {step_meta1.__doc__.strip().split('.')[0].strip() if step_meta1.__doc__ else 'No description available.'}")
print(f"  - This step typically adds meta-features, such as indicators for original NaN positions, to provide more context to the model.")
out_dict_step_3_3 = step_meta1(out_dict_step_3_2, single_eval_pos=single_eval_pos)

# --- Step 3.4: InputNormalizationEncoderStep (Normalization) ---
step_norm = model.encoder[3]
print(f"\n--- Step 3.4: {step_norm.__class__.__name__} ---")
print(f"  Description: {step_norm.__doc__.strip().split('.')[0].strip() if step_norm.__doc__ else 'No description available.'}")
print(f"  - This step normalizes features to a standard scale (e.g., mean 0, variance 1).")

# Corrected attributes for v2
print(f"  - `normalize_x`: {step_norm.normalize_x} (If True, performs standard mean/std normalization.)")
print(f"  - `remove_outliers`: {step_norm.remove_outliers} (If True, outliers are clipped based on sigma.)")
if step_norm.remove_outliers:
    print(f"  - `remove_outliers_sigma`: {step_norm.remove_outliers_sigma}")

# Corrected buffer names for statistics
if hasattr(step_norm, 'mean_for_normalization') and step_norm.mean_for_normalization is not None:
    # Detach and flatten for clean printing
    mean_val = step_norm.mean_for_normalization.detach().cpu().numpy().flatten()
    print(f"  - `mean` (first 5 values): {mean_val[:5]} (Learned means.)")

if hasattr(step_norm, 'std_for_normalization') and step_norm.std_for_normalization is not None:
    std_val = step_norm.std_for_normalization.detach().cpu().numpy().flatten()
    print(f"  - `std` (first 5 values): {std_val[:5]} (Learned standard deviations.)")
out_dict_step_3_4 = step_norm(out_dict_step_3_3, single_eval_pos=single_eval_pos)
view_tensor(out_dict_step_3_4["main"], "STEP 3.4 Output (Normalization)")

# --- Step 3.5: MetaFeaturesEncoderStep (Meta-feature generation) ---
step_meta2 = model.encoder[4]
print(f"\n--- Step 3.5: {step_meta2.__class__.__name__} ---")
print(f"  Description: {step_meta2.__doc__.strip().split('.')[0].strip() if step_meta2.__doc__ else 'No description available.'}")
print(f"  - This step adds further meta-features after normalization, providing additional context derived from the transformed data.")
out_dict_step_3_5 = step_meta2(out_dict_step_3_4, single_eval_pos=single_eval_pos)

# --- Step 3.6: LinearEncoderStep (Linear Embedding) ---
step_embedding = model.encoder[5]
print(f"\n--- Step 3.6: {step_embedding.__class__.__name__} ---")
print(f"  Description: {step_embedding.__doc__.strip().split('.')[0].strip() if step_embedding.__doc__ else 'No description available.'}")
print(f"  - This step applies a linear transformation to map the processed feature groups into a higher-dimensional embedding space.")

# [CHANGE] Access attributes via .layer
print(f"  - `in_features`: {step_embedding.layer.in_features} (Input dimension: Size of feature group + NaN indicators [2+2])")
print(f"  - `out_features`: {step_embedding.layer.out_features} (Output dimension: Dimensionality of the learned embedding.)")
print(f"  - `weight` shape: {step_embedding.layer.weight.shape} (Learned weights matrix.)")

if step_embedding.layer.bias is not None:
    print(f"  - `bias` shape: {step_embedding.layer.bias.shape} (Learned bias vector.)")

out_dict_step_3_6 = step_embedding(out_dict_step_3_5, single_eval_pos=single_eval_pos)
view_embedding(out_dict_step_3_6["output"], "STEP 3.6 Output (Linear Embedding)")


RAW INPUT: X_flat
   Shape: torch.Size([801, 10, 2])
   View:  (Showing 5 rows x 20 cols)
     0       1      2      3      4      5       6      7        8      9       10       11     12      13     14       15      16     17     18     19
 3.7606     NaN 8.2046 1.0000 0.9873 0.1329 16.0000 0.5575 620.0000 0.0000  1.1015 136.0000 0.8038 -0.2178 2.0000 227.0000 -0.7573 1.0000    NaN 3.0000
 1.4647 57.9792 3.7324 0.0000 0.8228 0.1329 16.0000 0.8695   7.0000 2.0000 -2.5159  16.0000 0.0000  0.9315 0.0000 325.0000  0.9160 0.0000 0.5327 1.0000
-1.0227     NaN 8.0700 0.0000 0.0000    NaN     NaN 0.8100 433.0000 0.0000  0.9664 136.0000 0.0000  1.0247 1.0000 642.0000  0.7548 1.0000    NaN 0.0000
-0.0930     NaN 7.0441 0.0000 0.0000 0.1519 17.0000 0.7892 248.0000 0.0000  1.6840 136.0000 0.0000 -1.3137 2.0000 164.0000  1.2673 1.0000    NaN 0.0000
-0.3417     NaN 7.1577 0.0000 0.0000 0.2722 21.0000 0.6737 239.0000 0.0000  1.4180 136.0000 0.0000 -1.5731 2.0000  75.0000  0.9931 1.0000    NaN 0.00

In [None]:
print(f"{X_embedded.shape=}")

X_embedded.shape=torch.Size([801, 10, 192])


In [44]:
# @title Run the full encoder
# We use the internal method to reproduce the preprocessing exactly
# This returns the numpy array that TabPFN actually uses

view_tensor(X_flat, "RAW INPUT: X_flat")
X_embedded = model.encoder({"main": X_flat}, single_eval_pos=0)
view_embedding(X_embedded, "STEP 3.5: LINEAR EMBEDDING")


RAW INPUT: X_flat
   Shape: torch.Size([801, 10, 2])
   View:  (Showing 5 rows x 20 cols)
     0       1      2      3      4      5       6      7        8      9       10       11     12      13     14       15      16     17     18     19
 3.7606     NaN 8.2046 1.0000 0.9873 0.1329 16.0000 0.5575 620.0000 0.0000  1.1015 136.0000 0.8038 -0.2178 2.0000 227.0000 -0.7573 1.0000    NaN 3.0000
 1.4647 57.9792 3.7324 0.0000 0.8228 0.1329 16.0000 0.8695   7.0000 2.0000 -2.5159  16.0000 0.0000  0.9315 0.0000 325.0000  0.9160 0.0000 0.5327 1.0000
-1.0227     NaN 8.0700 0.0000 0.0000    NaN     NaN 0.8100 433.0000 0.0000  0.9664 136.0000 0.0000  1.0247 1.0000 642.0000  0.7548 1.0000    NaN 0.0000
-0.0930     NaN 7.0441 0.0000 0.0000 0.1519 17.0000 0.7892 248.0000 0.0000  1.6840 136.0000 0.0000 -1.3137 2.0000 164.0000  1.2673 1.0000    NaN 0.0000
-0.3417     NaN 7.1577 0.0000 0.0000 0.2722 21.0000 0.6737 239.0000 0.0000  1.4180 136.0000 0.0000 -1.5731 2.0000  75.0000  0.9931 1.0000    NaN 0.00

In [76]:
# @title Show Steps for a single row
curr_i = 0

row_raw = X_raw.iloc[[0]].copy()
print(f"=== STEP 0: RAW DATA ===")
print(row_raw.T)
print("\n" + "="*50 + "\n")

ensemble_configs, X_processed, y_processed = clf._initialize_dataset_preprocessing(X_raw, y_raw, rng=42)
row_step1 = X_processed[0]

print(f"=== STEP 1: GLOBAL PREPROCESSING OUTPUT ===")
print(f"Shape: {X_processed.shape}")
print(f"Data Type: {X_processed.dtype}")
print(f"Content (Float Matrix):\n{np.round(row_step1, 3)}")
print("\n" + "="*50 + "\n")


input_cat_cols = clf.preprocessor_.transformers_[0][2]
cat_ix = list(range(len(input_cat_cols)))
view_num = 0
curr_config = ensemble_configs[view_num]
curr_pipeline = curr_config.to_pipeline(random_state=42)
res = curr_pipeline.fit_transform(X_processed, cat_ix)
X_view = res.X
row_step2 = X_view[0]

print(f"=== STEP 2: ENSEMBLE VIEW #0 OUTPUT ===")
print(f"Shape: {X_view.shape} (Notice extra fingerprint and transformation columns are added)")
print(f"Content (Randomized & Scaled):\n{np.round(row_step2, 3)}")
print("\n" + "="*50 + "\n")

row_step3_start = X_flat[0]
print(f"=== STEP 3: LATENT EMBEDDING (Thinking Token) ===")
print(f"\n--- Start: from vector to tensor ---")
print(f"Shape: {X_flat.shape}")
print(f"Content (Tensor):\n{np.round(row_step3_start, 3)}")


X_embedded = model.encoder({"main": X_flat}, single_eval_pos=0)
row_step3_end = X_embedded[0]

print(f"\n--- End: output of linear projection MLP ---")
print(f"Shape: {X_embedded.shape}")
print(f"Content (First 8 Embedding Dims):\n{np.round(row_step3_end[:,:8].detach().numpy(), 3)} ...")
print("\n" + "="*50 + "\n")


=== STEP 0: RAW DATA ===
                              86
Pclass                         3
Sex                         male
Age                         16.0
SibSp                          1
Parch                          3
Fare                      34.375
Embarked                       S
Ticket                W./C. 6608
Name      Ford, Mr. William Neal
Cabin                        NaN


=== STEP 1: GLOBAL PREPROCESSING OUTPUT ===
Shape: (801, 10)
Data Type: float64
Content (Float Matrix):
[  2.   1.   2. 620. 227. 136.  16.   1.   3.  nan]


=== STEP 2: ENSEMBLE VIEW #0 OUTPUT ===
Shape: (801, 20) (Notice extra fingerprint and transformation columns are added)
Content (Randomized & Scaled):
[  3.761     nan   8.205   1.      0.987   0.133  16.      0.557 620.
   0.      1.102 136.      0.804  -0.218   2.    227.     -0.757   1.
     nan   3.   ]


=== STEP 3: LATENT EMBEDDING (Thinking Token) ===

--- Start: from vector to tensor ---
Shape: torch.Size([801, 10, 2])
Content (Tensor):
te

Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client
  X_raw: single or list of input dataset features, in case of single it
