
## FT-Transformer (Feature Tokenizer + Transformer)
```plaintext[]
Input Features  
│  
├─ Feature Tokenizer (Embedding Layer)  
│  │  
│  └─ Continuous Features → Linear Projection  
│  └─ Categorical Features → Embedding Lookup  
│  
└─ Transformer Encoder (Multiple Layers)  
   │  
   └─ Multi-Head Self-Attention  
   └─ Layer Normalization & Feed-Forward  
   └─ Residual Connections  
│  
└─ Prediction Head (MLP for Regression/Classification)  
```

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from collections import defaultdict
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error
from tqdm import tqdm
from sklearn import datasets

In [2]:
class FeatureTokenizer(nn.Module):
    """
    Tokenizes both numerical and categorical features.
    Numerical features are projected to d_model dimension.
    Categorical features are embedded.
    """
    def __init__(self, num_numerical_features, cat_cardinalities, d_model):
        super().__init__()
        self.num_numerical = num_numerical_features
        self.cat_cardinalities = cat_cardinalities
        
        if num_numerical_features > 0:
            self.num_projection = nn.Linear(num_numerical_features, d_model)
        
        if cat_cardinalities:
            self.cat_embeddings = nn.ModuleList([
                nn.Embedding(cardinality, d_model) 
                for cardinality in cat_cardinalities
            ])
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls_token, mean=0.0, std=0.02)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x_num, x_cat=None):
        tokens = []
        
        if self.num_numerical > 0:
            num_tokens = self.num_projection(x_num).unsqueeze(1)
            tokens.append(num_tokens)
        
        if x_cat is not None and self.cat_cardinalities:
            cat_tokens = []
            for i, embed in enumerate(self.cat_embeddings):
                cat_tokens.append(embed(x_cat[:, i]))
            cat_tokens = torch.stack(cat_tokens, dim=1)
            tokens.append(cat_tokens)
        
        tokens = torch.cat(tokens, dim=1)
        cls_tokens = self.cls_token.expand(x_num.size(0), -1, -1)
        tokens = torch.cat([cls_tokens, tokens], dim=1)
        return self.norm(tokens)

