## TABM

In [None]:

import torch
import torch.nn as nn

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


# ===== LAYERS =====

class linear_BE(nn.Module):
    def __init__(self, in_features: int, out_features: int, k=32, dropout_rate=0.1, initialize_to_1=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.k = k

        if initialize_to_1:  # For TabM
            self.R = nn.Parameter(torch.ones(k, in_features))
            self.S = nn.Parameter(torch.ones(k, out_features))
        else:
            # Paper generates randomly with +-1
            self.R = nn.Parameter(torch.zeros((k, in_features)))
            nn.init.uniform_(self.R, -1, 1)
            self.S = nn.Parameter(torch.zeros((k, out_features)))
            nn.init.uniform_(self.S, -1, 1)

        self.W = nn.Parameter(torch.zeros((in_features, out_features)))
        nn.init.uniform_(self.W, -1, 1)
        self.B = nn.Parameter(torch.zeros((k, out_features)))
        nn.init.uniform_(self.B, -1, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X: torch.Tensor):
        """
        Shapes:

        X: (batch_size, k, in_features)
        R: (k, in_features)
        W: (in_features, out_features)
        S: (k, out_features)
        B: (k, out_features)
        output: (batch_size, k, out_features)

        Formula:
        output = ( (X * R) W) * S + B
        """
        output = X * self.R

        output = torch.einsum("bki,io->bko", output, self.W)
        output = output * self.S + self.B
        output = self.relu(output)
        output = self.dropout(output)

        return output

    def extra_repr(self):
        """
        Adds information about the layer to its string representation (useful when printing the model)
        """
        return f"in_features={self.in_features}, out_features={self.out_features}"


class MLP_layer(nn.Module):
    def __init__(self, in_features: int, out_features: int, dropout_rate=0.1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X: torch.Tensor):
        output = self.linear(X)
        output = self.relu(output)
        output = self.dropout(output)

        return output

    def extra_repr(self):
        return f"in_features={self.in_features}, out_features={self.out_features}"


class MLPk_layer(nn.Module):
    def __init__(self, in_features: int, out_features: int, k=32, dropout_rate=0.1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = nn.Parameter(torch.zeros((k, in_features, out_features)))
        nn.init.uniform_(self.W, -1, 1)
        self.B = nn.Parameter(torch.zeros((k, out_features)))
        nn.init.uniform_(self.B, -1, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X: torch.Tensor):
        """
        Shapes:

        X: (batch_size, k, in_features)
        W: (k, in_features, out_features)
        B: (k, out_features)
        output: (batch_size, k, out_features)

        Formula:
        output = X @ W + B
        """
        output = torch.einsum("bki,kio->bko", X, self.W)
        output = output + self.B

        output = self.relu(output)
        output = self.dropout(output)

        return output

    def extra_repr(self):
        return f"in_features={self.in_features}, out_features={self.out_features}"


# ===== BACKBONES =====


class TabM_naive(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        self.in_features = in_features
        self.hidden_sizes = hidden_sizes
        self.k = k

        layer_sizes = [in_features] + hidden_sizes

        layers = [linear_BE(layer_sizes[i], layer_sizes[i+1], k, dropout_rate) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)


class TabM_mini(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        self.k = k

        self.R = nn.Parameter(torch.randn(k, in_features))

        layer_sizes = [in_features] + hidden_sizes

        layers = [MLP_layer(layer_sizes[i], layer_sizes[i+1], dropout_rate) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        output = X * self.R
        return self.layers(output)


class TabM(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        self.k = k

        layer_sizes = [in_features] + hidden_sizes

        layers = [linear_BE(layer_sizes[i], layer_sizes[i+1], k, dropout_rate, initialize_to_1=True) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)


class MLPk(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        layer_sizes = [in_features] + hidden_sizes

        layers = [MLPk_layer(layer_sizes[i], layer_sizes[i+1], k, dropout_rate) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)

# ===== MODELS =====


class MLP(nn.Module):
    """
    Simple MLP model
    """

    def __init__(self, in_features: int, hidden_sizes: int, out_features: int, dropout_rate=0.1):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        layer_sizes = [in_features] + hidden_sizes + [out_features]

        layers = [*[MLP_layer(layer_sizes[i], layer_sizes[i+1], dropout_rate) for i in range(len(layer_sizes)-1)],
                  nn.Linear(layer_sizes[-1], out_features)
                  ]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)


class EnsembleModel(nn.Module):
    """
    Global ensemble model that :
    - takes batched input (batch, in_features)
    - clones it k times (batch, k, in_features)
    - passes it through a backbone (which model you want e.g TabM, MLPk, etc.) (batch, k, hidden_sizes[-1])
    - passes the output through k prediction heads, mean over heads (batch, out_features)
    """

    def __init__(self, backbone: nn.Module, in_features: int, hidden_sizes: int, out_features: int, k=32, dropout_rate=0.1, mean_over_heads=True):
        super().__init__()

        self.backbone = backbone(in_features, hidden_sizes, k, dropout_rate)
        self.in_features = in_features
        self.k = k

        self.mean_over_heads = mean_over_heads

        self.pred_heads = nn.ModuleList([nn.Linear(hidden_sizes[-1], out_features) for _ in range(k)])

    def forward(self, X: torch.Tensor):
        # clone X to shape (batch, k, dim)
        X = X.unsqueeze(1).repeat(1, self.k, 1)

        # pass through backbone
        X = self.backbone(X)

        # pass through prediction heads
        preds = [head(X[:, i]) for i, head in enumerate(self.pred_heads)]

        # concatenate head predictions
        preds = torch.stack(preds, dim=1)

        if self.mean_over_heads:
            preds = preds.mean(dim=1)

        return preds



## Prunning

In [None]:
import torch
import torch.nn as nn

class PrunableEnsembleModel(nn.Module):
    """
    Ensemble model avec mécanisme de pruning :
    - Mesure la contribution individuelle de chaque sous-modèle.
    - Permet de supprimer progressivement les sous-modèles les moins performants.
    """
    def __init__(self, backbone: nn.Module, in_features: int, hidden_sizes: int, out_features: int, k=32, dropout_rate=0.1):
        super().__init__()
        self.backbone = backbone(in_features, hidden_sizes, k, dropout_rate)
        self.in_features = in_features
        self.hidden_sizes = hidden_sizes
        self.k = k
        self.out_features = out_features

        # Prédiction par sous-modèle (on ne moyenne pas contrairement à ce qu'on fait dans la classe EnsembleModel)
        self.pred_heads = nn.ModuleList([nn.Linear(hidden_sizes[-1], out_features) for _ in range(k)])

    def forward(self, X: torch.Tensor):
        """
        Forward pass :
        - Les prédictions de chaque sous-modèle sont concaténées.
        """
        # Répliquer X pour chaque sous-modèle
        X = X.unsqueeze(1).repeat(1, self.k, 1)

        # Passer par le backbone
        X = self.backbone(X)  # (batch, k, hidden_sizes[-1])

        # Prédictions de chaque sous-modèle
        preds = [head(X[:, i]) for i, head in enumerate(self.pred_heads)]
        preds = torch.stack(preds, dim=1)  # (batch, k, out_features)

        return preds

    def prune(self, X: torch.Tensor, y: torch.Tensor, keep_ratio=0.5, task_type="binary"):
      """
      Pruning des sous-modèles :
      - Garde un ratio donné des sous-modèles avec les meilleures performances.
      - Suppression des sous-modèles moins performants.
      - S'adapte aux tâches de classification binaire, multiclass ou de régression.

      Args:
          X : Entrée des données.
          y : Labels ou valeurs cibles.
          keep_ratio : Ratio de sous-modèles à conserver.
          task_type : Type de tâche ('binary', 'multiclass', ou 'regression').
      """
      if not (0 < keep_ratio <= 1):
          raise ValueError("keep_ratio doit être un float entre 0 et 1.")

      with torch.no_grad():

          if task_type == "binary":
              criterion = nn.BCEWithLogitsLoss(reduction="none")
              y = y.float().view(-1, 1)
          elif task_type == "multiclass":
              criterion = nn.CrossEntropyLoss(reduction="none")
          elif task_type == "regression":
              criterion = nn.MSELoss(reduction="none")
          else:
              raise ValueError("task_type must be one of 'binary', 'multiclass', or 'regression'.")

          # Loss
          losses = []
          preds = self.forward(X)  # (batch, k, out_features)

          for i in range(self.k):
              if task_type == "multiclass":
                  loss = criterion(preds[:, i, :], y)
              else:
                  loss = criterion(preds[:, i, :].reshape(-1, 1), y.reshape(-1, 1))
              losses.append(loss.mean().item())  # Moyenne de la perte pour chaque sous-modèle

          # Trier les sous-modèles par perte (loss croissante)
          sorted_indices = sorted(range(self.k), key=lambda i: losses[i])
          keep_count = max(1, int(self.k * keep_ratio))
          keep_indices = sorted_indices[:keep_count]

          # Mettre à jour les sous-modèles et leurs paramètres
          self.pred_heads = nn.ModuleList([self.pred_heads[i] for i in keep_indices])
          self.k = keep_count

          # Pruning des paramètres du backbone
          for layer in self.backbone.layers:
              if hasattr(layer, "R") and hasattr(layer, "S") and hasattr(layer, "B"):
                  layer.R = nn.Parameter(layer.R[keep_indices])
                  layer.S = nn.Parameter(layer.S[keep_indices])
                  layer.B = nn.Parameter(layer.B[keep_indices])

          # print(f"Pruning effectué : {len(sorted_indices) - keep_count} sous-modèles supprimés. {self.k} restants.")


In [None]:
def test_pruning(model, train_loader, keep_ratios, task_type="binary"):
    """
    Test le mécanisme de pruning.
    - Réduit progressivement les sous-modèles.
    - Mesure la précision globale après chaque étape.

    Args:
        model : Modèle prunable.
        train_loader : DataLoader pour les données de test.
        keep_ratios : Liste des ratios de sous-modèles à conserver.
        task_type : Type de tâche ('binary', 'multiclass', ou 'regression').
    """

    for keep_ratio in keep_ratios:
        print(f"\nPruning avec keep_ratio = {keep_ratio}")

        # Pruning (utiliser un batch pour sélectionner les sous-modèles)
        for X_batch, y_batch in train_loader:
            model.prune(X_batch, y_batch, keep_ratio, task_type=task_type)
            break

        # Évaluation des performances après pruning
        total_correct = 0
        total_samples = 0
        total_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in train_loader:
                preds = model.forward(X_batch)  # Prédictions des sous-modèles restants

                if task_type == "binary":
                    preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
                    preds_binary = torch.sigmoid(preds)
                    preds_binary = (preds_binary > 0.5)
                    total_correct += (preds_binary == y_batch.reshape(-1,1)).sum().item()

                elif task_type == "multiclass":
                    preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
                    preds_class = preds.argmax(dim=1)  # Trouver la classe avec la probabilité la plus élevée
                    total_correct += (preds_class == y_batch).sum().item()
                elif task_type == "regression":
                    preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
                    total_loss += nn.functional.mse_loss(preds, y_batch).item()

                total_samples += y_batch.size(0)

        # Calcul des métriques
        if task_type in ["binary", "multiclass"]:
            accuracy = total_correct / total_samples
            print(f"Précision globale après pruning : {accuracy:.4f}")
        elif task_type == "regression":
            avg_loss = total_loss / total_samples
            print(f"Loss moyenne après pruning : {avg_loss:.4f}")



In [None]:
# import torch
# import torch.nn as nn
# import time

# def train_with_pruning(model, train_loader, keep_ratios, task_type="binary", epochs=10, learning_rate=0.001, dataset_name=""):
#     """
#     Entraîne un modèle avec pruning à la fin de chaque epoch.
#     - Réduit progressivement les sous-modèles après chaque epoch.
#     - Effectue une sélection des modèles basée sur les performances après chaque epoch.

#     Args:
#         model : Modèle prunable.
#         train_loader : DataLoader pour les données d'entraînement.
#         keep_ratios : Liste des ratios de sous-modèles à conserver pour chaque epoch.
#         task_type : Type de tâche ('binary', 'multiclass', ou 'regression').
#         epochs : Nombre d'epochs pour l'entraînement.
#         learning_rate : Taux d'apprentissage pour l'optimiseur.
#     """

#     optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#     # Variables pour le suivi du meilleur modèle
#     best_accuracy = 0
#     best_model_state = None
#     best_keep_ratio = None
#     best_epoch = None
#     start_time = time.time()  # Début du temps d'entraînement

#     for epoch in range(epochs):
#         model.train()

#         total_correct = 0
#         total_samples = 0
#         total_loss = 0.0

#         with torch.no_grad():
#             for X_batch, y_batch in train_loader:
#                 preds = model.forward(X_batch)  # Prédictions des sous-modèles restants

#                 if task_type == "binary":
#                     preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
#                     preds_binary = torch.sigmoid(preds)
#                     preds_binary = (preds_binary > 0.5)
#                     total_correct += (preds_binary == y_batch.reshape(-1, 1)).sum().item()

#                 elif task_type == "multiclass":
#                     preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
#                     preds_class = preds.argmax(dim=1)  # Trouver la classe avec la probabilité la plus élevée
#                     total_correct += (preds_class == y_batch).sum().item()
#                 elif task_type == "regression":
#                     preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
#                     total_loss += nn.functional.mse_loss(preds, y_batch).item()

#                 total_samples += y_batch.size(0)

#         # Calcul de la précision ou de la perte moyenne après chaque epoch
#         if task_type in ["binary", "multiclass"]:
#             accuracy = total_correct / total_samples
#             # print(f"Epoch {epoch+1}/{epochs}, Accuracy: {accuracy:.4f}")
#             if accuracy > best_accuracy:
#                 best_accuracy = accuracy
#                 best_model_state = model.state_dict()
#                 # Sauvegarder le keep_ratio correspondant à la meilleure performance
#                 best_keep_ratio = keep_ratios[0]  # initialisation avec un keep_ratio
#                 best_epoch = epoch + 1
#         elif task_type == "regression":
#             avg_loss = total_loss / total_samples
#             # print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

#         # Appliquer le pruning après chaque epoch avec un seul keep_ratio
#         for keep_ratio in keep_ratios:
#             # Appliquer pruning avec ce ratio
#             model.prune(X_batch, y_batch, keep_ratio, task_type=task_type)

#             # Vérifier la performance après pruning avec ce ratio
#             if task_type in ["binary", "multiclass"]:
#                 accuracy = total_correct / total_samples
#                 if accuracy > best_accuracy:
#                     best_accuracy = accuracy
#                     best_keep_ratio = keep_ratio  # Sauvegarder le keep_ratio pour lequel la performance est la meilleure

#     total_time = time.time() - start_time

#     # Sauvegarder le meilleur modèle
#     if best_model_state is not None:
#         model.load_state_dict(best_model_state)
#         print(f"\nOptimal model obtained with keep_ratio = {best_keep_ratio}, achieving an accuracy of {best_accuracy:.4f} in epoch {best_epoch}.")
#         print(f"Training time: {total_time:.2f} seconds.")

#         torch.save(model.state_dict(), "best_pruned_model_" + dataset_name + ".pth")


## Training

### Dataset "Breast cancer"

In [None]:
# Chargement des données
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train = torch.tensor(scaler.fit_transform(X_train), dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # Reshape pour BCEWithLogitsLoss

# DataLoader
train_loader = DataLoader(list(zip(X_train, y_train)), batch_size=25, shuffle=True)

# Initialisation du modèle
layers = [64, 32, 16]
prunable_model = PrunableEnsembleModel(TabM, in_features=X_train.shape[1], hidden_sizes=layers, out_features=1, k=32)

# Test du pruning
keep_ratios = [1, 0.75, 0.5, 0.25, 0.1]  # Réduction progressive
test_pruning(prunable_model, train_loader, keep_ratios)



Pruning avec keep_ratio = 1
Précision globale après pruning : 0.4549

Pruning avec keep_ratio = 0.75
Précision globale après pruning : 0.6264

Pruning avec keep_ratio = 0.5
Précision globale après pruning : 0.6396

Pruning avec keep_ratio = 0.25
Précision globale après pruning : 0.8198

Pruning avec keep_ratio = 0.1
Précision globale après pruning : 0.6967


In [None]:
#

garder le mode des meilleurs ratio à chaque epoch

à chaque epoch tester avec un plus faible ratio et un ratio plus elevé

In [None]:
# # Entrainement sur plusieurs epochs
# train_with_pruning(prunable_model, train_loader, keep_ratios, epochs=30, learning_rate=0.001, dataset_name="cancer")


Optimal model obtained with keep_ratio = 1, achieving an accuracy of 0.7385 in epoch 28.
Training time: 0.85 seconds.


### Dataset "Adult income"

In [None]:
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def get_adult_income_data(split=0.2, batch_size=32, seed=42):
    """
    Chargement et préparation des données pour la classification `adult income`
    """
    data = pd.read_csv("adult.csv")
    X = data.drop(columns='income')
    y = data['income']  # target

    X = pd.get_dummies(X)

    le = LabelEncoder()
    y = le.fit_transform(y)

    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=split, random_state=seed)

    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.tensor(X_train).float(), torch.tensor(y_train).long()),
        batch_size=batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.tensor(X_test).float(), torch.tensor(y_test).long()),
        batch_size=batch_size, shuffle=False
    )
    shape_x = X_train.shape[1]
    shape_y = y_train.reshape(-1,1).shape[1]
    return train_loader, test_loader, shape_x, shape_y

In [None]:
# Données
BATCH_SIZE = 32
train_loader, test_loader, shape_x, shape_y = get_adult_income_data(split=.2, batch_size=BATCH_SIZE, seed=42)

# Modèle
layers = [64, 32, 16]
prunable_model = PrunableEnsembleModel(TabM, in_features=shape_x, hidden_sizes=layers, out_features=1, k=32)

# Test du pruning
keep_ratios = [1, 0.75, 0.5, 0.25, 0.1]  # Réduction progressive du nombre de sous modèles
test_pruning(prunable_model, train_loader, keep_ratios, task_type="binary")



Pruning avec keep_ratio = 1
Précision globale après pruning : 0.5129

Pruning avec keep_ratio = 0.75
Précision globale après pruning : 0.5854

Pruning avec keep_ratio = 0.5
Précision globale après pruning : 0.7401

Pruning avec keep_ratio = 0.25
Précision globale après pruning : 0.6917

Pruning avec keep_ratio = 0.1
Précision globale après pruning : 0.6256


In [None]:
# # Entrainement sur plusieurs epochs
# train_with_pruning(prunable_model, train_loader, keep_ratios, epochs=15, learning_rate=0.001, dataset_name="adult")


Optimal model obtained with keep_ratio = 1, achieving an accuracy of 0.6302 in epoch 4.
Training time: 23.70 seconds.
