# Supply Chain Risk Classification with TabPFN

This notebook demonstrates how to use **TabPFN** for classification tasks in retail/CPG supply chain planning.

TabPFN is a foundation model for tabular data that outperforms traditional methods while being dramatically faster. It requires no hyperparameter tuning and works out-of-the-box.

**Use Cases Covered:**
1. **Supplier Delay Risk Prediction** (Binary Classification) - Predict which supplier deliveries will be delayed
2. **Material Shortage Prediction** (Multi-class Classification) - Predict material shortage risk levels
3. **Labor Shortage Prediction** (Multi-class Classification) - Predict workforce availability issues
4. **OTIF Risk Prediction** (Multi-class Classification) - Predict on-time-in-full delivery risk

**Business Value:**
- Enable proactive supply risk mitigation
- Optimize safety stock and expediting decisions
- Improve workforce planning and scheduling
- Improve on-time delivery performance
- Enhance customer service through proactive OTIF management

**Prerequisites:** Run `00_data_preparation` notebook first to set up the datasets.

**References:**
- [TabPFN Client GitHub](https://github.com/PriorLabs/tabpfn-client)
- [Prior Labs Documentation](https://docs.priorlabs.ai/)

## Compute Setup

We recommend running this notebook on **Serverless Compute** with the **Base Environment V4**.

To configure:
1. Click on the compute selector in the notebook toolbar
2. Select **Serverless**
3. Under Environment, choose **Base Environment V4**

Serverless compute provides fast startup times and automatic scaling, ideal for interactive notebook workflows.

## 1. Installation

In [None]:
%pip install tabpfn-client scikit-learn pandas matplotlib mlflow --quiet

In [None]:
dbutils.library.restartPython()

## 2. Authentication

TabPFN client requires authentication using an access token.

**Setting up Databricks Secrets (one-time setup):**

1. Create a secret scope using the Databricks CLI:
   ```bash
   databricks secrets create-scope tabpfn-client
   ```

2. Store your TabPFN token in the secret scope:
   ```bash
   databricks secrets put-secret tabpfn-client token
   ```

3. You can retrieve your TabPFN token on another machine by running:
   ```python
   import tabpfn_client
   token = tabpfn_client.get_access_token()
   print(token)
   ```

In [None]:
import tabpfn_client

token = dbutils.secrets.get(scope="tabpfn-client", key="token")
tabpfn_client.set_access_token(token)

## 3. Configuration

Configure the catalog and schema where the prepared datasets are stored.

In [None]:
# Configure catalog and schema (must match 00_data_preparation)
CATALOG = "tabpfn_databricks"
SCHEMA = "default"

# MLflow experiment configuration (shared across all TabPFN notebooks)
# Default uses user namespace, but can be customized
current_user = spark.sql("SELECT current_user()").collect()[0][0]
MLFLOW_EXPERIMENT_NAME = f"/Users/{current_user}/tabpfn-databricks"

spark.sql(f"USE CATALOG {CATALOG}")
spark.sql(f"USE SCHEMA {SCHEMA}")

## 4. Import Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
import mlflow

from tabpfn_client import TabPFNClassifier

# Set MLflow experiment
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
print(f"MLflow experiment set to: {MLFLOW_EXPERIMENT_NAME}")

---
# Part 1: Supply Planning
---

## 5. Supplier Delay Risk Prediction (Binary Classification)

**Business Context:** Supply planners need to identify which incoming supplier deliveries are at risk of delay so they can:
- Expedite high-risk orders
- Adjust production schedules
- Communicate proactively with stakeholders

We'll use TabPFN to predict whether a supplier delivery will be delayed based on:
- Supplier characteristics (tier, country, reliability history)
- Order characteristics (quantity, value, lead time)
- External factors (port congestion, weather risk, peak season)

In [None]:
# Load the Supplier Delay Risk training dataset from Delta table
df_delay = spark.table("supplier_delay_risk_train").toPandas()

print(f"Dataset shape: {df_delay.shape}")
print(f"\nFeatures:")
print([col for col in df_delay.columns if col != 'is_delayed'])
print(f"\nTarget distribution (is_delayed):")
print(f"  On-time (0): {(df_delay['is_delayed'] == 0).sum()}")
print(f"  Delayed (1): {(df_delay['is_delayed'] == 1).sum()}")
print(f"  Delay rate: {df_delay['is_delayed'].mean():.1%}")

In [None]:
# Prepare features - encode categorical columns
df_encoded = pd.get_dummies(df_delay, columns=['supplier_tier', 'supplier_country'], drop_first=True)

# Separate features and target
feature_cols = [col for col in df_encoded.columns if col != 'is_delayed']
X = df_encoded[feature_cols].values
y = df_delay['is_delayed'].values

print(f"Feature matrix shape: {X.shape}")
print(f"Number of encoded features: {len(feature_cols)}")

In [None]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set size: {len(X_train)}")
print(f"Test set size: {len(X_test)}")

In [None]:
# Initialize and train TabPFN classifier with MLflow logging
with mlflow.start_run(run_name="supplier_delay_risk_tabpfn"):
    # Log parameters
    mlflow.log_param("model_type", "TabPFNClassifier")
    mlflow.log_param("task", "supplier_delay_risk")
    mlflow.log_param("problem_type", "binary_classification")
    mlflow.log_param("test_size", 0.2)
    mlflow.log_param("n_features", X_train.shape[1])
    mlflow.log_param("train_samples", X_train.shape[0])
    mlflow.log_param("test_samples", X_test.shape[0])
    
    clf = TabPFNClassifier()
    clf.fit(X_train, y_train)

    # Make predictions
    y_pred = clf.predict(X_test)
    y_pred_proba = clf.predict_proba(X_test)

    # Evaluate performance
    accuracy = accuracy_score(y_test, y_pred)
    roc_auc = roc_auc_score(y_test, y_pred_proba[:, 1])
    
    # Log metrics
    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("roc_auc", roc_auc)

    print(f"TabPFN Supplier Delay Risk Prediction Results:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  ROC AUC:  {roc_auc:.4f}")
    print(f"  MLflow Run ID: {mlflow.active_run().info.run_id}")

In [None]:
# Detailed classification report
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=['On-Time', 'Delayed']))

