In [38]:
!python -m venv env
!source env/bin/activate  

!pip install boto3 pandas
!pip install networkx
!pip install typing
!pip install scikit-learn
!pip install torch
!pip install torch-geometric

#Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
import re
import math
import hashlib
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass, field
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings('ignore')

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device utilisé: {device}")

🚀 Device utilisé: cpu


Importation du JSON

In [16]:
import json

def json_to_tree(data, indent=1):
    prefix = " "*indent
    if isinstance(data, dict):
        for key, value in data.items():
            print(f"{prefix}- {key}")
            json_to_tree(value, indent +2)
    else: 
        pass

with open('0003d90ad249104a7ba0fb6bab08e8b9e70746e0cd2c3b30a006935b55f2a07b.json') as file: 
    data = json.load(file)
    #print(json.dumps(data, indent=2))
    json_to_tree(data)

    
    


 - styles
   - s0
     - bl
   - s1
     - bd
       - t
         - s
         - cl
           - rgb
       - b
         - s
         - cl
           - rgb
       - l
         - s
         - cl
           - rgb
       - r
         - s
         - cl
           - rgb
     - tb
   - s2
     - bd
       - t
         - s
         - cl
           - rgb
       - b
         - s
         - cl
           - rgb
       - l
         - s
         - cl
           - rgb
       - r
         - s
         - cl
           - rgb
     - tb
   - s3
     - bd
       - t
         - s
         - cl
           - rgb
       - b
         - s
         - cl
           - rgb
       - l
         - s
         - cl
           - rgb
       - r
         - s
         - cl
           - rgb
     - ht
     - tb
   - s4
     - bd
       - t
         - s
         - cl
           - rgb
   - s5
     - bd
       - t
         - s
         - cl
           - rgb
       - r
         - s
         - cl
           - rgb
   - s6
     - bd

Représentation du json en graph : 
- Création de la classe CellNode
- ExcelGraphBuilder

In [36]:
@dataclass
class FullCellInfo:
    """Structure complète d'une cellule Excel avec TOUTES ses informations"""
    # Contenu
    raw_value: Any = ""
    cell_type: int = 0  # 0=empty, 1=text, 2=number, 3=formula
    formula: str = ""
    
    # Position
    row: int = 0
    col: int = 0
    sheet_id: str = "default"
    
    # Style complet
    style_id: str = "s0"
    bold: bool = False
    italic: bool = False
    underline: bool = False
    font_size: int = 11
    font_family: str = "Arial"
    
    # Couleurs
    text_color: str = "#000000"
    background_color: str = "#FFFFFF" 
    
    # Bordures
    border_top: int = 0
    border_bottom: int = 0
    border_left: int = 0
    border_right: int = 0
    border_color: str = "#000000"
    
    # Alignement
    horizontal_align: int = 0  # 0=general, 1=left, 2=center, 3=right
    vertical_align: int = 0    # 0=bottom, 1=middle, 2=top
    text_wrap: bool = False
    
    # Format et métadonnées
    number_format: str = "General"
    is_merged: bool = False
    merge_range: Tuple[int, int, int, int] = (0, 0, 0, 0)
    has_comment: bool = False
    is_locked: bool = False
    is_hidden: bool = False
    
    def __post_init__(self):
        """Post-processing après création"""
        if isinstance(self.raw_value, (int, float)):
            self.cell_type = 2
        elif isinstance(self.raw_value, str) and self.raw_value.strip():
            self.cell_type = 1
        elif self.formula:
            self.cell_type = 3
        else:
            self.cell_type = 0

