In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from plug.model import fit, cross_validate, save_artifacts, load_artifacts, predict

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Create some sample data
X, y = make_classification(
    n_samples=1000,
    n_features=512,  # Simulate transformer hidden size
    n_classes=4,
    n_informative=256,
    random_state=42
)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training data shape: {X_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Number of classes: {len(np.unique(y))}")


  from .autonotebook import tqdm as notebook_tqdm


Training data shape: (800, 512)
Training labels shape: (800,)
Number of classes: 4


In [2]:
print("=== 1. Built-in MLP Model (String Specification) ===")
# Start with the simple built-in MLP using string specification
mlp_model, mlp_history = fit(
    X_train, 
    y_train,
    model="mlp",  # Built-in MLP using string
    num_classes=4,
    n_epochs=20,
    val_split=0.2,
    patience=5,
    learning_rate=1e-3,
    batch_size=128
)

print(f"Built-in MLP trained in {len(mlp_history)} epochs")
print(f"Final validation metric: {mlp_history[-1]['val_metric']:.4f}")

print("\n=== 2. Custom Model Function ===")
def deep_probe(input_dim, num_classes, hidden_dim=256, num_layers=3, dropout=0.5):
    """A deeper probe with configurable architecture."""
    layers = []
    
    # Input layer
    layers.extend([
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Dropout(dropout)
    ])
    
    # Hidden layers
    for _ in range(num_layers - 2):
        layers.extend([
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        ])
    
    # Output layer
    layers.append(nn.Linear(hidden_dim, num_classes))
    
    return nn.Sequential(*layers)

# Train the custom deep probe
deep_model, deep_history = fit(
    X_train, y_train,
    model=deep_probe,  # Custom model function
    num_classes=4,
    n_epochs=20,
    val_split=0.2,
    patience=5,
    hidden_dim=128,
    num_layers=4,
    dropout=0.3
)

print(f"Custom deep probe trained in {len(deep_history)} epochs")
print(f"Final validation metric: {deep_history[-1]['val_metric']:.4f}")


[00:38:28] INFO: PlugClassifier • d=512 N=640 target=5.0 ⇒ fc1=16 params=9492 (≈14.8×N)


=== 1. Built-in MLP Model (String Specification) ===


Fit: 100%|█| 20/20 [00:00<00:00, 41.29it/s, train_loss=1.1957, val_metric=0.6426
[00:38:30] INFO: Restored best weights (val_metric 0.6426).


Built-in MLP trained in 20 epochs
Final validation metric: 0.6426

=== 2. Custom Model Function ===


Fit: 100%|█| 20/20 [00:00<00:00, 86.50it/s, train_loss=0.1963, val_metric=0.7686
[00:38:30] INFO: Restored best weights (val_metric 0.7686).


Custom deep probe trained in 20 epochs
Final validation metric: 0.7686


In [3]:
print("=== 3. Saving and Loading Models ===")
# Save the model with custom factory for reconstruction
weights_path, meta_path = save_artifacts(
    deep_model,
    path="tutorial_outputs/deep_probe_model",
    model_factory=deep_probe,
    model_kwargs={
        "hidden_dim": 128,
        "num_layers": 4,
        "dropout": 0.3
    },
    meta={"description": "Deep probe with 4 layers for classification"}
)

print(f"Model saved to: {weights_path}")

# Test loading the model back
loaded_model, loaded_meta = load_artifacts("tutorial_outputs/deep_probe_model", device="cpu")
print(f"Model loaded successfully: {loaded_meta['description']}")

print("\n=== 4. Making Predictions ===")
# Create test data
test_features_df = pd.DataFrame(X_test, columns=[f"feature_{i}" for i in range(X_test.shape[1])])
test_features_df.index = [f"sample_{i}" for i in range(len(test_features_df))]
test_features_df.to_csv("tutorial_outputs/test_features.csv")

ground_truth_df = pd.DataFrame({"id": test_features_df.index, "answer": y_test})
ground_truth_df.to_csv("tutorial_outputs/ground_truth.csv", index=False)

# Make predictions with output directory
predictions_df = predict(
    features="tutorial_outputs/test_features.csv",
    model_path="tutorial_outputs/deep_probe_model",
    output_csv="deep_probe_predictions.csv",
    response_csv="tutorial_outputs/ground_truth.csv",
    response_col="answer",
    device="cpu",
    batch_size=1024,
    output_dir="tutorial_outputs"
)

print(f"Predictions completed! Shape: {predictions_df.shape}")
print("Sample predictions:")
print(predictions_df.head())


[00:38:54] INFO: Artifacts saved → tutorial_outputs/deep_probe_model.pt (+ meta tutorial_outputs/deep_probe_model.json)
[00:38:54] INFO: Loaded model ← tutorial_outputs/deep_probe_model.pt (device=cpu)
[00:38:54] INFO: Loaded model ← tutorial_outputs/deep_probe_model.pt (device=cpu)
[00:38:54] INFO: Predictions → tutorial_outputs/deep_probe_predictions.csv
[00:38:54] INFO: Metrics → tutorial_outputs/deep_probe_predictions.metrics.json


=== 3. Saving and Loading Models ===
Model saved to: /data1/home/kivelsons/plug-generic-probe/tutorial_outputs/deep_probe_model.pt
Model loaded successfully: Deep probe with 4 layers for classification

=== 4. Making Predictions ===
Predictions completed! Shape: (200, 5)
Sample predictions:
         id  prob_class_0  prob_class_1  prob_class_2  prob_class_3
0  sample_0      0.007184      0.051992      0.937045      0.003780
1  sample_1      0.002558      0.435386      0.190167      0.371888
2  sample_2      0.013586      0.204972      0.002027      0.779415
3  sample_3      0.998933      0.000481      0.000297      0.000290
4  sample_4      0.222337      0.533917      0.014780      0.228966


In [4]:
print("=== 5. Advanced Custom Architectures ===")

def attention_probe(input_dim, num_classes, hidden_dim=256, num_heads=8, dropout=0.1):
    """A probe with self-attention mechanism."""
    
    class AttentionProbe(nn.Module):
        def __init__(self, input_dim, num_classes, hidden_dim, num_heads, dropout):
            super().__init__()
            self.input_proj = nn.Linear(input_dim, hidden_dim)
            self.attention = nn.MultiheadAttention(
                embed_dim=hidden_dim,
                num_heads=num_heads,
                dropout=dropout,
                batch_first=True
            )
            self.norm = nn.LayerNorm(hidden_dim)
            self.dropout = nn.Dropout(dropout)
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, num_classes)
            )
            
        def forward(self, x):
            # Project input to hidden dimension
            x = self.input_proj(x)  # [batch, hidden_dim]
            
            # Add sequence dimension for attention (treat as single token)
            x = x.unsqueeze(1)  # [batch, 1, hidden_dim]
            
            # Self-attention
            attn_out, _ = self.attention(x, x, x)
            attn_out = self.norm(attn_out + x)  # Residual connection
            
            # Remove sequence dimension and classify
            x = attn_out.squeeze(1)  # [batch, hidden_dim]
            x = self.dropout(x)
            return self.classifier(x)
    
    return AttentionProbe(input_dim, num_classes, hidden_dim, num_heads, dropout)

# Train attention-based probe
attention_model, attention_history = fit(
    X_train, y_train,
    model=attention_probe,
    num_classes=4,
    n_epochs=15,
    val_split=0.2,
    patience=5,
    hidden_dim=256,
    num_heads=4,
    dropout=0.1,
    learning_rate=1e-3
)

print(f"Attention probe trained in {len(attention_history)} epochs")
print(f"Final validation metric: {attention_history[-1]['val_metric']:.4f}")

# Compare model complexities
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameter counts:")
print(f"Built-in MLP: {count_parameters(mlp_model):,}")
print(f"Deep probe: {count_parameters(deep_model):,}")
print(f"Attention probe: {count_parameters(attention_model):,}")


=== 5. Advanced Custom Architectures ===


Fit:  33%|▎| 5/15 [00:00<00:00, 47.77it/s, train_loss=0.1035, val_metric=0.7272][00:39:10] INFO: Early stopping at epoch 12 (best=0.7435)
Fit:  73%|▋| 11/15 [00:00<00:00, 50.41it/s, train_loss=0.1035, val_metric=0.7272
[00:39:10] INFO: Restored best weights (val_metric 0.7435).


Attention probe trained in 12 epochs
Final validation metric: 0.7272

Model parameter counts:
Built-in MLP: 9,492
Deep probe: 99,204
Attention probe: 428,420


In [5]:
print("=== 6. Cross-Validation Comparison ===")

# Compare different models using cross-validation
models_to_compare = [
    ("mlp", {"model": "mlp"}),
    ("deep_probe", {"model": deep_probe, "hidden_dim": 128, "num_layers": 3, "dropout": 0.3}),
    ("attention", {"model": attention_probe, "hidden_dim": 128, "num_heads": 4, "dropout": 0.1})
]

cv_results = {}
for model_name, model_config in models_to_compare:
    print(f"\nRunning CV for {model_name}...")
    
    preds, cv_summary = cross_validate(
        X_train, y_train,
        num_classes=4,
        n_splits=3,  # Fewer splits for demo
        n_epochs=10,
        patience=3,
        out_dir="tutorial_outputs/cv_results",
        run_name=f"{model_name}_cv",
        **model_config
    )
    
    cv_results[model_name] = cv_summary
    print(f"{model_name} CV score: {cv_summary['overall_metric']:.4f}")

# Print comparison summary
print("\n=== Cross-Validation Results Summary ===")
for model_name, summary in cv_results.items():
    metric_name = summary['metric_name']
    overall_score = summary['overall_metric']
    total_time = summary['sec_total']
    print(f"{model_name:15s} | {metric_name}: {overall_score:.4f} | Time: {total_time:.1f}s")


[00:39:18] INFO: Fold 1/3 - training …
[00:39:18] INFO: PlugClassifier • d=512 N=453 target=5.0 ⇒ fc1=16 params=9492 (≈21.0×N)
[00:39:18] INFO: Fold 1 ep   1 | train_roc_auc 0.5638 val_roc_auc 0.5275
[00:39:18] INFO: Fold 1 ep   2 | train_roc_auc 0.5955 val_roc_auc 0.5181
[00:39:18] INFO: Fold 1 ep   3 | train_roc_auc 0.6204 val_roc_auc 0.5219
[00:39:18] INFO: Fold 1 ep   4 | train_roc_auc 0.6442 val_roc_auc 0.5233
[00:39:18] INFO: Fold 1 ep   5 | train_roc_auc 0.6612 val_roc_auc 0.5346
[00:39:18] INFO: Fold 1 ep   6 | train_roc_auc 0.6798 val_roc_auc 0.5340
[00:39:18] INFO: Fold 1 ep   7 | train_roc_auc 0.6975 val_roc_auc 0.5406
[00:39:18] INFO: Fold 1 ep   8 | train_roc_auc 0.7190 val_roc_auc 0.5465
[00:39:19] INFO: Fold 1 ep   9 | train_roc_auc 0.7360 val_roc_auc 0.5598


=== 6. Cross-Validation Comparison ===

Running CV for mlp...


[00:39:19] INFO: Fold 1 ep  10 | train_roc_auc 0.7585 val_roc_auc 0.5640
[00:39:19] INFO: Fold 1 - restored best weights (val 0.5640)
[00:39:19] INFO: Fold 1 finished - roc_auc 0.5818
[00:39:19] INFO: Fold 2/3 - training …
[00:39:19] INFO: PlugClassifier • d=512 N=453 target=5.0 ⇒ fc1=16 params=9492 (≈21.0×N)
[00:39:19] INFO: Fold 2 ep   1 | train_roc_auc 0.5297 val_roc_auc 0.4625
[00:39:19] INFO: Fold 2 ep   2 | train_roc_auc 0.5799 val_roc_auc 0.4810
[00:39:19] INFO: Fold 2 ep   3 | train_roc_auc 0.6302 val_roc_auc 0.4908
[00:39:19] INFO: Fold 2 ep   4 | train_roc_auc 0.6682 val_roc_auc 0.5090
[00:39:19] INFO: Fold 2 ep   5 | train_roc_auc 0.7001 val_roc_auc 0.5181
[00:39:19] INFO: Fold 2 ep   6 | train_roc_auc 0.7286 val_roc_auc 0.5229
[00:39:19] INFO: Fold 2 ep   7 | train_roc_auc 0.7520 val_roc_auc 0.5202
[00:39:19] INFO: Fold 2 ep   8 | train_roc_auc 0.7729 val_roc_auc 0.5179
[00:39:19] INFO: Fold 2 ep   9 | train_roc_auc 0.7925 val_roc_auc 0.5181
[00:39:19] INFO: Fold 2 ep  10 |

mlp CV score: 0.5279

Running CV for deep_probe...


[00:39:19] INFO: Fold 2 ep   5 | train_roc_auc 0.9834 val_roc_auc 0.6929
[00:39:19] INFO: Fold 2 ep   6 | train_roc_auc 0.9939 val_roc_auc 0.7177
[00:39:19] INFO: Fold 2 ep   7 | train_roc_auc 0.9980 val_roc_auc 0.7194
[00:39:19] INFO: Fold 2 ep   8 | train_roc_auc 0.9994 val_roc_auc 0.7227
[00:39:19] INFO: Fold 2 ep   9 | train_roc_auc 0.9999 val_roc_auc 0.7342
[00:39:19] INFO: Fold 2 ep  10 | train_roc_auc 1.0000 val_roc_auc 0.7448
[00:39:19] INFO: Fold 2 - restored best weights (val 0.7448)
[00:39:19] INFO: Fold 2 finished - roc_auc 0.6861
[00:39:19] INFO: Fold 3/3 - training …
[00:39:19] INFO: Fold 3 ep   1 | train_roc_auc 0.7124 val_roc_auc 0.5418
[00:39:19] INFO: Fold 3 ep   2 | train_roc_auc 0.8599 val_roc_auc 0.5588
[00:39:19] INFO: Fold 3 ep   3 | train_roc_auc 0.9467 val_roc_auc 0.5891
[00:39:20] INFO: Fold 3 ep   4 | train_roc_auc 0.9805 val_roc_auc 0.6155
[00:39:20] INFO: Fold 3 ep   5 | train_roc_auc 0.9926 val_roc_auc 0.6336
[00:39:20] INFO: Fold 3 ep   6 | train_roc_auc 

deep_probe CV score: 0.6940

Running CV for attention...


[00:39:20] INFO: Fold 1 finished - roc_auc 0.6960
[00:39:20] INFO: Fold 2/3 - training …
[00:39:20] INFO: Fold 2 ep   1 | train_roc_auc 0.8171 val_roc_auc 0.6202
[00:39:20] INFO: Fold 2 ep   2 | train_roc_auc 0.9133 val_roc_auc 0.6671
[00:39:20] INFO: Fold 2 ep   3 | train_roc_auc 0.9552 val_roc_auc 0.7044
[00:39:20] INFO: Fold 2 ep   4 | train_roc_auc 0.9721 val_roc_auc 0.7337
[00:39:20] INFO: Fold 2 ep   5 | train_roc_auc 0.9783 val_roc_auc 0.7419
[00:39:20] INFO: Fold 2 ep   6 | train_roc_auc 0.9836 val_roc_auc 0.7492
[00:39:20] INFO: Fold 2 ep   7 | train_roc_auc 0.9888 val_roc_auc 0.7552
[00:39:20] INFO: Fold 2 ep   8 | train_roc_auc 0.9938 val_roc_auc 0.7602
[00:39:20] INFO: Fold 2 ep   9 | train_roc_auc 0.9973 val_roc_auc 0.7621
[00:39:20] INFO: Fold 2 ep  10 | train_roc_auc 0.9990 val_roc_auc 0.7583
[00:39:20] INFO: Fold 2 - restored best weights (val 0.7621)
[00:39:20] INFO: Fold 2 finished - roc_auc 0.6752
[00:39:20] INFO: Fold 3/3 - training …
[00:39:20] INFO: Fold 3 ep   1 

attention CV score: 0.6804

=== Cross-Validation Results Summary ===
mlp             | roc_auc: 0.5279 | Time: 0.7s
deep_probe      | roc_auc: 0.6940 | Time: 0.4s
attention       | roc_auc: 0.6804 | Time: 0.6s