In [None]:
# Visualize high-risk deliveries
# Show deliveries with highest predicted delay probability
df_test = df_delay.iloc[X_train.shape[0]:].copy().reset_index(drop=True)
df_test['delay_probability'] = y_pred_proba[:, 1]
df_test['predicted_delayed'] = y_pred

print("Top 10 Highest Risk Deliveries:")
high_risk = df_test.nlargest(10, 'delay_probability')[[
    'supplier_tier', 'supplier_country', 'contracted_lead_time_days',
    'historical_otd_rate', 'port_congestion_index', 'delay_probability', 'is_delayed'
]]
display(high_risk)

## 6. Material Shortage Prediction (Multi-class Classification)

**Business Context:** Material planners need to identify which materials are at risk of shortage to:
- Prioritize procurement actions
- Expedite critical orders
- Adjust production schedules

We'll predict shortage risk levels:
- **0 = No Risk**: Adequate inventory coverage
- **1 = At Risk**: Monitor closely, may need action
- **2 = Critical**: Immediate action required

In [None]:
# Load the Material Shortage training dataset from Delta table
df_shortage = spark.table("material_shortage_train").toPandas()

print(f"Dataset shape: {df_shortage.shape}")
print(f"\nTarget distribution (shortage_risk):")
shortage_labels = {0: 'No Risk', 1: 'At Risk', 2: 'Critical'}
for val, label in shortage_labels.items():
    count = (df_shortage['shortage_risk'] == val).sum()
    print(f"  {val} ({label}): {count} ({count/len(df_shortage):.1%})")

In [None]:
# Prepare features - encode categorical columns
df_shortage_encoded = pd.get_dummies(df_shortage, 
                                      columns=['material_type', 'criticality_class'], 
                                      drop_first=True)

# Separate features and target
shortage_feature_cols = [col for col in df_shortage_encoded.columns if col != 'shortage_risk']
X_shortage = df_shortage_encoded[shortage_feature_cols].values
y_shortage = df_shortage['shortage_risk'].values

print(f"Feature matrix shape: {X_shortage.shape}")

# Split the data
X_train_s, X_test_s, y_train_s, y_test_s = train_test_split(
    X_shortage, y_shortage, test_size=0.3, random_state=42, stratify=y_shortage
)