In [28]:
class ExcelParser:
    """Parse les fichiers Excel JSON en structures FullCellInfo"""
    
    @staticmethod
    def parse_excel_json(excel_data: Dict) -> List[FullCellInfo]:
        """Convertit JSON Excel en liste de FullCellInfo"""
        cells = []
        
        styles = excel_data.get('styles', {})
        
        for sheet_id, sheet_info in excel_data.get('sheets', {}).items():
            cell_data = sheet_info.get('cellData', {})
            
            for row_str, row_data in cell_data.items():
                row = int(row_str)
                for col_str, cell_info in row_data.items():
                    col = int(col_str)
                    
                    # Extraire les informations de style
                    style_id = cell_info.get('s', 's0')
                    style = styles.get(style_id, {})
                    
                    cell = FullCellInfo(
                        raw_value=cell_info.get('v', ''),
                        formula=cell_info.get('f', ''),
                        row=row,
                        col=col,
                        sheet_id=sheet_id,
                        style_id=style_id,
                        **ExcelParser._parse_style(style)
                    )
                    
                    cells.append(cell)
        
        return cells
    
    @staticmethod
    def _parse_style(style: Dict) -> Dict:
        """Parse les informations de style"""
        parsed = {
            'bold': bool(style.get('bl', 0)),
            'italic': bool(style.get('it', 0)),
            'underline': bool(style.get('ul', 0)),
            'font_size': style.get('fs', 11),
            'text_wrap': bool(style.get('tb', 0)),
        }
        
        # Couleurs
        if 'cl' in style:
            parsed['text_color'] = style['cl'].get('rgb', '#000000')
        if 'bg' in style:
            parsed['background_color'] = style['bg'].get('rgb', '#FFFFFF')
            
        # Bordures
        borders = style.get('bd', {})
        parsed.update({
            'border_top': borders.get('t', {}).get('s', 0),
            'border_bottom': borders.get('b', {}).get('s', 0),
            'border_left': borders.get('l', {}).get('s', 0),
            'border_right': borders.get('r', {}).get('s', 0),
        })
        
        # Alignement
        parsed.update({
            'horizontal_align': style.get('ht', 0),
            'vertical_align': style.get('vt', 0),
        })
        
        return parsed

In [None]:
class PositionalEncoding(nn.Module):
    """Encodage positionnel sinusoïdal pour Transformer"""
    
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """x shape: (seq_len, batch_size, d_model) ou (batch_size, seq_len, d_model)"""
        if x.dim() == 3 and x.size(1) == x.size(0):  # (batch, seq, d_model)
            seq_len = x.size(1)
            return x + self.pe[:seq_len, :].unsqueeze(0)
        else:  # (seq, batch, d_model)
            seq_len = x.size(0)
            return x + self.pe[:seq_len, :].unsqueeze(1)

