## TABM

In [None]:

import torch
import torch.nn as nn

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


# ===== LAYERS =====

class linear_BE(nn.Module):
    def __init__(self, in_features: int, out_features: int, k=32, dropout_rate=0.1, initialize_to_1=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.k = k

        if initialize_to_1:  # For TabM
            self.R = nn.Parameter(torch.ones(k, in_features))
            self.S = nn.Parameter(torch.ones(k, out_features))
        else:
            # Paper generates randomly with +-1
            self.R = nn.Parameter(torch.zeros((k, in_features)))
            nn.init.uniform_(self.R, -1, 1)
            self.S = nn.Parameter(torch.zeros((k, out_features)))
            nn.init.uniform_(self.S, -1, 1)

        self.W = nn.Parameter(torch.zeros((in_features, out_features)))
        nn.init.uniform_(self.W, -1, 1)
        self.B = nn.Parameter(torch.zeros((k, out_features)))
        nn.init.uniform_(self.B, -1, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X: torch.Tensor):
        """
        Shapes:

        X: (batch_size, k, in_features)
        R: (k, in_features)
        W: (in_features, out_features)
        S: (k, out_features)
        B: (k, out_features)
        output: (batch_size, k, out_features)

        Formula:
        output = ( (X * R) W) * S + B
        """
        output = X * self.R

        output = torch.einsum("bki,io->bko", output, self.W)
        output = output * self.S + self.B
        output = self.relu(output)
        output = self.dropout(output)

        return output

    def extra_repr(self):
        """
        Adds information about the layer to its string representation (useful when printing the model)
        """
        return f"in_features={self.in_features}, out_features={self.out_features}"


class MLP_layer(nn.Module):
    def __init__(self, in_features: int, out_features: int, dropout_rate=0.1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X: torch.Tensor):
        output = self.linear(X)
        output = self.relu(output)
        output = self.dropout(output)

        return output

    def extra_repr(self):
        return f"in_features={self.in_features}, out_features={self.out_features}"


class MLPk_layer(nn.Module):
    def __init__(self, in_features: int, out_features: int, k=32, dropout_rate=0.1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = nn.Parameter(torch.zeros((k, in_features, out_features)))
        nn.init.uniform_(self.W, -1, 1)
        self.B = nn.Parameter(torch.zeros((k, out_features)))
        nn.init.uniform_(self.B, -1, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X: torch.Tensor):
        """
        Shapes:

        X: (batch_size, k, in_features)
        W: (k, in_features, out_features)
        B: (k, out_features)
        output: (batch_size, k, out_features)

        Formula:
        output = X @ W + B
        """
        output = torch.einsum("bki,kio->bko", X, self.W)
        output = output + self.B

        output = self.relu(output)
        output = self.dropout(output)

        return output

    def extra_repr(self):
        return f"in_features={self.in_features}, out_features={self.out_features}"


# ===== BACKBONES =====


class TabM_naive(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        self.in_features = in_features
        self.hidden_sizes = hidden_sizes
        self.k = k

        layer_sizes = [in_features] + hidden_sizes

        layers = [linear_BE(layer_sizes[i], layer_sizes[i+1], k, dropout_rate) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)


class TabM_mini(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        self.k = k

        self.R = nn.Parameter(torch.randn(k, in_features))

        layer_sizes = [in_features] + hidden_sizes

        layers = [MLP_layer(layer_sizes[i], layer_sizes[i+1], dropout_rate) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        output = X * self.R
        return self.layers(output)


class TabM(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        self.k = k

        layer_sizes = [in_features] + hidden_sizes

        layers = [linear_BE(layer_sizes[i], layer_sizes[i+1], k, dropout_rate, initialize_to_1=True) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)


class MLPk(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: int, k=32, dropout_rate=0.1):
        super().__init__()

        layer_sizes = [in_features] + hidden_sizes

        layers = [MLPk_layer(layer_sizes[i], layer_sizes[i+1], k, dropout_rate) for i in range(len(layer_sizes)-1)]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)

# ===== MODELS =====


class MLP(nn.Module):
    """
    Simple MLP model
    """

    def __init__(self, in_features: int, hidden_sizes: int, out_features: int, dropout_rate=0.1):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        layer_sizes = [in_features] + hidden_sizes + [out_features]

        layers = [*[MLP_layer(layer_sizes[i], layer_sizes[i+1], dropout_rate) for i in range(len(layer_sizes)-1)],
                  nn.Linear(layer_sizes[-1], out_features)
                  ]

        self.layers = nn.Sequential(*layers)

    def forward(self, X: torch.Tensor):
        return self.layers(X)


class EnsembleModel(nn.Module):
    """
    Global ensemble model that :
    - takes batched input (batch, in_features)
    - clones it k times (batch, k, in_features)
    - passes it through a backbone (which model you want e.g TabM, MLPk, etc.) (batch, k, hidden_sizes[-1])
    - passes the output through k prediction heads, mean over heads (batch, out_features)
    """

    def __init__(self, backbone: nn.Module, in_features: int, hidden_sizes: int, out_features: int, k=32, dropout_rate=0.1, mean_over_heads=True):
        super().__init__()

        self.backbone = backbone(in_features, hidden_sizes, k, dropout_rate)
        self.in_features = in_features
        self.k = k

        self.mean_over_heads = mean_over_heads

        self.pred_heads = nn.ModuleList([nn.Linear(hidden_sizes[-1], out_features) for _ in range(k)])

    def forward(self, X: torch.Tensor):
        # clone X to shape (batch, k, dim)
        X = X.unsqueeze(1).repeat(1, self.k, 1)

        # pass through backbone
        X = self.backbone(X)

        # pass through prediction heads
        preds = [head(X[:, i]) for i, head in enumerate(self.pred_heads)]

        # concatenate head predictions
        preds = torch.stack(preds, dim=1)

        if self.mean_over_heads:
            preds = preds.mean(dim=1)

        return preds



## Prunning

In [None]:
import torch
import torch.nn as nn

class PrunableEnsembleModel(nn.Module):
    """
    Ensemble model avec mécanisme de pruning :
    - Mesure la contribution individuelle de chaque sous-modèle.
    - Permet de supprimer progressivement les sous-modèles les moins performants.
    """
    def __init__(self, backbone: nn.Module, in_features: int, hidden_sizes: int, out_features: int, k=32, dropout_rate=0.1):
        super().__init__()
        self.backbone = backbone(in_features, hidden_sizes, k, dropout_rate)
        self.in_features = in_features
        self.hidden_sizes = hidden_sizes
        self.k = k
        self.out_features = out_features

        # Prédiction par sous-modèle (on ne moyenne pas contrairement à ce qu'on fait dans la classe EnsembleModel)
        self.pred_heads = nn.ModuleList([nn.Linear(hidden_sizes[-1], out_features) for _ in range(k)])

    def forward(self, X: torch.Tensor):
        """
        Forward pass :
        - Les prédictions de chaque sous-modèle sont concaténées.
        """
        # Répliquer X pour chaque sous-modèle
        X = X.unsqueeze(1).repeat(1, self.k, 1)

        # Passer par le backbone
        X = self.backbone(X)  # (batch, k, hidden_sizes[-1])

        # Prédictions de chaque sous-modèle
        preds = [head(X[:, i]) for i, head in enumerate(self.pred_heads)]
        preds = torch.stack(preds, dim=1)  # (batch, k, out_features)

        return preds

    def prune(self, X: torch.Tensor, y: torch.Tensor, keep_ratio=0.5, task_type="binary"):
      """
      Pruning des sous-modèles :
      - Garde un ratio donné des sous-modèles avec les meilleures performances.
      - Suppression des sous-modèles moins performants.
      - S'adapte aux tâches de classification binaire, multiclass ou de régression.

      Args:
          X : Entrée des données.
          y : Labels ou valeurs cibles.
          keep_ratio : Ratio de sous-modèles à conserver.
          task_type : Type de tâche ('binary', 'multiclass', ou 'regression').
      """
      if not (0 < keep_ratio <= 1):
          raise ValueError("keep_ratio doit être un float entre 0 et 1.")

      with torch.no_grad():

          if task_type == "binary":
              criterion = nn.BCEWithLogitsLoss(reduction="none")
              y = y.float().view(-1, 1)
          elif task_type == "multiclass":
              criterion = nn.CrossEntropyLoss(reduction="none")
          elif task_type == "regression":
              criterion = nn.MSELoss(reduction="none")
          else:
              raise ValueError("task_type must be one of 'binary', 'multiclass', or 'regression'.")

          # Loss
          losses = []
          preds = self.forward(X)  # (batch, k, out_features)

          for i in range(self.k):
              if task_type == "multiclass":
                  loss = criterion(preds[:, i, :], y)
              else:
                  loss = criterion(preds[:, i, :].reshape(-1, 1), y.reshape(-1, 1))
              losses.append(loss.mean().item())  # Moyenne de la perte pour chaque sous-modèle

          # Trier les sous-modèles par perte (loss croissante)
          sorted_indices = sorted(range(self.k), key=lambda i: losses[i])
          keep_count = max(1, int(self.k * keep_ratio))
          keep_indices = sorted_indices[:keep_count]

          # Mettre à jour les sous-modèles et leurs paramètres
          self.pred_heads = nn.ModuleList([self.pred_heads[i] for i in keep_indices])
          self.k = keep_count

          # Pruning des paramètres du backbone
          for layer in self.backbone.layers:
              if hasattr(layer, "R") and hasattr(layer, "S") and hasattr(layer, "B"):
                  layer.R = nn.Parameter(layer.R[keep_indices])
                  layer.S = nn.Parameter(layer.S[keep_indices])
                  layer.B = nn.Parameter(layer.B[keep_indices])

          # print(f"Pruning effectué : {len(sorted_indices) - keep_count} sous-modèles supprimés. {self.k} restants.")


In [None]:
def test_pruning(model, train_loader, keep_ratios, task_type="binary"):
    """
    Test le mécanisme de pruning.
    - Réduit progressivement les sous-modèles.
    - Mesure la précision globale après chaque étape.

    Args:
        model : Modèle prunable.
        train_loader : DataLoader pour les données de test.
        keep_ratios : Liste des ratios de sous-modèles à conserver.
        task_type : Type de tâche ('binary', 'multiclass', ou 'regression').
    """

    for keep_ratio in keep_ratios:
        print(f"\nPruning avec keep_ratio = {keep_ratio}")

        # Pruning (utiliser un batch pour sélectionner les sous-modèles)
        for X_batch, y_batch in train_loader:
            model.prune(X_batch, y_batch, keep_ratio, task_type=task_type)
            break

        # Évaluation des performances après pruning
        total_correct = 0
        total_samples = 0
        total_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in train_loader:
                preds = model.forward(X_batch)  # Prédictions des sous-modèles restants

                if task_type == "binary":
                    preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
                    preds_binary = torch.sigmoid(preds)
                    preds_binary = (preds_binary > 0.5)
                    total_correct += (preds_binary == y_batch.reshape(-1,1)).sum().item()

                elif task_type == "multiclass":
                    preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
                    preds_class = preds.argmax(dim=1)  # Trouver la classe avec la probabilité la plus élevée
                    total_correct += (preds_class == y_batch).sum().item()
                elif task_type == "regression":
                    preds = preds.mean(dim=1)  # Moyenne sur les sous-modèles
                    total_loss += nn.functional.mse_loss(preds, y_batch).item()

                total_samples += y_batch.size(0)

        # Calcul des métriques
        if task_type in ["binary", "multiclass"]:
            accuracy = total_correct / total_samples
            print(f"Précision globale après pruning : {accuracy:.4f}")
        elif task_type == "regression":
            avg_loss = total_loss / total_samples
            print(f"Loss moyenne après pruning : {avg_loss:.4f}")



In [None]:
import copy  # Import copy module for deep copying

def train_with_pruning(
    model, train_loader, test_loader, optimizer, criterion,
    epochs=10, keep_ratios=[1.0, 0.75, 0.5, 0.25], filepath="best_model.pth",
    dataset_name="", task_type="binary"
):
    """
    Entraîne un modèle avec un mécanisme de pruning progressif.

    Args:
        model : Modèle prunable.
        train_loader : DataLoader pour les données d'entraînement.
        test_loader : DataLoader pour les données de test.
        optimizer : Optimiseur.
        criterion : Fonction de perte.
        epochs : Nombre d'époques d'entraînement.
        keep_ratios : Liste des ratios de sous-modèles à conserver.
        filepath : Chemin pour sauvegarder le meilleur modèle.
        dataset_name : Nom du dataset utilisé.
        task_type : Type de tâche ('binary', 'multiclass', ou 'regression').

    Returns:
        model : Le meilleur modèle entraîné.
        best_config : Configuration associée au meilleur modèle.
    """
    best_loss = float('inf')
    best_model = None
    best_config = {}  # Sauvegarder la configuration du meilleur modèle

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            outputs = model(X_batch)

            if task_type == task_type:
                preds = outputs.mean(dim=1)
                loss = criterion(preds, y_batch.unsqueeze(1).float())

            elif task_type == task_type:

                preds = outputs.mean(dim=1)  # Moyenne sur les sous-modèles
                preds_class = preds.argmax(dim=1)  # Trouver la classe avec la probabilité la plus élevée
                loss = criterion(preds_class, y_batch)

            elif task_type == task_type:

                # loss = criterion(preds, y_batch.float())

                preds = outputs.mean(dim=1)  # Moyenne sur les sous-modèles
                loss += nn.functional.mse_loss(preds, y_batch).item()

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss}")

        # Tester différents niveaux de pruning
        for keep_ratio in keep_ratios:
            model_copy = copy.deepcopy(model)
            model_copy.prune(X_batch, y_batch, keep_ratio=keep_ratio, task_type=task_type)

            model_copy.eval()
            total_loss = 0.0
            total_correct = 0
            total_samples = 0

            with torch.no_grad():
                for X_batch, y_batch in test_loader:
                    outputs = model_copy(X_batch)

                    if task_type == "binary":
                        preds = outputs.mean(dim=1)
                        loss = criterion(preds, y_batch.unsqueeze(1).float())
                        # Calculer l'accuracy
                        probs = torch.sigmoid(preds)
                        predictions = (probs >= 0.5).float()
                        total_correct +=  (predictions == y_batch.unsqueeze(1)).float().sum()


                    elif task_type == "multiclass":
                        preds = outputs.mean(dim=1)  # Moyenne sur les sous-modèles
                        preds_class = preds.argmax(dim=1)  # Trouver la classe avec la probabilité la plus élevée
                        loss = criterion(preds_class, y_batch)

                        total_correct += (preds_class == y_batch).sum().item()

                    elif task_type == "regression":
                        preds = outputs.mean(dim=1)
                        preds = outputs.mean(dim=1)  # Moyenne sur les sous-modèles
                        loss += nn.functional.mse_loss(preds, y_batch).item()


                    total_loss += loss.item()
                    total_samples += y_batch.size(0)

            avg_loss = total_loss / len(test_loader)

            if task_type in ["binary", "multiclass"]:
                accuracy = total_correct / total_samples
                print(f"Pruning with keep_ratio={keep_ratio}, Accuracy: {accuracy:.4f}")
            elif task_type == "regression":
                print(f"Pruning with keep_ratio={keep_ratio}, Test Loss: {avg_loss:.4f}")

            if avg_loss < best_loss:
                best_loss = avg_loss
                best_model = model_copy
                best_config = {"k": model_copy.k}

        # Charger le meilleur modèle pour la prochaine epoch
        model = best_model

    return model, best_config


In [None]:
def save_model(model, filepath, config):
    """
    Sauvegarde un modèle PyTorch et sa configuration.

    Args:
        model (nn.Module): Le modèle à sauvegarder.
        filepath (str): Chemin vers le fichier de sauvegarde.
        config (dict): Configuration du modèle (par exemple, k après pruning).
    """
    # Include all necessary arguments in the config dictionary
    config.update({
        "backbone": model.backbone.__class__,  # Get the class of the backbone
        "in_features": model.in_features,
        "hidden_sizes": model.hidden_sizes,
        "out_features": model.out_features,
        "k": model.k  # Make sure k is also included
    })

    torch.save({
        "state_dict": model.state_dict(),
        "config": config
    }, filepath)
    print(f"Modèle et configuration sauvegardés dans {filepath}")

def load_model(filepath, model_class, default_config):
    """
    Charge un modèle PyTorch à partir d'un fichier.

    Args:
        filepath (str): Chemin vers le fichier du modèle sauvegardé.
        model_class (nn.Module): La classe du modèle à charger.
        default_config (dict): Configuration par défaut du modèle.

    Returns:
        nn.Module: Le modèle chargé et prêt à être utilisé.
    """
    # Charger le fichier
    checkpoint = torch.load(filepath)  # Remove weights_only=True
    state_dict = checkpoint["state_dict"]
    config = checkpoint.get("config", default_config)  # Utiliser la configuration sauvegardée ou par défaut

    # Réinitialiser le modèle avec la bonne configuration
    model = model_class(**config)

    # Charger les paramètres du modèle
    model.load_state_dict(state_dict)

    # Passer le modèle en mode évaluation
    model.eval()

    print(f"Modèle chargé depuis {filepath}")
    return model

## Training

### Dataset "Breast cancer"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Charger les données
data = load_breast_cancer()
X, y = data.data, data.target

# Diviser les données en ensembles d'entraînement et de test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Normaliser les données
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convertir en tenseurs PyTorch
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

# Créer des DataLoader
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
# Initialisation du modèle
layers = [64, 32, 16]
prunable_model = PrunableEnsembleModel(TabM, in_features=X_train.shape[1], hidden_sizes=layers, out_features=1, k=32)

# Test du pruning
keep_ratios = [1, 0.75, 0.5, 0.25, 0.1]  # Réduction progressive
test_pruning(prunable_model, train_loader, keep_ratios)



Pruning avec keep_ratio = 1
Précision globale après pruning : 0.4066

Pruning avec keep_ratio = 0.75
Précision globale après pruning : 0.5934

Pruning avec keep_ratio = 0.5
Précision globale après pruning : 0.5956

Pruning avec keep_ratio = 0.25
Précision globale après pruning : 0.4989

Pruning avec keep_ratio = 0.1
Précision globale après pruning : 0.5934


#### Entrainement sur plusieurs epochs

In [None]:

# Initialiser le modèle
model = PrunableEnsembleModel(TabM, in_features=X_train.shape[1], hidden_sizes=[64, 32], out_features=1, k=32)

# Définir l'optimiseur et la fonction de perte
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

filepath = "best_model.pth"
dataset_name = "cancer"

# Entraîner le modèle avec pruning
trained_model, best_config = train_with_pruning(model, train_loader, test_loader, optimizer, criterion, epochs=10)

# Sauvegarder le modèle et sa configuration
model_name = dataset_name+"_best_model.pth"
save_model(trained_model, model_name, best_config)

Epoch 1/10: 100%|██████████| 15/15 [00:00<00:00, 71.72it/s]


Epoch 1, Loss: 0.556146643559138
Pruning with keep_ratio=1.0, Accuracy: 0.9737
Pruning with keep_ratio=0.75, Accuracy: 0.9649
Pruning with keep_ratio=0.5, Accuracy: 0.9561
Pruning with keep_ratio=0.25, Accuracy: 0.8947


Epoch 2/10: 100%|██████████| 15/15 [00:00<00:00, 215.87it/s]


Epoch 2, Loss: 0.35309741894404095
Pruning with keep_ratio=1.0, Accuracy: 0.8947
Pruning with keep_ratio=0.75, Accuracy: 0.9035
Pruning with keep_ratio=0.5, Accuracy: 0.8596
Pruning with keep_ratio=0.25, Accuracy: 0.8509


Epoch 3/10: 100%|██████████| 15/15 [00:00<00:00, 341.05it/s]


Epoch 3, Loss: 0.32457735439141594
Pruning with keep_ratio=1.0, Accuracy: 0.9035
Pruning with keep_ratio=0.75, Accuracy: 0.9298
Pruning with keep_ratio=0.5, Accuracy: 0.8684
Pruning with keep_ratio=0.25, Accuracy: 0.7368


Epoch 4/10: 100%|██████████| 15/15 [00:00<00:00, 375.45it/s]


Epoch 4, Loss: 0.2928622752428055
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.9123
Pruning with keep_ratio=0.5, Accuracy: 0.9211
Pruning with keep_ratio=0.25, Accuracy: 0.8246


Epoch 5/10: 100%|██████████| 15/15 [00:00<00:00, 369.44it/s]


Epoch 5, Loss: 0.29867383738358816
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.9035
Pruning with keep_ratio=0.5, Accuracy: 0.8772
Pruning with keep_ratio=0.25, Accuracy: 0.8246


Epoch 6/10: 100%|██████████| 15/15 [00:00<00:00, 359.50it/s]


Epoch 6, Loss: 0.2947486917177836
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.8860
Pruning with keep_ratio=0.5, Accuracy: 0.8772
Pruning with keep_ratio=0.25, Accuracy: 0.8596


Epoch 7/10: 100%|██████████| 15/15 [00:00<00:00, 387.21it/s]


Epoch 7, Loss: 0.2886879285176595
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.8860
Pruning with keep_ratio=0.5, Accuracy: 0.9211
Pruning with keep_ratio=0.25, Accuracy: 0.8596


Epoch 8/10: 100%|██████████| 15/15 [00:00<00:00, 370.83it/s]


Epoch 8, Loss: 0.3113703628381093
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.9123
Pruning with keep_ratio=0.5, Accuracy: 0.9211
Pruning with keep_ratio=0.25, Accuracy: 0.8596


Epoch 9/10: 100%|██████████| 15/15 [00:00<00:00, 383.49it/s]


Epoch 9, Loss: 0.2977165271838506
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.9123
Pruning with keep_ratio=0.5, Accuracy: 0.9211
Pruning with keep_ratio=0.25, Accuracy: 0.8596


Epoch 10/10: 100%|██████████| 15/15 [00:00<00:00, 359.65it/s]


Epoch 10, Loss: 0.3089119285345078
Pruning with keep_ratio=1.0, Accuracy: 0.9298
Pruning with keep_ratio=0.75, Accuracy: 0.9123
Pruning with keep_ratio=0.5, Accuracy: 0.9211
Pruning with keep_ratio=0.25, Accuracy: 0.8596
Modèle et configuration sauvegardés dans cancer_best_model.pth


In [None]:

# Charger le modèle
default_config = {
    "backbone": TabM,
    "in_features": X_train.shape[1],
    "hidden_sizes": [64, 32, 16],
    "out_features": 1,
    "dropout_rate": 0.1
}

default_config.update(best_config)

loaded_model = load_model(model_name, PrunableEnsembleModel, default_config)

# Tester le modèle chargé
loaded_model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = loaded_model(X_batch)
        preds = torch.sigmoid(outputs.mean(dim=1)) > 0.5
        total_correct += (preds == y_batch.unsqueeze(1)).sum().item()
        total_samples += y_batch.size(0)

accuracy = total_correct / total_samples
print(f"Test Accuracy du modèle chargé: {accuracy:.4f}")

Modèle chargé depuis cancer_best_model.pth
Test Accuracy du modèle chargé: 0.9298


  checkpoint = torch.load(filepath)  # Remove weights_only=True


### Dataset "Adult income"

In [None]:
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def get_adult_income_data(split=0.2, batch_size=32, seed=42):
    """
    Chargement et préparation des données pour la classification `adult income`
    """
    data = pd.read_csv("adult.csv")
    X = data.drop(columns='income')
    y = data['income']  # target

    X = pd.get_dummies(X)

    le = LabelEncoder()
    y = le.fit_transform(y)

    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=split, random_state=seed)

    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.tensor(X_train).float(), torch.tensor(y_train).long()),
        batch_size=batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.tensor(X_test).float(), torch.tensor(y_test).long()),
        batch_size=batch_size, shuffle=False
    )
    shape_x = X_train.shape[1]
    shape_y = y_train.reshape(-1,1).shape[1]
    return train_loader, test_loader, shape_x, shape_y