In [None]:
# Train TabPFN on multi-class problem with MLflow logging
with mlflow.start_run(run_name="material_shortage_tabpfn"):
    # Log parameters
    mlflow.log_param("model_type", "TabPFNClassifier")
    mlflow.log_param("task", "material_shortage")
    mlflow.log_param("problem_type", "multiclass_classification")
    mlflow.log_param("n_classes", 3)
    mlflow.log_param("test_size", 0.3)
    mlflow.log_param("n_features", X_train_s.shape[1])
    mlflow.log_param("train_samples", X_train_s.shape[0])
    mlflow.log_param("test_samples", X_test_s.shape[0])
    
    clf_shortage = TabPFNClassifier()
    clf_shortage.fit(X_train_s, y_train_s)

    # Make predictions
    y_pred_shortage = clf_shortage.predict(X_test_s)
    y_pred_proba_shortage = clf_shortage.predict_proba(X_test_s)

    # Evaluate
    accuracy_shortage = accuracy_score(y_test_s, y_pred_shortage)
    
    # Log metrics
    mlflow.log_metric("accuracy", accuracy_shortage)
    
    print(f"Multi-class Classification Accuracy: {accuracy_shortage:.4f}")
    print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")

    print("\nClassification Report:")
    print(classification_report(y_test_s, y_pred_shortage, 
                                target_names=['No Risk', 'At Risk', 'Critical']))

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test_s, y_pred_shortage)

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
ax.figure.colorbar(im, ax=ax)

classes = ['No Risk', 'At Risk', 'Critical']
ax.set(xticks=np.arange(cm.shape[1]),
       yticks=np.arange(cm.shape[0]),
       xticklabels=classes, yticklabels=classes,
       title='Material Shortage Risk - Confusion Matrix',
       ylabel='Actual',
       xlabel='Predicted')

# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ax.text(j, i, format(cm[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

In [None]:
# Identify critical materials requiring immediate attention
df_test_shortage = df_shortage.iloc[X_train_s.shape[0]:].copy().reset_index(drop=True)
df_test_shortage['predicted_risk'] = y_pred_shortage
df_test_shortage['critical_probability'] = y_pred_proba_shortage[:, 2]  # Probability of Critical

print("Top 10 Materials with Highest Critical Risk Probability:")
critical_materials = df_test_shortage.nlargest(10, 'critical_probability')[[
    'material_type', 'criticality_class', 'current_stock_days',
    'num_active_suppliers', 'avg_supplier_reliability', 
    'critical_probability', 'shortage_risk'
]]
display(critical_materials)

---
# Part 2: Production Planning
---

## 7. Labor Shortage Prediction (Multi-class Classification)

**Business Context:** Production and HR planners need to anticipate workforce availability issues to:
- Schedule overtime or temporary staffing
- Adjust production schedules based on labor constraints
- Prioritize cross-training initiatives
- Improve hiring and retention strategies

We'll predict labor shortage risk levels:
- **0 = Adequate**: Sufficient workforce coverage
- **1 = At Risk**: Monitor closely, may need contingency plans
- **2 = Critical**: Immediate staffing action required

In [None]:
# Load the Labor Shortage training dataset from Delta table
df_labor = spark.table("labor_shortage_train").toPandas()

print(f"Dataset shape: {df_labor.shape}")
print(f"\nFeatures:")
print([col for col in df_labor.columns if col != 'labor_shortage_risk'])
print(f"\nTarget distribution (labor_shortage_risk):")
labor_labels = {0: 'Adequate', 1: 'At Risk', 2: 'Critical'}
for val, label in labor_labels.items():
    count = (df_labor['labor_shortage_risk'] == val).sum()
    print(f"  {val} ({label}): {count} ({count/len(df_labor):.1%})")

In [None]:
# Prepare features - encode categorical columns
df_labor_encoded = pd.get_dummies(df_labor, 
                                   columns=['facility_type', 'facility_size', 'region'], 
                                   drop_first=True)

# Separate features and target
labor_feature_cols = [col for col in df_labor_encoded.columns if col != 'labor_shortage_risk']
X_labor = df_labor_encoded[labor_feature_cols].values
y_labor = df_labor['labor_shortage_risk'].values

print(f"Feature matrix shape: {X_labor.shape}")

# Split the data
X_train_l, X_test_l, y_train_l, y_test_l = train_test_split(
    X_labor, y_labor, test_size=0.3, random_state=42, stratify=y_labor
)
print(f"Training set size: {len(X_train_l)}")
print(f"Test set size: {len(X_test_l)}")

In [None]:
# Train TabPFN on labor shortage prediction with MLflow logging
with mlflow.start_run(run_name="labor_shortage_tabpfn"):
    # Log parameters
    mlflow.log_param("model_type", "TabPFNClassifier")
    mlflow.log_param("task", "labor_shortage")
    mlflow.log_param("problem_type", "multiclass_classification")
    mlflow.log_param("n_classes", 3)
    mlflow.log_param("test_size", 0.3)
    mlflow.log_param("n_features", X_train_l.shape[1])
    mlflow.log_param("train_samples", X_train_l.shape[0])
    mlflow.log_param("test_samples", X_test_l.shape[0])
    
    clf_labor = TabPFNClassifier()
    clf_labor.fit(X_train_l, y_train_l)

    # Make predictions
    y_pred_labor = clf_labor.predict(X_test_l)
    y_pred_proba_labor = clf_labor.predict_proba(X_test_l)

    # Evaluate
    accuracy_labor = accuracy_score(y_test_l, y_pred_labor)
    
    # Log metrics
    mlflow.log_metric("accuracy", accuracy_labor)

    print(f"Multi-class Classification Accuracy: {accuracy_labor:.4f}")
    print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")

    print("\nClassification Report:")
    print(classification_report(y_test_l, y_pred_labor, 
                                target_names=['Adequate', 'At Risk', 'Critical']))