class PositionalEncoding(nn.Module):
    """
    Standard positional encoding for transformers (Optional)
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [3]:
def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

In [4]:
class FTTransformer(nn.Module):
    """
    FT-Transformer for tabular data with mixed numerical and categorical features.
    """

    def __init__(self, preprocessor, d_model=64, nhead=4, num_layers=3,
                 dim_feedforward=128, dropout=0.1, num_classes=None,
                 use_pos_encoding=False):
        """
        Args:
            preprocessor: Fitted DataFramePreprocessor instance
            d_model: Transformer model dimension
            nhead: Number of attention heads
            num_layers: Number of TransformerEncoder layers
            dim_feedforward: Dimension of feedforward network
            dropout: Dropout rate
            num_classes: Number of output classes (None for regression)
            use_pos_encoding: Whether to use positional encoding
        """
        super().__init__()
        self.preprocessor = preprocessor
        self.use_pos_encoding = use_pos_encoding
        self.num_classes = num_classes

        # Feature tokenizer
        self.feature_tokenizer = FeatureTokenizer(
            num_numerical_features=len(preprocessor.num_cols),
            cat_cardinalities=preprocessor.cat_cardinalities,
            d_model=d_model
        )

        # Calculate total number of tokens
        num_tokens = 1  # CLS token
        if len(preprocessor.num_cols) > 0:
            num_tokens += 1  # One token for all numerical features
        if preprocessor.cat_cardinalities:
            num_tokens += len(preprocessor.cat_cardinalities)

        # Optional positional encoding
        if self.use_pos_encoding:
            self.pos_encoder = PositionalEncoding(d_model, max_len=num_tokens)

        # Transformer encoder
        encoder_layer = TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output head
        if num_classes is not None:
            self.head = nn.Linear(d_model, num_classes)  # Binary or multi-class classification
        else:
            self.head = nn.Linear(d_model, 1)  # Regression

        # Initialize all weights
        self.apply(init_weights)

    def forward_from_processed(self, x_num, x_cat):
        """
        Forward pass with already processed numerical and categorical tensors.
        """
        tokens = self.feature_tokenizer(x_num, x_cat)

        if self.use_pos_encoding:
            tokens = self.pos_encoder(tokens)

        encoded = self.transformer_encoder(tokens)
        cls_output = encoded[:, 0, :]  # CLS token

        output = self.head(cls_output)

        # Return raw logits or regression output
        return output

    def forward(self, df):
        """
        Forward pass directly from a raw DataFrame.
        """
        x_num, x_cat = self.preprocessor.transform(df)
        return self.forward_from_processed(x_num, x_cat)

In [5]:
class DataFramePreprocessor:
    """
    Preprocesses a pandas DataFrame for the FT-Transformer
    - Identifies numerical and categorical columns
    - Normalizes numerical features
    - Encodes categorical features
    - Handles missing values
    """
    def __init__(self):
        self.num_cols = []
        self.cat_cols = []
        self.cat_cardinalities = []
        self.num_scaler = StandardScaler()
        self.cat_encoders = defaultdict(LabelEncoder)
        self.fitted = False
        
    def fit(self, df, target_col='target'):
        """Identify and prepare feature processors"""
        # Identify feature types
        self.num_cols = df.drop(columns=[target_col])\
            .select_dtypes(include=['number'])\
            .columns.tolist()
        self.cat_cols = df.drop(columns=[target_col])\
            .select_dtypes(exclude=['number'])\
            .columns.tolist()
        
        # Fit numerical scaler
        if self.num_cols:
            self.num_scaler.fit(df[self.num_cols].fillna(0).values)
        
        # Fit categorical encoders and get cardinalities
        self.cat_cardinalities = []
        for col in self.cat_cols:
            # Fill NA with a special category
            series = df[col].fillna('__NA__').astype(str)
            self.cat_encoders[col].fit(series)
            self.cat_cardinalities.append(len(self.cat_encoders[col].classes_))
        
        self.fitted = True
        return self
    
    def transform(self, df):
        """Transform a DataFrame into processed numerical and categorical tensors"""
        if not self.fitted:
            raise RuntimeError("Preprocessor must be fit before transforming data")
        
        # Process numerical features
        x_num = torch.empty(0)
        if self.num_cols:
            num_values = self.num_scaler.transform(df[self.num_cols].fillna(0).values)
            x_num = torch.FloatTensor(num_values)
        
        # Process categorical features
        x_cat = torch.empty(0)
        if self.cat_cols:
            cat_data = []
            for col in self.cat_cols:
                series = df[col].fillna('__NA__').astype(str)
                encoded = self.cat_encoders[col].transform(series)
                cat_data.append(encoded)
            x_cat = torch.LongTensor(np.column_stack(cat_data))
        
        return x_num, x_cat
    
    def fit_transform(self, df):
        """Fit and transform in one step"""
        return self.fit(df).transform(df)

In [6]:
class TabularDataset(Dataset):
    """Dataset for handling pandas DataFrames"""
    def __init__(self, df, preprocessor, is_classification=True):
        self.x_num, self.x_cat = preprocessor.transform(df)
        self.is_classification = is_classification
        # Convert to long for classification, float for regression
        if is_classification:
            self.y = torch.LongTensor(df['target'].values)  # Changed to LongTensor
        else:
            self.y = torch.FloatTensor(df['target'].values)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        if self.x_cat.nelement() == 0:  # No categorical features
            return self.x_num[idx], self.y[idx]
        return self.x_num[idx], self.x_cat[idx], self.y[idx]

In [7]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10, device='cpu'):
    """Training loop with validation"""
    model.to(device)
    best_val_loss = float('inf')

    train_losses = []
    val_losses = []
    val_accuracies = []
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            if len(batch) == 2:  # Only numerical features
                x_num, y = batch
                x_cat = None
            else:  # Both numerical and categorical
                x_num, x_cat, y = batch
            
            x_num = x_num.to(device)
            y = y.to(device if model.num_classes is None else torch.long)  # Ensure correct type
            if x_cat is not None:
                x_cat = x_cat.to(device)
            
            optimizer.zero_grad()
            outputs = model.forward_from_processed(x_num, x_cat)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item() * x_num.size(0)
        
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                if len(batch) == 2:
                    x_num, y = batch
                    x_cat = None
                else:
                    x_num, x_cat, y = batch
                
                x_num, y = x_num.to(device), y.to(device)
                if x_cat is not None:
                    x_cat = x_cat.to(device)
                
                outputs = model.forward_from_processed(x_num, x_cat)
                loss = criterion(outputs, y)
                val_loss += loss.item() * x_num.size(0)
                
                all_preds.append(outputs.cpu())
                all_targets.append(y.cpu())
        
        val_loss /= len(val_loader.dataset)
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        val_losses.append(val_loss)
        
        # Calculate metrics
        if model.num_classes is not None:  # Classification
            preds = torch.argmax(all_preds, dim=1)
            targets = all_targets.long()
            accuracy = accuracy_score(targets, preds)
            val_accuracies.append(accuracy)
            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {accuracy:.4f}")
        else:  # Regression
            rmse = mean_squared_error(all_targets, all_preds, squared=False)
            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val RMSE: {rmse:.4f}")
            val_accuracies = None
        
        # Save best model
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     torch.save(model.state_dict(), 'best_model.pth')
    
    print("Training complete!")
    # model.load_state_dict(torch.load('best_model.pth'))  # Load best model
    return train_losses, val_losses, val_accuracies, model

def predictions(model, df):
    with torch.no_grad():
        predictions_prob = model(df)
        predictions = np.argmax(predictions_prob, axis=1)
    df['prediction'] = predictions
    return df

def confusion_matrix(prediction_df):
    y = prediction_df['target'].values
    y_pred = prediction_df['prediction'].values
    TP = FP = TN = FN = 0
    for t, p in zip(y, y_pred):
        TP += t == 1 and p == 1
        FP += t == 0 and p == 1
        TN += t == 0 and p == 0
        FN += t == 1 and p == 0
    return np.array([[TP, FP], [FN, TN]])

In [8]:
iris = datasets.load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['target'] = iris.target 
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)

In [9]:
# Split data
train_df, val_df = train_test_split(df_train, test_size=0.2, random_state=42)

# Preprocess data
preprocessor = DataFramePreprocessor().fit(train_df)

# Create datasets and dataloaders
train_dataset = TabularDataset(train_df, preprocessor)
val_dataset = TabularDataset(val_df, preprocessor)

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [10]:
# Initialize model
model = FTTransformer(
    preprocessor=preprocessor,
    d_model=64,
    nhead=4,
    num_layers=3,
    dim_feedforward=128,
    dropout=0.1,
    num_classes=4  # Binary classification
)

In [11]:
total = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name:50s} {param.numel():,}")
        total += param.numel()

print(f"\nTotal trainable parameters: {total:,}")

feature_tokenizer.cls_token                        64
feature_tokenizer.num_projection.weight            256
feature_tokenizer.num_projection.bias              64
feature_tokenizer.norm.weight                      64
feature_tokenizer.norm.bias                        64
transformer_encoder.layers.0.self_attn.in_proj_weight 12,288
transformer_encoder.layers.0.self_attn.in_proj_bias 192
transformer_encoder.layers.0.self_attn.out_proj.weight 4,096
transformer_encoder.layers.0.self_attn.out_proj.bias 64
transformer_encoder.layers.0.linear1.weight        8,192
transformer_encoder.layers.0.linear1.bias          128
transformer_encoder.layers.0.linear2.weight        8,192
transformer_encoder.layers.0.linear2.bias          64
transformer_encoder.layers.0.norm1.weight          64
transformer_encoder.layers.0.norm1.bias            64
transformer_encoder.layers.0.norm2.weight          64
transformer_encoder.layers.0.norm2.bias            64
transformer_encoder.layers.1.self_attn.in_proj_weight 12

In [12]:
# Set up training
device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.CrossEntropyLoss()  # For multi-class classification
# criterion = nn.MSELoss()  # For regression
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,          
    weight_decay=1e-2
)

scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    total_iters=5   
)
epochs = 10

In [13]:
# Train the model
train_loss, val_loss, val_acc, trained_model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=epochs,
    device=device
)

Epoch 1/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 102.96it/s]


Epoch 1: Train Loss: 1.6307, Val Loss: 0.9043, Val Acc: 0.5417


Epoch 2/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 108.33it/s]


Epoch 2: Train Loss: 0.6482, Val Loss: 0.5082, Val Acc: 0.7500


Epoch 3/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 94.38it/s]


Epoch 3: Train Loss: 0.5239, Val Loss: 0.3989, Val Acc: 0.8333


Epoch 4/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 101.20it/s]


Epoch 4: Train Loss: 0.5252, Val Loss: 0.4085, Val Acc: 0.7917


Epoch 5/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 100.62it/s]


Epoch 5: Train Loss: 0.4392, Val Loss: 0.3747, Val Acc: 0.8750


Epoch 6/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 103.11it/s]


Epoch 6: Train Loss: 0.3781, Val Loss: 0.3502, Val Acc: 0.8750


Epoch 7/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 80.16it/s]


Epoch 7: Train Loss: 0.3721, Val Loss: 0.3362, Val Acc: 0.8750


Epoch 8/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 102.18it/s]


Epoch 8: Train Loss: 0.3921, Val Loss: 0.3542, Val Acc: 0.8750


Epoch 9/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 105.82it/s]


Epoch 9: Train Loss: 0.4591, Val Loss: 0.3397, Val Acc: 0.8750


Epoch 10/10 - Training: 100%|██████████| 24/24 [00:00<00:00, 104.59it/s]

Epoch 10: Train Loss: 0.4425, Val Loss: 0.2992, Val Acc: 0.9167
Training complete!





In [14]:
import plotly.graph_objects as go

epochs_range = list(range(1, len(train_loss) + 1))

fig = go.Figure()

# Train loss
fig.add_trace(go.Scatter(
    x=epochs_range,
    y=train_loss,
    mode='lines+markers',
    name='Train Loss',
    yaxis='y1'
))

# Validation loss
fig.add_trace(go.Scatter(
    x=epochs_range,
    y=val_loss,
    mode='lines+markers',
    name='Validation Loss',
    yaxis='y1'
))

# Validation accuracy (secondary axis)
fig.add_trace(go.Scatter(
    x=epochs_range,
    y=val_acc,
    mode='lines+markers',
    name='Validation Accuracy',
    yaxis='y2'
))

fig.update_layout(
    title='Training & Validation Loss with Validation Accuracy',
    xaxis_title='Epoch',
    yaxis=dict(
        title='Loss',
        side='left'
    ),
    yaxis2=dict(
        title='Accuracy',
        overlaying='y',
        side='right',
        range=[0, 1]
    ),
    legend=dict(x=0.01, y=0.99),
    template='plotly_white'
)

fig.show()


In [15]:
prediction = predictions(trained_model, df_test)
confusion_matrix(prediction)

array([[ 7,  0],
       [ 0, 10]])