In [None]:
# Données
BATCH_SIZE = 32
train_loader, test_loader, shape_x, shape_y = get_adult_income_data(split=.2, batch_size=BATCH_SIZE, seed=42)

# Modèle
layers = [64, 64]
prunable_model = PrunableEnsembleModel(TabM, in_features=shape_x, hidden_sizes=layers, out_features=1, k=32)

# Test du pruning
keep_ratios = [1, 0.75, 0.5, 0.25, 0.1]  # Réduction progressive du nombre de sous modèles
test_pruning(prunable_model, train_loader, keep_ratios, task_type="binary")



Pruning avec keep_ratio = 1
Précision globale après pruning : 0.3225

Pruning avec keep_ratio = 0.75
Précision globale après pruning : 0.6920

Pruning avec keep_ratio = 0.5
Précision globale après pruning : 0.7576

Pruning avec keep_ratio = 0.25
Précision globale après pruning : 0.7382

Pruning avec keep_ratio = 0.1
Précision globale après pruning : 0.6939


#### Entrainement sur plusieurs epochs

In [None]:
# Charger et prétraiter les données
train_loader, test_loader, shape_x, shape_y = get_adult_income_data(split=0.2, batch_size=32, seed=42)

# Afficher les dimensions des données
print(f"Nombre de caractéristiques (shape_x) : {shape_x}")
print(f"Nombre de classes (shape_y) : {shape_y}")