In [None]:
# Confusion Matrix for Labor Shortage
cm_labor = confusion_matrix(y_test_l, y_pred_labor)

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm_labor, interpolation='nearest', cmap='Oranges')
ax.figure.colorbar(im, ax=ax)

classes = ['Adequate', 'At Risk', 'Critical']
ax.set(xticks=np.arange(cm_labor.shape[1]),
       yticks=np.arange(cm_labor.shape[0]),
       xticklabels=classes, yticklabels=classes,
       title='Labor Shortage Risk - Confusion Matrix',
       ylabel='Actual',
       xlabel='Predicted')

thresh = cm_labor.max() / 2.
for i in range(cm_labor.shape[0]):
    for j in range(cm_labor.shape[1]):
        ax.text(j, i, format(cm_labor[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm_labor[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

In [None]:
# Identify facilities with highest labor shortage risk
df_test_labor = df_labor.iloc[X_train_l.shape[0]:].copy().reset_index(drop=True)
df_test_labor['predicted_risk'] = y_pred_labor
df_test_labor['critical_probability'] = y_pred_proba_labor[:, 2]  # Probability of Critical

print("Top 10 Facilities with Highest Critical Labor Shortage Risk:")
critical_facilities = df_test_labor.nlargest(10, 'critical_probability')[[
    'facility_type', 'region', 'current_headcount', 'headcount_ratio',
    'turnover_rate_monthly', 'open_positions', 'local_unemployment_rate',
    'critical_probability', 'labor_shortage_risk'
]]
display(critical_facilities)

In [None]:
# Analyze risk by facility type and region
df_test_labor['risk_label'] = df_test_labor['predicted_risk'].map({0: 'Adequate', 1: 'At Risk', 2: 'Critical'})

print("\nPredicted Labor Shortage Risk by Facility Type:")
display(pd.crosstab(df_test_labor['facility_type'], df_test_labor['risk_label'], normalize='index').round(3) * 100)

print("\nPredicted Labor Shortage Risk by Region:")
display(pd.crosstab(df_test_labor['region'], df_test_labor['risk_label'], normalize='index').round(3) * 100)

---
# Part 3: Distribution Planning
---

## 8. OTIF Risk Prediction (Multi-class Classification)

**Business Context:** Distribution and customer service teams need to identify orders at risk of not being delivered On-Time-In-Full (OTIF) to:
- Proactively communicate with customers about potential delays
- Prioritize orders for expedited processing
- Allocate resources to high-risk shipments
- Improve overall customer satisfaction and retention

We'll predict OTIF risk levels:
- **0 = Low Risk**: High confidence of successful OTIF delivery
- **1 = Medium Risk**: Monitor closely, may need intervention
- **2 = High Risk**: Immediate action required to prevent OTIF failure

In [None]:
# Load the OTIF Risk training dataset from Delta table
df_otif = spark.table("otif_risk_train").toPandas()

print(f"Dataset shape: {df_otif.shape}")
print(f"\nFeatures:")
print([col for col in df_otif.columns if col != 'otif_risk'])
print(f"\nTarget distribution (otif_risk):")
otif_labels = {0: 'Low Risk', 1: 'Medium Risk', 2: 'High Risk'}
for val, label in otif_labels.items():
    count = (df_otif['otif_risk'] == val).sum()
    print(f"  {val} ({label}): {count} ({count/len(df_otif):.1%})")

In [None]:
# Prepare features - encode categorical columns
df_otif_encoded = pd.get_dummies(df_otif, 
                                  columns=['order_type', 'order_size', 'customer_tier', 
                                          'customer_order_frequency', 'fulfillment_source',
                                          'carrier_tier', 'order_day_of_week', 
                                          'requested_delivery_window'], 
                                  drop_first=True)

# Separate features and target
otif_feature_cols = [col for col in df_otif_encoded.columns if col != 'otif_risk']
X_otif = df_otif_encoded[otif_feature_cols].values
y_otif = df_otif['otif_risk'].values

print(f"Feature matrix shape: {X_otif.shape}")

# Split the data
X_train_o, X_test_o, y_train_o, y_test_o = train_test_split(
    X_otif, y_otif, test_size=0.3, random_state=42, stratify=y_otif
)
print(f"Training set size: {len(X_train_o)}")
print(f"Test set size: {len(X_test_o)}")

In [None]:
# Train TabPFN on OTIF risk prediction with MLflow logging
with mlflow.start_run(run_name="otif_risk_tabpfn"):
    # Log parameters
    mlflow.log_param("model_type", "TabPFNClassifier")
    mlflow.log_param("task", "otif_risk")
    mlflow.log_param("problem_type", "multiclass_classification")
    mlflow.log_param("n_classes", 3)
    mlflow.log_param("test_size", 0.3)
    mlflow.log_param("n_features", X_train_o.shape[1])
    mlflow.log_param("train_samples", X_train_o.shape[0])
    mlflow.log_param("test_samples", X_test_o.shape[0])
    
    clf_otif = TabPFNClassifier()
    clf_otif.fit(X_train_o, y_train_o)

    # Make predictions
    y_pred_otif = clf_otif.predict(X_test_o)
    y_pred_proba_otif = clf_otif.predict_proba(X_test_o)

    # Evaluate
    accuracy_otif = accuracy_score(y_test_o, y_pred_otif)
    
    # Log metrics
    mlflow.log_metric("accuracy", accuracy_otif)

    print(f"Multi-class Classification Accuracy: {accuracy_otif:.4f}")
    print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")

    print("\nClassification Report:")
    print(classification_report(y_test_o, y_pred_otif, 
                                target_names=['Low Risk', 'Medium Risk', 'High Risk']))

In [None]:
# Confusion Matrix for OTIF Risk
cm_otif = confusion_matrix(y_test_o, y_pred_otif)

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm_otif, interpolation='nearest', cmap='Greens')
ax.figure.colorbar(im, ax=ax)

classes = ['Low Risk', 'Medium Risk', 'High Risk']
ax.set(xticks=np.arange(cm_otif.shape[1]),
       yticks=np.arange(cm_otif.shape[0]),
       xticklabels=classes, yticklabels=classes,
       title='OTIF Risk Prediction - Confusion Matrix',
       ylabel='Actual',
       xlabel='Predicted')

thresh = cm_otif.max() / 2.
for i in range(cm_otif.shape[0]):
    for j in range(cm_otif.shape[1]):
        ax.text(j, i, format(cm_otif[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm_otif[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

In [None]:
# Identify high-risk orders requiring immediate attention
df_test_otif = df_otif.iloc[X_train_o.shape[0]:].copy().reset_index(drop=True)
df_test_otif['predicted_risk'] = y_pred_otif
df_test_otif['high_risk_probability'] = y_pred_proba_otif[:, 2]  # Probability of High Risk

print("Top 10 Orders with Highest OTIF Failure Risk:")
high_risk_orders = df_test_otif.nlargest(10, 'high_risk_probability')[[
    'order_type', 'customer_tier', 'fulfillment_source', 'carrier_tier',
    'inventory_availability_rate', 'days_until_delivery',
    'high_risk_probability', 'otif_risk'
]]
display(high_risk_orders)

In [None]:
# Analyze OTIF risk by key factors
df_test_otif['risk_label'] = df_test_otif['predicted_risk'].map({0: 'Low Risk', 1: 'Medium Risk', 2: 'High Risk'})

print("Predicted OTIF Risk by Order Type:")
display(pd.crosstab(df_test_otif['order_type'], df_test_otif['risk_label'], normalize='index').round(3) * 100)

print("\nPredicted OTIF Risk by Customer Tier:")
display(pd.crosstab(df_test_otif['customer_tier'], df_test_otif['risk_label'], normalize='index').round(3) * 100)

## Summary

In this notebook, we demonstrated:

- **Binary Classification**: Supplier delay risk prediction with ROC AUC analysis
- **Multi-class Classification**: Material shortage risk levels (No Risk, At Risk, Critical)
- **Multi-class Classification**: Labor shortage prediction for workforce planning
- **Multi-class Classification**: OTIF risk prediction for distribution planning

**Key Takeaways:**
1. TabPFN provides strong performance without hyperparameter tuning
2. Probability outputs enable risk-based decision making
3. OTIF prediction enables proactive customer service management

**Next Steps:**
- Run `02_regression` notebook for price elasticity, promotion lift, and lead time prediction
- Explore threshold optimization for different business objectives
- Integrate predictions into supply planning, HR, and distribution workflows