In [None]:
class TransformerCellEmbedder(nn.Module):
    """
    Transformer multi-modal pour encoder les cellules Excel complètes.
    Chaque cellule → embedding 128D avec interactions cross-modales riches.
    """
    
    def __init__(self, 
                 embedding_dim: int = 128,
                 vocab_size: int = 10000,
                 num_heads: int = 8,
                 num_layers: int = 4,
                 max_text_length: int = 20,
                 dropout: float = 0.1):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.max_text_length = max_text_length
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        # === VOCABULAIRE ET MAPPINGS ===
        self.string_to_id = {'<PAD>': 0, '<UNK>': 1, '<CLS>': 2, '<SEP>': 3}
        self.id_to_string = {v: k for k, v in self.string_to_id.items()}
        self.color_vocab = {}
        self.format_vocab = {}
        
        # === EMBEDDINGS POUR CHAQUE MODALITÉ ===
        
        # 1. Texte
        self.text_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.text_pos_encoding = PositionalEncoding(embedding_dim, max_text_length)
        
        # 2. Style (20 features → embedding_dim)
        self.style_projector = nn.Sequential(
            nn.Linear(20, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # 3. Position spatiale
        self.spatial_row_emb = nn.Embedding(2000, embedding_dim // 2)  # Support jusqu'à 2000 lignes
        self.spatial_col_emb = nn.Embedding(100, embedding_dim // 2)   # Support jusqu'à 100 colonnes
        self.spatial_projector = nn.Linear(embedding_dim, embedding_dim)
        
        # 4. Type de cellule
        self.type_embedding = nn.Embedding(4, embedding_dim)
        
        # 5. Valeurs numériques
        self.numeric_projector = nn.Sequential(
            nn.Linear(6, embedding_dim),  # 6 features numériques
            nn.LayerNorm(embedding_dim),
            nn.Tanh()
        )
        
        # === TOKENS DE MODALITÉ ===
        # Tokens apprenables pour identifier chaque type d'information
        self.modality_tokens = nn.Parameter(torch.randn(6, embedding_dim) * 0.02)
        # 0: CLS, 1: TEXT, 2: STYLE, 3: POSITION, 4: TYPE, 5: NUMERIC
        
        # === TRANSFORMER ENCODER ===
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-norm pour stabilité
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # === AGRÉGATION FINALE ===
        self.final_aggregator = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )
        
        # === NORMALISATION DES VALEURS NUMÉRIQUES ===
        self.numeric_stats = {
            'min': 0, 'max': 1, 'mean': 0, 'std': 1,
            'q25': 0, 'q75': 1, 'log_min': 0, 'log_max': 1
        }
        
        self.is_fitted = False
        
    def build_vocabulary(self, cells: List[FullCellInfo], min_freq: int = 2):
        """Construit le vocabulaire à partir des cellules"""
        print(f"🔨 Construction du vocabulaire à partir de {len(cells)} cellules...")
        
        # Compter les tokens textuels
        text_counter = Counter()
        color_set = set()
        format_set = set()
        numeric_values = []
        
        for cell in cells:
            # Texte
            if cell.cell_type == 1 and cell.raw_value:
                words = str(cell.raw_value).strip().split()
                text_counter.update(words)
            
            # Couleurs
            color_set.add(cell.text_color)
            color_set.add(cell.background_color)
            color_set.add(cell.border_color)
            
            # Formats
            format_set.add(cell.number_format)
            format_set.add(cell.font_family)
            
            # Valeurs numériques
            if cell.cell_type == 2:
                try:
                    val = float(cell.raw_value)
                    if not (np.isnan(val) or np.isinf(val)):
                        numeric_values.append(val)
                except:
                    pass
        
        # Construire vocabulaire textuel
        current_id = len(self.string_to_id)
        for word, freq in text_counter.items():
            if freq >= min_freq and current_id < self.vocab_size:
                self.string_to_id[word] = current_id
                self.id_to_string[current_id] = word
                current_id += 1
        
        # Vocabulaires de couleurs et formats (hashing)
        self.color_vocab = {color: self._hash_string(color) for color in color_set}
        self.format_vocab = {fmt: self._hash_string(fmt) for fmt in format_set}
        
        # Statistiques numériques
        if numeric_values:
            numeric_values = np.array(numeric_values)
            self.numeric_stats = {
                'min': float(np.min(numeric_values)),
                'max': float(np.max(numeric_values)),
                'mean': float(np.mean(numeric_values)),
                'std': float(np.std(numeric_values)),
                'q25': float(np.percentile(numeric_values, 25)),
                'q75': float(np.percentile(numeric_values, 75)),
                'log_min': float(np.log10(np.maximum(numeric_values, 1e-8)).min()),
                'log_max': float(np.log10(np.maximum(numeric_values, 1e-8)).max())
            }
        
        self.is_fitted = True
        
        print(f"✅ Vocabulaire construit:")
        print(f"   - Mots textuels: {len(self.string_to_id)} (coverage: {current_id}/{len(text_counter)})")
        print(f"   - Couleurs uniques: {len(color_set)}")
        print(f"   - Formats uniques: {len(format_set)}")
        print(f"   - Valeurs numériques: {len(numeric_values)}")
        if numeric_values:
            print(f"   - Range numérique: [{self.numeric_stats['min']:.2f}, {self.numeric_stats['max']:.2f}]")
    
    def forward(self, cells_batch: List[FullCellInfo]) -> torch.Tensor:
        """
        Forward pass principal : cellules → embeddings unifiés
        
        Args:
            cells_batch: Liste de cellules avec toutes leurs informations
            
        Returns:
            Tensor (batch_size, embedding_dim) : embeddings unifiés
        """
        if not self.is_fitted:
            raise ValueError("Le modèle doit être entraîné avec build_vocabulary() d'abord")
        
        batch_size = len(cells_batch)
        device = next(self.parameters()).device
        
        # === ÉTAPE 1: CRÉER LES SÉQUENCES MULTI-MODALES ===
        
        all_sequences = []
        all_masks = []
        
        for cell in cells_batch:
            sequence_tokens = []
            
            # 1. Token CLS (début de séquence)
            cls_token = self.modality_tokens[0].unsqueeze(0)  # (1, embedding_dim)
            sequence_tokens.append(cls_token)
            
            # 2. Tokens textuels
            text_tokens = self._encode_text_tokens(cell)
            if text_tokens.size(0) > 0:
                text_tokens = text_tokens + self.modality_tokens[1].unsqueeze(0)
                sequence_tokens.append(text_tokens)
            
            # 3. Token style
            style_token = self._encode_style_token(cell)
            style_token = style_token + self.modality_tokens[2]
            sequence_tokens.append(style_token.unsqueeze(0))
            
            # 4. Token position
            pos_token = self._encode_position_token(cell)
            pos_token = pos_token + self.modality_tokens[3]
            sequence_tokens.append(pos_token.unsqueeze(0))
            
            # 5. Token type
            type_token = self._encode_type_token(cell)
            type_token = type_token + self.modality_tokens[4]
            sequence_tokens.append(type_token.unsqueeze(0))
            
            # 6. Token numérique
            numeric_token = self._encode_numeric_token(cell)
            numeric_token = numeric_token + self.modality_tokens[5]
            sequence_tokens.append(numeric_token.unsqueeze(0))
            
            # Concaténer tous les tokens
            cell_sequence = torch.cat(sequence_tokens, dim=0)  # (seq_len, embedding_dim)
            all_sequences.append(cell_sequence)
        
        # === ÉTAPE 2: PADDING ET BATCHING ===
        
        max_seq_len = max(seq.size(0) for seq in all_sequences)
        padded_sequences = []
        attention_masks = []
        
        for seq in all_sequences:
            seq_len = seq.size(0)
            
            if seq_len < max_seq_len:
                # Padding avec des zéros
                padding = torch.zeros(max_seq_len - seq_len, self.embedding_dim, 
                                    device=device, dtype=seq.dtype)
                padded_seq = torch.cat([seq, padding], dim=0)
                
                # Masque d'attention
                mask = torch.cat([
                    torch.ones(seq_len, device=device, dtype=torch.bool),
                    torch.zeros(max_seq_len - seq_len, device=device, dtype=torch.bool)
                ])
            else:
                padded_seq = seq
                mask = torch.ones(seq_len, device=device, dtype=torch.bool)
            
            padded_sequences.append(padded_seq)
            attention_masks.append(mask)
        
        # Stack en batch
        batch_sequences = torch.stack(padded_sequences).to(device)  # (batch, seq_len, emb_dim)
        batch_masks = torch.stack(attention_masks).to(device)       # (batch, seq_len)
        
        # === ÉTAPE 3: TRANSFORMER MULTI-MODAL ===
        
        # Masque pour padding (True = ignorer, False = attention)
        src_key_padding_mask = ~batch_masks
        
        # Passage dans le Transformer
        transformer_output = self.transformer(
            batch_sequences,
            src_key_padding_mask=src_key_padding_mask
        )  # (batch_size, seq_len, embedding_dim)
        
        # === ÉTAPE 4: AGRÉGATION EN EMBEDDING UNIFIÉ ===
        
        # Utiliser le token CLS pour l'embedding final (comme BERT)
        cls_embeddings = transformer_output[:, 0, :]  # (batch_size, embedding_dim)
        
        # Transformation finale
        final_embeddings = self.final_aggregator(cls_embeddings)
        
        return final_embeddings
    
    def _encode_text_tokens(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode les tokens textuels avec positional encoding"""
        device = next(self.parameters()).device
        
        if cell.cell_type != 1 or not cell.raw_value:
            return torch.empty(0, self.embedding_dim, device=device)
        
        # Tokenisation
        text = str(cell.raw_value).strip()
        words = text.split()[:self.max_text_length-2]  # -2 pour SEP
        
        # Conversion en IDs
        token_ids = []
        for word in words:
            token_ids.append(self.string_to_id.get(word, self.string_to_id['<UNK>']))
        token_ids.append(self.string_to_id['<SEP>'])
        
        if not token_ids:
            return torch.empty(0, self.embedding_dim, device=device)
        
        # Embeddings
        token_ids_tensor = torch.tensor(token_ids, device=device)
        embeddings = self.text_embedding(token_ids_tensor)
        
        # Positional encoding
        embeddings = self.text_pos_encoding(embeddings.unsqueeze(0)).squeeze(0)
        
        return embeddings
    
    def _encode_style_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode les informations de style en un token"""
        device = next(self.parameters()).device
        
        features = [
            float(cell.bold),
            float(cell.italic), 
            float(cell.underline),
            cell.font_size / 24.0,  # Normalisation
            self._normalize_hash(self.color_vocab.get(cell.text_color, 0)),
            self._normalize_hash(self.color_vocab.get(cell.background_color, 0)),
            self._normalize_hash(self.color_vocab.get(cell.border_color, 0)),
            cell.border_top / 8.0,
            cell.border_bottom / 8.0,
            cell.border_left / 8.0,
            cell.border_right / 8.0,
            cell.horizontal_align / 3.0,
            cell.vertical_align / 2.0,
            float(cell.text_wrap),
            float(cell.is_merged),
            float(cell.has_comment),
            float(cell.is_locked),
            float(cell.is_hidden),
            self._normalize_hash(self.format_vocab.get(cell.number_format, 0)),
            self._normalize_hash(self.format_vocab.get(cell.font_family, 0))
        ]
        
        features_tensor = torch.tensor(features, device=device, dtype=torch.float)
        return self.style_projector(features_tensor)
    
    def _encode_position_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode la position spatiale"""
        device = next(self.parameters()).device
        
        row_emb = self.spatial_row_emb(torch.tensor(min(cell.row, 1999), device=device))
        col_emb = self.spatial_col_emb(torch.tensor(min(cell.col, 99), device=device))
        
        pos_concat = torch.cat([row_emb, col_emb])
        return self.spatial_projector(pos_concat)
    
    def _encode_type_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode le type de cellule"""
        device = next(self.parameters()).device
        return self.type_embedding(torch.tensor(cell.cell_type, device=device))
    
    def _encode_numeric_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode les valeurs numériques"""
        device = next(self.parameters()).device
        
        if cell.cell_type == 2:
            try:
                value = float(cell.raw_value)
                features = self._extract_numeric_features(value)
            except:
                features = [0.0] * 6
        else:
            features = [0.0] * 6
        
        features_tensor = torch.tensor(features, device=device, dtype=torch.float)
        return self.numeric_projector(features_tensor)
    
    def _extract_numeric_features(self, value: float) -> List[float]:
        """Extrait des features d'une valeur numérique"""
        stats = self.numeric_stats
        
        features = []
        
        # 1. Normalisation min-max
        if stats['max'] > stats['min']:
            minmax_norm = (value - stats['min']) / (stats['max'] - stats['min'])
            features.append(np.clip(minmax_norm, 0, 1))
        else:
            features.append(0.5)
        
        # 2. Z-score normalisé
        if stats['std'] > 0:
            zscore = (value - stats['mean']) / stats['std']
            zscore_norm = np.tanh(zscore / 3) * 0.5 + 0.5
            features.append(zscore_norm)
        else:
            features.append(0.5)
        
        # 3. Log-scale
        if value > 0 and stats['log_max'] > stats['log_min']:
            log_val = np.log10(max(value, 1e-8))
            log_norm = (log_val - stats['log_min']) / (stats['log_max'] - stats['log_min'])
            features.append(np.clip(log_norm, 0, 1))
        else:
            features.append(0.0)
        
        # 4. Signe et propriétés
        features.append(1.0 if value > 0 else 0.0)  # Positif
        features.append(1.0 if value == int(value) else 0.0)  # Entier
        features.append(min(len(str(int(abs(value)))) / 10.0, 1.0))  # Nb chiffres
        
        return features
    
    def _hash_string(self, s: str) -> int:
        """Hash stable d'une string"""
        return int(hashlib.md5(s.encode()).hexdigest()[:8], 16)
    
    def _normalize_hash(self, hash_val: int) -> float:
        """Normalise un hash en [0,1]"""
        return (hash_val % 10000) / 10000.0

In [None]:
class TransformerCellEmbedder(nn.Module):
    """
    Transformer multi-modal pour encoder les cellules Excel complètes.
    Chaque cellule → embedding 128D avec interactions cross-modales riches.
    """
    
    def __init__(self, 
                 embedding_dim: int = 128,
                 vocab_size: int = 10000,
                 num_heads: int = 8,
                 num_layers: int = 4,
                 max_text_length: int = 20,
                 dropout: float = 0.1):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.max_text_length = max_text_length
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        # === VOCABULAIRE ET MAPPINGS ===
        self.string_to_id = {'<PAD>': 0, '<UNK>': 1, '<CLS>': 2, '<SEP>': 3}
        self.id_to_string = {v: k for k, v in self.string_to_id.items()}
        self.color_vocab = {}
        self.format_vocab = {}
        
        # === EMBEDDINGS POUR CHAQUE MODALITÉ ===
        
        # 1. Texte
        self.text_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.text_pos_encoding = PositionalEncoding(embedding_dim, max_text_length)
        
        # 2. Style (20 features → embedding_dim)
        self.style_projector = nn.Sequential(
            nn.Linear(20, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # 3. Position spatiale
        self.spatial_row_emb = nn.Embedding(2000, embedding_dim // 2)  # Support jusqu'à 2000 lignes
        self.spatial_col_emb = nn.Embedding(100, embedding_dim // 2)   # Support jusqu'à 100 colonnes
        self.spatial_projector = nn.Linear(embedding_dim, embedding_dim)
        
        # 4. Type de cellule
        self.type_embedding = nn.Embedding(4, embedding_dim)
        
        # 5. Valeurs numériques
        self.numeric_projector = nn.Sequential(
            nn.Linear(6, embedding_dim),  # 6 features numériques
            nn.LayerNorm(embedding_dim),
            nn.Tanh()
        )
        
        # === TOKENS DE MODALITÉ ===
        # Tokens apprenables pour identifier chaque type d'information
        self.modality_tokens = nn.Parameter(torch.randn(6, embedding_dim) * 0.02)
        # 0: CLS, 1: TEXT, 2: STYLE, 3: POSITION, 4: TYPE, 5: NUMERIC
        
        # === TRANSFORMER ENCODER ===
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-norm pour stabilité
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # === AGRÉGATION FINALE ===
        self.final_aggregator = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )
        
        # === NORMALISATION DES VALEURS NUMÉRIQUES ===
        self.numeric_stats = {
            'min': 0, 'max': 1, 'mean': 0, 'std': 1,
            'q25': 0, 'q75': 1, 'log_min': 0, 'log_max': 1
        }
        
        self.is_fitted = False
        
    def build_vocabulary(self, cells: List[FullCellInfo], min_freq: int = 2):
        """Construit le vocabulaire à partir des cellules"""
        print(f"🔨 Construction du vocabulaire à partir de {len(cells)} cellules...")
        
        # Compter les tokens textuels
        text_counter = Counter()
        color_set = set()
        format_set = set()
        numeric_values = []
        
        for cell in cells:
            # Texte
            if cell.cell_type == 1 and cell.raw_value:
                words = str(cell.raw_value).strip().split()
                text_counter.update(words)
            
            # Couleurs
            color_set.add(cell.text_color)
            color_set.add(cell.background_color)
            color_set.add(cell.border_color)
            
            # Formats
            format_set.add(cell.number_format)
            format_set.add(cell.font_family)
            
            # Valeurs numériques
            if cell.cell_type == 2:
                try:
                    val = float(cell.raw_value)
                    if not (np.isnan(val) or np.isinf(val)):
                        numeric_values.append(val)
                except:
                    pass
        
        # Construire vocabulaire textuel
        current_id = len(self.string_to_id)
        for word, freq in text_counter.items():
            if freq >= min_freq and current_id < self.vocab_size:
                self.string_to_id[word] = current_id
                self.id_to_string[current_id] = word
                current_id += 1
        
        # Vocabulaires de couleurs et formats (hashing)
        self.color_vocab = {color: self._hash_string(color) for color in color_set}
        self.format_vocab = {fmt: self._hash_string(fmt) for fmt in format_set}
        
        # Statistiques numériques
        if numeric_values:
            numeric_values = np.array(numeric_values)
            self.numeric_stats = {
                'min': float(np.min(numeric_values)),
                'max': float(np.max(numeric_values)),
                'mean': float(np.mean(numeric_values)),
                'std': float(np.std(numeric_values)),
                'q25': float(np.percentile(numeric_values, 25)),
                'q75': float(np.percentile(numeric_values, 75)),
                'log_min': float(np.log10(np.maximum(numeric_values, 1e-8)).min()),
                'log_max': float(np.log10(np.maximum(numeric_values, 1e-8)).max())
            }
        
        self.is_fitted = True
        
        print(f"✅ Vocabulaire construit:")
        print(f"   - Mots textuels: {len(self.string_to_id)} (coverage: {current_id}/{len(text_counter)})")
        print(f"   - Couleurs uniques: {len(color_set)}")
        print(f"   - Formats uniques: {len(format_set)}")
        print(f"   - Valeurs numériques: {len(numeric_values)}")
        if numeric_values:
            print(f"   - Range numérique: [{self.numeric_stats['min']:.2f}, {self.numeric_stats['max']:.2f}]")
    
    def forward(self, cells_batch: List[FullCellInfo]) -> torch.Tensor:
        """
        Forward pass principal : cellules → embeddings unifiés
        
        Args:
            cells_batch: Liste de cellules avec toutes leurs informations
            
        Returns:
            Tensor (batch_size, embedding_dim) : embeddings unifiés
        """
        if not self.is_fitted:
            raise ValueError("Le modèle doit être entraîné avec build_vocabulary() d'abord")
        
        batch_size = len(cells_batch)
        device = next(self.parameters()).device
        
        # === ÉTAPE 1: CRÉER LES SÉQUENCES MULTI-MODALES ===
        
        all_sequences = []
        all_masks = []
        
        for cell in cells_batch:
            sequence_tokens = []
            
            # 1. Token CLS (début de séquence)
            cls_token = self.modality_tokens[0].unsqueeze(0)  # (1, embedding_dim)
            sequence_tokens.append(cls_token)
            
            # 2. Tokens textuels
            text_tokens = self._encode_text_tokens(cell)
            if text_tokens.size(0) > 0:
                text_tokens = text_tokens + self.modality_tokens[1].unsqueeze(0)
                sequence_tokens.append(text_tokens)
            
            # 3. Token style
            style_token = self._encode_style_token(cell)
            style_token = style_token + self.modality_tokens[2]
            sequence_tokens.append(style_token.unsqueeze(0))
            
            # 4. Token position
            pos_token = self._encode_position_token(cell)
            pos_token = pos_token + self.modality_tokens[3]
            sequence_tokens.append(pos_token.unsqueeze(0))
            
            # 5. Token type
            type_token = self._encode_type_token(cell)
            type_token = type_token + self.modality_tokens[4]
            sequence_tokens.append(type_token.unsqueeze(0))
            
            # 6. Token numérique
            numeric_token = self._encode_numeric_token(cell)
            numeric_token = numeric_token + self.modality_tokens[5]
            sequence_tokens.append(numeric_token.unsqueeze(0))
            
            # Concaténer tous les tokens
            cell_sequence = torch.cat(sequence_tokens, dim=0)  # (seq_len, embedding_dim)
            all_sequences.append(cell_sequence)
        
        # === ÉTAPE 2: PADDING ET BATCHING ===
        
        max_seq_len = max(seq.size(0) for seq in all_sequences)
        padded_sequences = []
        attention_masks = []
        
        for seq in all_sequences:
            seq_len = seq.size(0)
            
            if seq_len < max_seq_len:
                # Padding avec des zéros
                padding = torch.zeros(max_seq_len - seq_len, self.embedding_dim, 
                                    device=device, dtype=seq.dtype)
                padded_seq = torch.cat([seq, padding], dim=0)
                
                # Masque d'attention
                mask = torch.cat([
                    torch.ones(seq_len, device=device, dtype=torch.bool),
                    torch.zeros(max_seq_len - seq_len, device=device, dtype=torch.bool)
                ])
            else:
                padded_seq = seq
                mask = torch.ones(seq_len, device=device, dtype=torch.bool)
            
            padded_sequences.append(padded_seq)
            attention_masks.append(mask)
        
        # Stack en batch
        batch_sequences = torch.stack(padded_sequences).to(device)  # (batch, seq_len, emb_dim)
        batch_masks = torch.stack(attention_masks).to(device)       # (batch, seq_len)
        
        # === ÉTAPE 3: TRANSFORMER MULTI-MODAL ===
        
        # Masque pour padding (True = ignorer, False = attention)
        src_key_padding_mask = ~batch_masks
        
        # Passage dans le Transformer
        transformer_output = self.transformer(
            batch_sequences,
            src_key_padding_mask=src_key_padding_mask
        )  # (batch_size, seq_len, embedding_dim)
        
        # === ÉTAPE 4: AGRÉGATION EN EMBEDDING UNIFIÉ ===
        
        # Utiliser le token CLS pour l'embedding final (comme BERT)
        cls_embeddings = transformer_output[:, 0, :]  # (batch_size, embedding_dim)
        
        # Transformation finale
        final_embeddings = self.final_aggregator(cls_embeddings)
        
        return final_embeddings
    
    def _encode_text_tokens(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode les tokens textuels avec positional encoding"""
        device = next(self.parameters()).device
        
        if cell.cell_type != 1 or not cell.raw_value:
            return torch.empty(0, self.embedding_dim, device=device)
        
        # Tokenisation
        text = str(cell.raw_value).strip()
        words = text.split()[:self.max_text_length-2]  # -2 pour SEP
        
        # Conversion en IDs
        token_ids = []
        for word in words:
            token_ids.append(self.string_to_id.get(word, self.string_to_id['<UNK>']))
        token_ids.append(self.string_to_id['<SEP>'])
        
        if not token_ids:
            return torch.empty(0, self.embedding_dim, device=device)
        
        # Embeddings
        token_ids_tensor = torch.tensor(token_ids, device=device)
        embeddings = self.text_embedding(token_ids_tensor)
        
        # Positional encoding
        embeddings = self.text_pos_encoding(embeddings.unsqueeze(0)).squeeze(0)
        
        return embeddings
    
    def _encode_style_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode les informations de style en un token"""
        device = next(self.parameters()).device
        
        features = [
            float(cell.bold),
            float(cell.italic), 
            float(cell.underline),
            cell.font_size / 24.0,  # Normalisation
            self._normalize_hash(self.color_vocab.get(cell.text_color, 0)),
            self._normalize_hash(self.color_vocab.get(cell.background_color, 0)),
            self._normalize_hash(self.color_vocab.get(cell.border_color, 0)),
            cell.border_top / 8.0,
            cell.border_bottom / 8.0,
            cell.border_left / 8.0,
            cell.border_right / 8.0,
            cell.horizontal_align / 3.0,
            cell.vertical_align / 2.0,
            float(cell.text_wrap),
            float(cell.is_merged),
            float(cell.has_comment),
            float(cell.is_locked),
            float(cell.is_hidden),
            self._normalize_hash(self.format_vocab.get(cell.number_format, 0)),
            self._normalize_hash(self.format_vocab.get(cell.font_family, 0))
        ]
        
        features_tensor = torch.tensor(features, device=device, dtype=torch.float)
        return self.style_projector(features_tensor)
    
    def _encode_position_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode la position spatiale"""
        device = next(self.parameters()).device
        
        row_emb = self.spatial_row_emb(torch.tensor(min(cell.row, 1999), device=device))
        col_emb = self.spatial_col_emb(torch.tensor(min(cell.col, 99), device=device))
        
        pos_concat = torch.cat([row_emb, col_emb])
        return self.spatial_projector(pos_concat)
    
    def _encode_type_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode le type de cellule"""
        device = next(self.parameters()).device
        return self.type_embedding(torch.tensor(cell.cell_type, device=device))
    
    def _encode_numeric_token(self, cell: FullCellInfo) -> torch.Tensor:
        """Encode les valeurs numériques"""
        device = next(self.parameters()).device
        
        if cell.cell_type == 2:
            try:
                value = float(cell.raw_value)
                features = self._extract_numeric_features(value)
            except:
                features = [0.0] * 6
        else:
            features = [0.0] * 6
        
        features_tensor = torch.tensor(features, device=device, dtype=torch.float)
        return self.numeric_projector(features_tensor)
    
    def _extract_numeric_features(self, value: float) -> List[float]:
        """Extrait des features d'une valeur numérique"""
        stats = self.numeric_stats
        
        features = []
        
        # 1. Normalisation min-max
        if stats['max'] > stats['min']:
            minmax_norm = (value - stats['min']) / (stats['max'] - stats['min'])
            features.append(np.clip(minmax_norm, 0, 1))
        else:
            features.append(0.5)
        
        # 2. Z-score normalisé
        if stats['std'] > 0:
            zscore = (value - stats['mean']) / stats['std']
            zscore_norm = np.tanh(zscore / 3) * 0.5 + 0.5
            features.append(zscore_norm)
        else:
            features.append(0.5)
        
        # 3. Log-scale
        if value > 0 and stats['log_max'] > stats['log_min']:
            log_val = np.log10(max(value, 1e-8))
            log_norm = (log_val - stats['log_min']) / (stats['log_max'] - stats['log_min'])
            features.append(np.clip(log_norm, 0, 1))
        else:
            features.append(0.0)
        
        # 4. Signe et propriétés
        features.append(1.0 if value > 0 else 0.0)  # Positif
        features.append(1.0 if value == int(value) else 0.0)  # Entier
        features.append(min(len(str(int(abs(value)))) / 10.0, 1.0))  # Nb chiffres
        
        return features
    
    def _hash_string(self, s: str) -> int:
        """Hash stable d'une string"""
        return int(hashlib.md5(s.encode()).hexdigest()[:8], 16)
    
    def _normalize_hash(self, hash_val: int) -> float:
        """Normalise un hash en [0,1]"""
        return (hash_val % 10000) / 10000.0