model = PrunableEnsembleModel(TabM, in_features=shape_x, hidden_sizes=[64, 32], out_features=shape_y, k=32)


Nombre de caractéristiques (shape_x) : 108
Nombre de classes (shape_y) : 1


In [None]:
# Définir l'optimiseur et la fonction de perte
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()  # Pour la classification binaire

# Entraîner le modèle avec pruning
trained_model, best_config = train_with_pruning(model, train_loader, test_loader, optimizer, criterion, epochs=10)

Epoch 1/10: 100%|██████████| 1222/1222 [00:19<00:00, 63.47it/s]


Epoch 1, Loss: 0.49272613635868956
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.7527
Pruning with keep_ratio=0.5, Accuracy: 0.7574
Pruning with keep_ratio=0.25, Accuracy: 0.7065


Epoch 2/10: 100%|██████████| 1222/1222 [00:14<00:00, 86.66it/s] 


Epoch 2, Loss: 0.4192813226248941
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.7532
Pruning with keep_ratio=0.5, Accuracy: 0.7025
Pruning with keep_ratio=0.25, Accuracy: 0.7311


Epoch 3/10: 100%|██████████| 1222/1222 [00:12<00:00, 97.05it/s] 


Epoch 3, Loss: 0.419344790401806
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.6934
Pruning with keep_ratio=0.5, Accuracy: 0.7169
Pruning with keep_ratio=0.25, Accuracy: 0.7488


Epoch 4/10: 100%|██████████| 1222/1222 [00:12<00:00, 98.34it/s] 


Epoch 4, Loss: 0.41885892923695756
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.7812
Pruning with keep_ratio=0.5, Accuracy: 0.7609
Pruning with keep_ratio=0.25, Accuracy: 0.7054


Epoch 5/10: 100%|██████████| 1222/1222 [00:12<00:00, 99.13it/s] 


Epoch 5, Loss: 0.41880364399007025
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.6944
Pruning with keep_ratio=0.5, Accuracy: 0.7608
Pruning with keep_ratio=0.25, Accuracy: 0.7535


Epoch 6/10: 100%|██████████| 1222/1222 [00:12<00:00, 98.00it/s] 


Epoch 6, Loss: 0.41804119725079075
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.7071
Pruning with keep_ratio=0.5, Accuracy: 0.7299
Pruning with keep_ratio=0.25, Accuracy: 0.7557


Epoch 7/10: 100%|██████████| 1222/1222 [00:12<00:00, 97.25it/s] 


Epoch 7, Loss: 0.4189207910052493
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.6540
Pruning with keep_ratio=0.5, Accuracy: 0.7172
Pruning with keep_ratio=0.25, Accuracy: 0.7710


Epoch 8/10: 100%|██████████| 1222/1222 [00:18<00:00, 66.84it/s]


Epoch 8, Loss: 0.421607683338042
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.7494
Pruning with keep_ratio=0.5, Accuracy: 0.7236
Pruning with keep_ratio=0.25, Accuracy: 0.7688


Epoch 9/10: 100%|██████████| 1222/1222 [00:12<00:00, 98.70it/s] 


Epoch 9, Loss: 0.4175804953674638
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.7534
Pruning with keep_ratio=0.5, Accuracy: 0.7235
Pruning with keep_ratio=0.25, Accuracy: 0.7327


Epoch 10/10: 100%|██████████| 1222/1222 [00:12<00:00, 99.70it/s] 


Epoch 10, Loss: 0.42010032670929076
Pruning with keep_ratio=1.0, Accuracy: 0.8249
Pruning with keep_ratio=0.75, Accuracy: 0.6887
Pruning with keep_ratio=0.5, Accuracy: 0.7696
Pruning with keep_ratio=0.25, Accuracy: 0.7244


In [None]:
dataset_name = "adult"
# Sauvegarder le modèle et sa configuration
model_name = dataset_name+"_best_model.pth"
save_model(trained_model, model_name, best_config)


Modèle et configuration sauvegardés dans adult_best_model.pth
