In [14]:
import json
from dataclasses import dataclass
from datetime import datetime
import torch
import torch.nn as nn
import re
import numpy as np
from typing import List, Dict, Any, Optional, Union, Tuple, Set
from enum import Enum
import math

data_folder = "data/"
jsons_path = "data/*.json"

Création de la classe FullCellInfo récupérant l'ensmeble des informations d'une cellule 

In [15]:
@dataclass
class FullCellInfo:
    """Structure complète d'une cellule Excel avec les informations disponibles dans le JSON Univer"""
    # Contenu (disponible dans le JSON)
    raw_value: Any = ""
    cell_type: int = 0  # 1=text, 2=number, 3=formula (inféré)
    formula: str = ""   # "f" - formule si présente (avec = au début)
    
    # Position
    row: int = 0
    col: int = 0
    sheet_id: str = "default"
    sheet_name: str = ""
    
    # Style complet (basé sur le format Univer optimisé)
    style_id: str = ""  # Peut être vide si pas de style
    
    # Formatage de texte
    bold: bool = False          # "bl": 1
    italic: bool = False        # "it": 1  
    underline: bool = False     # "ul": {"s": 1}
    strike: bool = False        # "st": {"s": 1}
    font_size: float = 11.0     # "fs": size (défaut Calibri 11)
    font_family: str = "Calibri" # "ff": family
    
    # Couleurs
    text_color: str = "#000000"     # "cl": {"rgb": "#color"}
    background_color: str = "#FFFFFF"  # "bg": {"rgb": "#color"}
    
    # Bordures
    border_top: int = 0      # "bd": {"t": {"s": value, "cl": {"rgb": "#color"}}}
    border_bottom: int = 0   # "bd": {"b": {"s": value, "cl": {"rgb": "#color"}}}
    border_left: int = 0     # "bd": {"l": {"s": value, "cl": {"rgb": "#color"}}}
    border_right: int = 0    # "bd": {"r": {"s": value, "cl": {"rgb": "#color"}}}
    border_color: str = "#000000"
    
    # Alignement
    horizontal_align: int = 0  # "ht": 1=left, 2=center, 3=right, 4=justify
    vertical_align: int = 0    # "vt": 1=top, 2=center, 3=bottom
    text_wrap: bool = False    # "tb": 3
    
    # Rotation/Transformation
    text_rotation: int = 0     # "tr": {"a": angle, "v": 0}
    
    # Format de nombre
    number_format: str = "General"  # "n": {"pattern": "format"}
    
    # Fusion de cellules (disponible via mergeData)
    is_merged: bool = False
    merge_range: Tuple[int, int, int, int] = (0, 0, 0, 0)  # (startRow, endRow, startCol, endCol)
    
    # Métadonnées de feuille
    sheet_hidden: bool = False
    sheet_tab_color: str = ""
    sheet_zoom: float = 1.0
    sheet_show_gridlines: bool = True
    
    # Métadonnées de ligne/colonne
    row_height: Optional[float] = None
    row_hidden: bool = False
    col_width: Optional[int] = None
    col_hidden: bool = False
    
    # Volets figés
    freeze_start_row: int = -1
    freeze_start_col: int = -1
    
    def __post_init__(self):
        """Post-processing après création"""
        # Le type est défini par "t" dans le JSON ou inféré du contenu
        if self.cell_type == 0:  # Si pas de type défini
            if self.formula:
                self.cell_type = 3
            elif isinstance(self.raw_value, (int, float)):
                self.cell_type = 2
            elif isinstance(self.raw_value, str) and self.raw_value.strip():
                self.cell_type = 1


Classe pour générer les FullCellInfo à partir d'un json

In [16]:
class ExcelParser:
    """Parse les fichiers Excel JSON Univer en structures FullCellInfo"""
    
    @staticmethod
    def parse_excel_json(excel_data: Dict) -> List[FullCellInfo]:
        """Convertit JSON Excel Univer en liste de FullCellInfo"""
        cells = []
        
        styles = excel_data.get('styles', {})
        
        for sheet_id, sheet_info in excel_data.get('sheets', {}).items():
            sheet_name = sheet_info.get('name', "")
            cell_data = sheet_info.get('cellData', {})
            merge_data = sheet_info.get('mergeData', [])
            row_data = sheet_info.get('rowData', {})
            column_data = sheet_info.get('columnData', {})
            freeze_info = sheet_info.get('freeze', {})
            
            # Métadonnées de feuille
            sheet_hidden = bool(sheet_info.get('hidden', 0))
            sheet_tab_color = sheet_info.get('tabColor', "")
            sheet_zoom = sheet_info.get('zoomRatio', 1.0)
            sheet_show_gridlines = bool(sheet_info.get('showGridlines', 1))
            
            # Informations de volets figés
            freeze_start_row = freeze_info.get('startRow', -1)
            freeze_start_col = freeze_info.get('startColumn', -1)
            
            # Créer un mapping des cellules fusionnées
            merge_map = ExcelParser._create_merge_map(merge_data)
            
            for row_str, row_cells in cell_data.items():
                row = int(row_str)
                
                # Informations de ligne
                row_info = row_data.get(row_str, {})
                row_height = row_info.get('h')
                row_hidden = bool(row_info.get('hd', 0))
                
                for col_str, cell_info in row_cells.items():
                    col = int(col_str)
                    
                    # Informations de colonne
                    col_info = column_data.get(col_str, {})
                    col_width = col_info.get('w')
                    col_hidden = bool(col_info.get('hd', 0))
                    
                    # Extraire les informations de style
                    style_id = cell_info.get('s', '')
                    style = styles.get(style_id, {}) if style_id else {}
                    
                    # Vérifier si cette cellule fait partie d'une fusion
                    merge_info = merge_map.get((row, col), None)
                    
                    cell = FullCellInfo(
                        raw_value=cell_info.get('v', ''),
                        cell_type=cell_info.get('t', 0),
                        formula=cell_info.get('f', ''),
                        row=row,
                        col=col,
                        sheet_id=sheet_id,
                        sheet_name=sheet_name,
                        style_id=style_id,
                        is_merged=merge_info is not None,
                        merge_range=merge_info if merge_info else (0, 0, 0, 0),
                        sheet_hidden=sheet_hidden,
                        sheet_tab_color=sheet_tab_color,
                        sheet_zoom=sheet_zoom,
                        sheet_show_gridlines=sheet_show_gridlines,
                        row_height=row_height,
                        row_hidden=row_hidden,
                        col_width=col_width,
                        col_hidden=col_hidden,
                        freeze_start_row=freeze_start_row,
                        freeze_start_col=freeze_start_col,
                        **ExcelParser._parse_style(style)
                    )
                    
                    cells.append(cell)
        
        return cell
    
    @staticmethod
    def _create_merge_map(merge_data: List[Dict]) -> Dict[Tuple[int, int], Tuple[int, int, int, int]]:
        """Crée un mapping des cellules fusionnées"""
        merge_map = {}
        
        for merge in merge_data:
            start_row = merge['startRow']
            end_row = merge['endRow'] 
            start_col = merge['startColumn']
            end_col = merge['endColumn']
            
            # Marquer toutes les cellules dans cette plage comme fusionnées
            for row in range(start_row, end_row + 1):
                for col in range(start_col, end_col + 1):
                    merge_map[(row, col)] = (start_row, end_row, start_col, end_col)
        
        return merge_map
    
    @staticmethod
    def _parse_style(style: Dict) -> Dict:
        """Parse les informations de style basé sur le format Univer optimisé"""
        parsed = {
            'bold': bool(style.get('bl', 0)),
            'italic': bool(style.get('it', 0)),
            'font_size': float(style.get('fs', 11.0)),
            'font_family': style.get('ff', 'Calibri'),
            'text_wrap': bool(style.get('tb', 0) == 3),  # tb: 3 = wrap
        }
        
        # Underline - structure: "ul": {"s": 1}
        ul_info = style.get('ul', {})
        if isinstance(ul_info, dict):
            parsed['underline'] = bool(ul_info.get('s', 0))
        else:
            parsed['underline'] = bool(ul_info)
        
        # Strike - structure: "st": {"s": 1}
        st_info = style.get('st', {})
        if isinstance(st_info, dict):
            parsed['strike'] = bool(st_info.get('s', 0))
        else:
            parsed['strike'] = bool(st_info)
        
        # Couleurs
        if 'cl' in style:  # text color
            parsed['text_color'] = style['cl'].get('rgb', '#000000')
        if 'bg' in style:  # background color
            parsed['background_color'] = style['bg'].get('rgb', '#FFFFFF')
            
        # Bordures - structure: "bd": {"t": {"s": 8, "cl": {"rgb": "#000000"}}}
        borders = style.get('bd', {})
        parsed.update({
            'border_top': borders.get('t', {}).get('s', 0),
            'border_bottom': borders.get('b', {}).get('s', 0),
            'border_left': borders.get('l', {}).get('s', 0),
            'border_right': borders.get('r', {}).get('s', 0),
        })
        
        # Couleur de bordure (prendre la première trouvée)
        border_color = "#000000"
        for border_side in ['t', 'b', 'l', 'r']:
            if border_side in borders and 'cl' in borders[border_side]:
                border_color = borders[border_side]['cl'].get('rgb', '#000000')
                break
        parsed['border_color'] = border_color
        
        # Alignement (mapping exact du convertisseur)
        parsed.update({
            'horizontal_align': style.get('ht', 0),  # 1=left, 2=center, 3=right, 4=justify
            'vertical_align': style.get('vt', 0),    # 1=top, 2=center, 3=bottom
        })
        
        # Rotation/Transformation - "tr": {"a": angle, "v": 0}
        tr_info = style.get('tr', {})
        if isinstance(tr_info, dict):
            parsed['text_rotation'] = tr_info.get('a', 0)  # angle
        else:
            parsed['text_rotation'] = 0
        
        # Format de nombre - structure: "n": {"pattern": "format"}
        number_info = style.get('n', {})
        if isinstance(number_info, dict):
            parsed['number_format'] = number_info.get('pattern', 'General')
        else:
            parsed['number_format'] = 'General'
        
        return parsed

Tokenisation d'un classeur : [POS:1,0][TYPE:TEXT][STYLE:B,I,BG:#4470C4][VALUE:Map brief][MERGE:1,2,0,6][LAYOUT:RH:25.2,CW:150]

In [17]:
class ExcelDataProcessor:
    """Classe pour préparer les données Excel pour l'entraînement de transformers"""
    
    @staticmethod
    def cells_to_text_sequence(cells: List[FullCellInfo], include_empty_cells: bool = False) -> str:
        """Convertit une liste de cellules en séquence de texte pour l'entraînement"""
        sequences = []
        
        # Regrouper par feuille et position
        sheets = {}
        for cell in cells:
            if cell.sheet_name not in sheets:
                sheets[cell.sheet_name] = []
            sheets[cell.sheet_name].append(cell)
        
        for sheet_name, sheet_cells in sheets.items():
            # Trier par position (row, col)
            sheet_cells.sort(key=lambda c: (c.row, c.col))
            
            # Métadonnées de feuille enrichies
            sheet_meta = f"[SHEET:{sheet_name}"
            if sheet_cells:
                first_cell = sheet_cells[0]
                if first_cell.sheet_hidden:
                    sheet_meta += ",HIDDEN"
                if first_cell.sheet_tab_color:
                    sheet_meta += f",TAB:{first_cell.sheet_tab_color}"
                if first_cell.sheet_zoom != 1.0:
                    sheet_meta += f",ZOOM:{first_cell.sheet_zoom}"
                if not first_cell.sheet_show_gridlines:
                    sheet_meta += ",NO_GRID"
                if first_cell.freeze_start_row >= 0 or first_cell.freeze_start_col >= 0:
                    sheet_meta += f",FREEZE:{first_cell.freeze_start_row},{first_cell.freeze_start_col}"
            sheet_meta += "]"
            
            sheet_text = sheet_meta
            for cell in sheet_cells:
                if not include_empty_cells and not cell.raw_value and not cell.is_merged and not cell.style_id:
                    continue
                cell_repr = ExcelDataProcessor._cell_to_token(cell)
                sheet_text += f" {cell_repr}"
            
            sequences.append(sheet_text)
        
        return " [SHEET_END] ".join(sequences)
    
    @staticmethod
    def _cell_to_token(cell: FullCellInfo) -> str:
        """Convertit une cellule en token enrichi pour le transformer"""
        # Format: [POS:row,col][TYPE:t][STYLE:...][VALUE:...][MERGE:...][ROW/COL:...]
        pos = f"[POS:{cell.row},{cell.col}]"
        
        # Type de cellule
        type_map = {0: "EMPTY", 1: "TEXT", 2: "NUMBER", 3: "FORMULA"}
        cell_type = f"[TYPE:{type_map.get(cell.cell_type, 'UNKNOWN')}]"
        
        # Style simplifié pour le token
        style_parts = []
        if cell.bold: style_parts.append("B")
        if cell.italic: style_parts.append("I") 
        if cell.underline: style_parts.append("U")
        if cell.strike: style_parts.append("S")
        if cell.background_color != "#FFFFFF": 
            style_parts.append(f"BG:{cell.background_color}")
        if cell.text_rotation != 0: 
            style_parts.append(f"ROT:{cell.text_rotation}")
        if any([cell.border_top, cell.border_bottom, cell.border_left, cell.border_right]):
            style_parts.append("BORDER")
        if cell.horizontal_align != 0:
            align_map = {1: "LEFT", 2: "CENTER", 3: "RIGHT", 4: "JUSTIFY"}
            style_parts.append(f"ALIGN:{align_map.get(cell.horizontal_align, 'UNKNOWN')}")
        if cell.text_wrap:
            style_parts.append("WRAP")
        if cell.font_size != 11.0:
            style_parts.append(f"SIZE:{cell.font_size}")
        if cell.font_family != "Calibri":
            style_parts.append(f"FONT:{cell.font_family}")
        
        style = f"[STYLE:{','.join(style_parts)}]" if style_parts else "[STYLE:NONE]"
        
        # Valeur (formule ou valeur)
        if cell.formula:
            # Formule avec = au début selon le convertisseur
            formula_clean = cell.formula.lstrip('=')
            value = f"[FORMULA:={formula_clean}]"
        else:
            value = f"[VALUE:{cell.raw_value}]"
        
        # Ajout information de fusion
        merge = ""
        if cell.is_merged:
            sr, er, sc, ec = cell.merge_range
            merge = f"[MERGE:{sr},{er},{sc},{ec}]"
        
        # Informations de ligne/colonne si non-standard
        layout = ""
        layout_parts = []
        if cell.row_height is not None:
            layout_parts.append(f"RH:{cell.row_height}")
        if cell.row_hidden:
            layout_parts.append("ROW_HIDDEN")
        if cell.col_width is not None:
            layout_parts.append(f"CW:{cell.col_width}")
        if cell.col_hidden:
            layout_parts.append("COL_HIDDEN")
        
        if layout_parts:
            layout = f"[LAYOUT:{','.join(layout_parts)}]"
        
        return f"{pos}{cell_type}{style}{value}{merge}{layout}"
    
    @staticmethod
    def extract_workbook_metadata(cells: List[FullCellInfo]) -> Dict[str, Any]:
        """Extrait les métadonnées complètes du classeur"""
        if not cells:
            return {}
        
        sheet_metadata = {}
        
        for cell in cells:
            if cell.sheet_name not in sheet_metadata:
                sheet_metadata[cell.sheet_name] = {
                    'sheet_id': cell.sheet_id,
                    'hidden': cell.sheet_hidden,
                    'tab_color': cell.sheet_tab_color,
                    'zoom': cell.sheet_zoom,
                    'show_gridlines': cell.sheet_show_gridlines,
                    'freeze_panes': (cell.freeze_start_row, cell.freeze_start_col) if cell.freeze_start_row >= 0 or cell.freeze_start_col >= 0 else None,
                    'cell_count': 0,
                    'merged_count': 0,
                    'formula_count': 0,
                    'styles_used': set(),
                    'max_row': 0,
                    'max_col': 0,
                    'custom_row_heights': 0,
                    'custom_col_widths': 0,
                    'hidden_rows': 0,
                    'hidden_cols': 0
                }
            
            meta = sheet_metadata[cell.sheet_name]
            meta['cell_count'] += 1
            meta['max_row'] = max(meta['max_row'], cell.row)
            meta['max_col'] = max(meta['max_col'], cell.col)
            
            if cell.is_merged:
                meta['merged_count'] += 1
            if cell.formula:
                meta['formula_count'] += 1
            if cell.style_id:
                meta['styles_used'].add(cell.style_id)
            if cell.row_height is not None:
                meta['custom_row_heights'] += 1
            if cell.col_width is not None:
                meta['custom_col_widths'] += 1
            if cell.row_hidden:
                meta['hidden_rows'] += 1
            if cell.col_hidden:
                meta['hidden_cols'] += 1
        
        # Convertir les sets en listes pour la sérialisation
        for meta in sheet_metadata.values():
            meta['styles_used'] = list(meta['styles_used'])
        
        return {
            'sheets': sheet_metadata,
            'total_cells': len(cells),
            'total_sheets': len(sheet_metadata),
            'total_merged_cells': sum(meta['merged_count'] for meta in sheet_metadata.values()),
            'total_formulas': sum(meta['formula_count'] for meta in sheet_metadata.values()),
            'total_styles': len(set(cell.style_id for cell in cells if cell.style_id))
        }

Embedder depuis FullCellInfo : 

In [18]:
@dataclass
class EmbeddingConfig:
    """Configuration pour l'embedder Excel"""
    # Dimensions
    embedding_dim: int = 256
    position_embedding_dim: int = 32
    type_embedding_dim: int = 16
    
    # Vocabulaires fixes
    max_position: int = 10000  # Pour row/col
    max_value_length: int = 100
    max_font_size: int = 72
    
    # Couleurs (nombre de couleurs possibles)
    color_vocab_size: int = 1000
    
    # Alignements, bordures, etc.
    align_vocab_size: int = 5
    border_vocab_size: int = 10
    font_vocab_size: int = 50

class ExcelCellEmbedder(nn.Module):
    """Embedder direct pour FullCellInfo avec positions fixes"""
    
    def __init__(self, config: EmbeddingConfig):
        super().__init__()
        self.config = config
        
        # Position 0-1: Position et Type
        self.row_embedding = nn.Embedding(config.max_position, config.position_embedding_dim)
        self.col_embedding = nn.Embedding(config.max_position, config.position_embedding_dim)
        self.type_embedding = nn.Embedding(4, config.type_embedding_dim)  # EMPTY, TEXT, NUMBER, FORMULA
        
        # Position 2: Contenu (valeur/formule)
        self.value_encoder = ValueEncoder(config)
        
        # Position 3-6: Formatage booléen
        self.bool_embedding = nn.Embedding(2, 8)  # TRUE/FALSE -> 8 dim
        
        # Position 7-8: Police
        self.font_size_embedding = nn.Embedding(config.max_font_size + 1, 16)
        self.font_family_embedding = nn.Embedding(config.font_vocab_size, 32)
        
        # Position 9-10: Couleurs
        self.color_embedding = nn.Embedding(config.color_vocab_size, 24)
        
        # Position 11-14: Alignement
        self.align_h_embedding = nn.Embedding(config.align_vocab_size, 16)
        self.align_v_embedding = nn.Embedding(config.align_vocab_size, 16)
        self.rotation_embedding = nn.Embedding(361, 16)  # 0-360 degrés
        
        # Position 15-18: Bordures
        self.border_embedding = nn.Embedding(config.border_vocab_size, 16)
        
        # Position 19: Fusion
        self.merge_encoder = MergeEncoder(config)
        
        # Projection finale
        total_dim = self._calculate_total_dim()
        self.projection = nn.Linear(total_dim, config.embedding_dim)
        
        # Normalisation
        self.layer_norm = nn.LayerNorm(config.embedding_dim)
        
        # Vocabulaires pour la conversion
        self._build_vocabularies()
    
    def _calculate_total_dim(self) -> int:
        """Calcule la dimension totale avant projection"""
        return (
            2 * self.config.position_embedding_dim +  # row, col
            self.config.type_embedding_dim +          # type
            self.value_encoder.output_dim +           # value (maintenant 2048D)
            4 * 8 +                                   # 4 bool (bold, italic, underline, strike)
            16 + 32 +                                 # font_size, font_family
            2 * 24 +                                  # text_color, bg_color
            2 * 16 + 8 + 16 +                        # align_h, align_v, wrap, rotation
            4 * 16 +                                  # 4 borders
            self.merge_encoder.output_dim             # merge
        )
    
    def _build_vocabularies(self):
        """Construit les vocabulaires de conversion"""
        # Familles de police communes
        self.font_families = [
            "Calibri", "Arial", "Times New Roman", "Helvetica", "Verdana",
            "Georgia", "Courier New", "Tahoma", "Comic Sans MS", "Impact",
            # ... ajouter d'autres polices communes
        ]
        self.font_family_to_id = {font: i for i, font in enumerate(self.font_families)}
        
        # Styles de bordure
        self.border_styles = ["NONE", "THIN", "HAIR", "DOTTED", "DASHED", "DOUBLE", "MEDIUM", "THICK"]
        self.border_style_to_id = {style: i for i, style in enumerate(self.border_styles)}
    
    def forward(self, cells: Union['FullCellInfo', List['FullCellInfo']]) -> torch.Tensor:
        """
        Convertit FullCellInfo en embeddings
        
        Args:
            cells: Une cellule ou liste de cellules
            
        Returns:
            Tensor de shape [batch_size, embedding_dim] ou [embedding_dim]
        """
        if not isinstance(cells, list):
            cells = [cells]
        
        batch_size = len(cells)
        embeddings = []
        
        for cell in cells:
            # Position 0-1: Position spatiale
            row_emb = self.row_embedding(torch.clamp(torch.tensor(cell.row), 0, self.config.max_position - 1))
            col_emb = self.col_embedding(torch.clamp(torch.tensor(cell.col), 0, self.config.max_position - 1))
            
            # Position 1: Type de cellule
            type_emb = self.type_embedding(torch.tensor(cell.cell_type))
            
            # Position 2: Contenu
            value_emb = self.value_encoder(cell)
            
            # Position 3-6: Formatage booléen
            bold_emb = self.bool_embedding(torch.tensor(int(cell.bold)))
            italic_emb = self.bool_embedding(torch.tensor(int(cell.italic)))
            underline_emb = self.bool_embedding(torch.tensor(int(cell.underline)))
            strike_emb = self.bool_embedding(torch.tensor(int(cell.strike)))
            
            # Position 7-8: Police
            font_size = min(int(cell.font_size), self.config.max_font_size)
            font_size_emb = self.font_size_embedding(torch.tensor(font_size))
            
            font_id = self.font_family_to_id.get(cell.font_family, 0)
            font_family_emb = self.font_family_embedding(torch.tensor(font_id))
            
            # Position 9-10: Couleurs
            text_color_id = self._color_to_id(cell.text_color)
            bg_color_id = self._color_to_id(cell.background_color)
            text_color_emb = self.color_embedding(torch.tensor(text_color_id))
            bg_color_emb = self.color_embedding(torch.tensor(bg_color_id))
            
            # Position 11-14: Alignement
            align_h_emb = self.align_h_embedding(torch.tensor(cell.horizontal_align))
            align_v_emb = self.align_v_embedding(torch.tensor(cell.vertical_align))
            wrap_emb = self.bool_embedding(torch.tensor(int(cell.text_wrap)))
            rotation = min(abs(cell.text_rotation), 360)
            rotation_emb = self.rotation_embedding(torch.tensor(rotation))
            
            # Position 15-18: Bordures
            border_top_id = self._border_to_id(cell.border_top)
            border_bottom_id = self._border_to_id(cell.border_bottom)
            border_left_id = self._border_to_id(cell.border_left)
            border_right_id = self._border_to_id(cell.border_right)
            
            border_top_emb = self.border_embedding(torch.tensor(border_top_id))
            border_bottom_emb = self.border_embedding(torch.tensor(border_bottom_id))
            border_left_emb = self.border_embedding(torch.tensor(border_left_id))
            border_right_emb = self.border_embedding(torch.tensor(border_right_id))
            
            # Position 19: Fusion
            merge_emb = self.merge_encoder(cell)
            
            # Concaténer tous les embeddings
            cell_embedding = torch.cat([
                row_emb, col_emb, type_emb, value_emb,
                bold_emb, italic_emb, underline_emb, strike_emb,
                font_size_emb, font_family_emb,
                text_color_emb, bg_color_emb,
                align_h_emb, align_v_emb, wrap_emb, rotation_emb,
                border_top_emb, border_bottom_emb, border_left_emb, border_right_emb,
                merge_emb
            ], dim=0)
            
            embeddings.append(cell_embedding)
        
        # Stack et projeter
        batch_embeddings = torch.stack(embeddings)
        projected = self.projection(batch_embeddings)
        normalized = self.layer_norm(projected)
        
        return normalized.squeeze(0) if len(cells) == 1 else normalized
    
    def _color_to_id(self, color: str) -> int:
        """Convertit une couleur hex en ID"""
        if not color or color == "#FFFFFF":
            return 0  # Couleur par défaut
        
        # Convertir hex en entier et mapper
        try:
            hex_val = int(color.replace("#", ""), 16)
            return (hex_val % (self.config.color_vocab_size - 1)) + 1
        except:
            return 0
    
    def _border_to_id(self, border_style: int) -> int:
        """Convertit un style de bordure en ID"""
        # Mapping basé sur les styles Univer
        border_map = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 7: 5, 8: 6, 13: 7}
        return border_map.get(border_style, 0)
    
    def get_embedding_at_position(self, cell: 'FullCellInfo', position: int) -> torch.Tensor:
        """Récupère l'embedding d'une position spécifique"""
        full_embedding = self.forward(cell)
        
        # Mapping des positions aux dimensions dans l'embedding final
        position_slices = self._get_position_slices()
        
        if position in position_slices:
            start, end = position_slices[position]
            return full_embedding[start:end]
        else:
            raise ValueError(f"Position {position} not found")
    
    def _get_position_slices(self) -> Dict[int, Tuple[int, int]]:
        """Retourne les tranches d'embedding pour chaque position logique"""
        slices = {}
        offset = 0
        
        # Position 0: row
        slices[0] = (offset, offset + self.config.position_embedding_dim)
        offset += self.config.position_embedding_dim
        
        # Position 1: col  
        slices[1] = (offset, offset + self.config.position_embedding_dim)
        offset += self.config.position_embedding_dim
        
        # Position 2: type
        slices[2] = (offset, offset + self.config.type_embedding_dim)
        offset += self.config.type_embedding_dim
        
        # Position 3: value
        slices[3] = (offset, offset + self.value_encoder.output_dim)
        offset += self.value_encoder.output_dim
        
        # Et ainsi de suite...
        return slices

class ValueEncoder(nn.Module):
    """Encodeur spécialisé pour les valeurs de cellules avec gestion du contenu complexe"""
    
    def __init__(self, config: EmbeddingConfig):
        super().__init__()
        self.config = config
        self.max_tokens = 32  # Nombre max de tokens pour le contenu
        self.token_dim = 64   # Dimension par token
        self.output_dim = self.max_tokens * self.token_dim  # 32 * 64 = 2048D
        
        # Tokenizer simple (remplacer par un vrai tokenizer)
        self.vocab_size = 10000
        self.token_embedding = nn.Embedding(self.vocab_size, self.token_dim)
        
        # Embeddings spéciaux
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.number_token_id = 2
        self.formula_start_id = 3
        
        # Encodeur de position pour les tokens
        self.token_position_embedding = nn.Embedding(self.max_tokens, self.token_dim)
        
        # Classification du type de contenu
        self.content_type_embedding = nn.Embedding(4, self.token_dim)  # TEXT, NUMBER, FORMULA, EMPTY
        
    def forward(self, cell: 'FullCellInfo') -> torch.Tensor:
        """Encode la valeur en séquence de tokens"""
        if cell.formula:
            tokens = self._tokenize_formula(cell.formula)
            content_type = 3  # FORMULA
        elif cell.raw_value and cell.cell_type == 2:  # Number
            tokens = self._tokenize_number(cell.raw_value)
            content_type = 2  # NUMBER
        elif cell.raw_value and cell.cell_type == 1:  # Text
            tokens = self._tokenize_text(str(cell.raw_value))
            content_type = 1  # TEXT
        else:
            tokens = [self.pad_token_id] * self.max_tokens
            content_type = 0  # EMPTY
        
        # Padding/truncation à max_tokens
        if len(tokens) > self.max_tokens:
            tokens = tokens[:self.max_tokens]
        else:
            tokens.extend([self.pad_token_id] * (self.max_tokens - len(tokens)))
        
        # Convertir en embeddings
        token_ids = torch.tensor(tokens)
        token_embs = self.token_embedding(token_ids)  # [max_tokens, token_dim]
        
        # Ajouter encodage positionnel
        positions = torch.arange(self.max_tokens)
        pos_embs = self.token_position_embedding(positions)
        
        # Ajouter type de contenu à chaque token
        content_type_emb = self.content_type_embedding(torch.tensor(content_type))
        content_type_emb = content_type_emb.unsqueeze(0).expand(self.max_tokens, -1)
        
        # Combiner
        combined_embs = token_embs + pos_embs + content_type_emb  # [max_tokens, token_dim]
        
        # Aplatir pour la sortie
        return combined_embs.flatten()  # [max_tokens * token_dim]
    
    def _tokenize_text(self, text: str) -> List[int]:
        """Tokenise le texte (placeholder - utiliser un vrai tokenizer)"""
        # Simplification: split par mots et hash
        words = text.lower().split()[:self.max_tokens]
        tokens = []
        for word in words:
            token_id = (hash(word) % (self.vocab_size - 10)) + 10  # Éviter les tokens spéciaux
            tokens.append(token_id)
        return tokens
    
    def _tokenize_number(self, value: Any) -> List[int]:
        """Tokenise un nombre"""
        try:
            num_str = str(float(value))
            # Séparer en caractères pour une représentation fine
            tokens = [self.number_token_id]  # Token de début de nombre
            for char in num_str[:self.max_tokens-1]:
                if char.isdigit():
                    tokens.append(ord(char) - ord('0') + 4)  # Chiffres 0-9 → IDs 4-13
                elif char == '.':
                    tokens.append(14)  # Point décimal
                elif char == '-':
                    tokens.append(15)  # Signe négatif
                elif char == 'e' or char == 'E':
                    tokens.append(16)  # Notation scientifique
            return tokens
        except:
            return [self.unk_token_id]
    
    def _tokenize_formula(self, formula: str) -> List[int]:
        """Tokenise une formule Excel"""
        tokens = [self.formula_start_id]  # Token de début de formule
        
        # Simplification: tokenisation caractère par caractère pour les opérateurs
        # En réalité, utiliser un parser Excel
        i = 0
        while i < len(formula) and len(tokens) < self.max_tokens:
            char = formula[i]
            
            if char.isalpha():
                # Fonction ou référence de cellule
                word = ""
                while i < len(formula) and (formula[i].isalnum() or formula[i] in ".$"):
                    word += formula[i]
                    i += 1
                token_id = (hash(word.upper()) % (self.vocab_size - 100)) + 100
                tokens.append(token_id)
                continue
            elif char.isdigit():
                # Nombre dans la formule
                num = ""
                while i < len(formula) and (formula[i].isdigit() or formula[i] == '.'):
                    num += formula[i]
                    i += 1
                token_id = (hash(num) % (self.vocab_size - 200)) + 200
                tokens.append(token_id)
                continue
            else:
                # Opérateur ou symbole
                token_map = {
                    '+': 17, '-': 18, '*': 19, '/': 20, '=': 21, 
                    '(': 22, ')': 23, ',': 24, ':': 25, ';': 26,
                    '<': 27, '>': 28, '&': 29, '^': 30
                }
                tokens.append(token_map.get(char, self.unk_token_id))
            
            i += 1
        
        return tokens

class MergeEncoder(nn.Module):
    """Encodeur pour les informations de fusion"""
    
    def __init__(self, config: EmbeddingConfig):
        super().__init__()
        self.output_dim = 32
        
        # Embeddings pour les coordonnées de fusion
        self.merge_coord_embedding = nn.Embedding(config.max_position, 8)
        self.merge_projection = nn.Linear(4 * 8, self.output_dim)
        self.no_merge_embedding = nn.Parameter(torch.randn(self.output_dim))
    
    def forward(self, cell: 'FullCellInfo') -> torch.Tensor:
        """Encode les informations de fusion"""
        if not cell.is_merged:
            return self.no_merge_embedding
        
        sr, er, sc, ec = cell.merge_range
        
        # Limiter les coordonnées
        sr = min(sr, self.merge_coord_embedding.num_embeddings - 1)
        er = min(er, self.merge_coord_embedding.num_embeddings - 1)
        sc = min(sc, self.merge_coord_embedding.num_embeddings - 1)
        ec = min(ec, self.merge_coord_embedding.num_embeddings - 1)
        
        # Embeddings des coordonnées
        sr_emb = self.merge_coord_embedding(torch.tensor(sr))
        er_emb = self.merge_coord_embedding(torch.tensor(er))
        sc_emb = self.merge_coord_embedding(torch.tensor(sc))
        ec_emb = self.merge_coord_embedding(torch.tensor(ec))
        
        # Concaténer et projeter
        merge_vec = torch.cat([sr_emb, er_emb, sc_emb, ec_emb])
        return self.merge_projection(merge_vec)

class ExcelSheetEmbedder(nn.Module):
    """Embedder pour des feuilles entières"""
    
    def __init__(self, config: EmbeddingConfig):
        super().__init__()
        self.cell_embedder = ExcelCellEmbedder(config)
        self.position_encoder = PositionalEncoder(config.embedding_dim)
        
    def forward(self, cells: List['FullCellInfo'], max_cells: Optional[int] = None) -> torch.Tensor:
        """
        Embed une feuille entière
        
        Args:
            cells: Liste de cellules
            max_cells: Nombre maximum de cellules (pour padding)
            
        Returns:
            Tensor [num_cells, embedding_dim] ou [max_cells, embedding_dim]
        """
        # Trier par position
        sorted_cells = sorted(cells, key=lambda c: (c.row, c.col))
        
        if max_cells:
            # Padding ou troncature
            if len(sorted_cells) > max_cells:
                sorted_cells = sorted_cells[:max_cells]
            elif len(sorted_cells) < max_cells:
                # Créer des cellules vides pour le padding
                empty_cell = self._create_empty_cell()
                sorted_cells.extend([empty_cell] * (max_cells - len(sorted_cells)))
        
        # Embedder toutes les cellules
        cell_embeddings = self.cell_embedder(sorted_cells)
        
        # Ajouter l'encodage positionnel
        positioned_embeddings = self.position_encoder(cell_embeddings)
        
        return positioned_embeddings
    
    def _create_empty_cell(self) -> 'FullCellInfo':
        """Crée une cellule vide pour le padding"""
        # Retourner une cellule avec toutes les valeurs par défaut
        # Cette implémentation dépend de votre classe FullCellInfo
        pass

class PositionalEncoder(nn.Module):
    """Encodage positionnel pour les séquences de cellules"""
    
    def __init__(self, embedding_dim: int, max_length: int = 10000):
        super().__init__()
        
        pe = torch.zeros(max_length, embedding_dim)
        position = torch.arange(0, max_length).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() *
                           -(np.log(10000.0) / embedding_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Ajoute l'encodage positionnel"""
        seq_len = x.size(0)
        return x + self.pe[:seq_len]

# Exemple d'utilisation
if __name__ == "__main__":
    # Configuration
    config = EmbeddingConfig(
        embedding_dim=256,
        position_embedding_dim=32
    )
    
    # Créer l'embedder
    embedder = ExcelCellEmbedder(config)
    
    # Simuler une cellule (remplacer par votre vraie classe FullCellInfo)
    from dataclasses import dataclass
    from typing import Tuple
    
    @dataclass 
    class MockFullCellInfo:
        raw_value: str = "Hello"
        cell_type: int = 1
        formula: str = ""
        row: int = 5
        col: int = 3
        bold: bool = True
        italic: bool = False
        underline: bool = False
        strike: bool = False
        font_size: float = 12.0
        font_family: str = "Arial"
        text_color: str = "#FF0000"
        background_color: str = "#FFFFFF"
        horizontal_align: int = 1
        vertical_align: int = 0
        text_wrap: bool = False
        text_rotation: int = 0
        border_top: int = 1
        border_bottom: int = 0
        border_left: int = 0
        border_right: int = 0
        is_merged: bool = False
        merge_range: Tuple[int, int, int, int] = (0, 0, 0, 0)
    
    # Créer une cellule test
    cell = MockFullCellInfo()
    
    # Obtenir l'embedding
    with torch.no_grad():
        embedding = embedder(cell)
        print(f"Embedding shape: {embedding.shape}")
        print(f"Embedding (premiers 10 éléments): {embedding[:10]}")
        
        # Test avec plusieurs cellules
        cells = [cell, cell, cell]
        batch_embedding = embedder(cells)
        print(f"Batch embedding shape: {batch_embedding.shape}")
        
        # Nombre de paramètres
        total_params = sum(p.numel() for p in embedder.parameters())
        print(f"Nombre total de paramètres: {total_params:,}")

Embedding shape: torch.Size([256])
Embedding (premiers 10 éléments): tensor([ 0.9847, -0.6439,  0.8268, -1.1029, -0.0665, -0.7135, -1.6502, -1.3894,
        -1.3489,  1.0317])
Batch embedding shape: torch.Size([3, 256])
Nombre total de paramètres: 2,013,552


Création du Graph Embedding

In [19]:
class EdgeType(Enum):
    """Types d'arêtes dans le graphe Excel"""
    # Relations spatiales
    SAME_ROW = "same_row"
    SAME_COL = "same_col"
    ADJACENT_RIGHT = "adjacent_right"
    ADJACENT_DOWN = "adjacent_down"
    ADJACENT_DIAG = "adjacent_diag"
    
    # Relations de distance
    NEAR_2 = "near_2"      # Distance 2
    NEAR_3 = "near_3"      # Distance 3
    NEAR_5 = "near_5"      # Distance 5
    FAR = "far"            # Distance > 5
    
    # Relations de dépendance
    FORMULA_REF = "formula_ref"           # A1 référence B1 dans une formule
    FORMULA_RANGE = "formula_range"       # Formule utilise une plage (A1:B5)
    FORMULA_INDIRECT = "formula_indirect" # Référence indirecte (INDIRECT, etc.)
    CROSS_SHEET_REF = "cross_sheet_ref"   # Référence entre feuilles (Sheet2!A1)
    
    # Relations structurelles
    MERGED_CELL = "merged_cell"      # Cellules fusionnées
    SAME_STYLE = "same_style"        # Même style appliqué
    SAME_VALUE_TYPE = "same_value_type"  # Même type de valeur
    
    # Relations de feuille
    SAME_SHEET = "same_sheet"        # Appartiennent à la même feuille
    CROSS_SHEET = "cross_sheet"      # Appartiennent à des feuilles différentes
    
    # Relations de plage
    RANGE_START = "range_start"      # Début de plage
    RANGE_END = "range_end"          # Fin de plage
    RANGE_MEMBER = "range_member"    # Membre d'une plage

@dataclass
class GraphEdge:
    """Représente une arête dans le graphe Excel"""
    source_idx: int           # Index de la cellule source
    target_idx: int           # Index de la cellule cible
    edge_type: EdgeType       # Type de relation
    weight: float = 1.0       # Poids de l'arête
    metadata: Dict[str, Any] = None  # Métadonnées additionnelles
    
    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}

@dataclass
class ExcelGraph:
    """Représente le graphe Excel avec cellules et arêtes"""
    cell_embeddings: torch.Tensor      # [num_cells, embedding_dim]
    edge_embeddings: torch.Tensor      # [num_edges, edge_embedding_dim]
    edge_indices: torch.Tensor         # [2, num_edges] - format PyG
    edge_types: List[EdgeType]         # Type de chaque arête
    edge_weights: torch.Tensor         # [num_edges] - poids des arêtes
    cell_positions: List[Tuple[int, int]]  # [(row, col)] positions originales
    
    @property
    def num_nodes(self) -> int:
        return self.cell_embeddings.size(0)
    
    @property
    def num_edges(self) -> int:
        return self.edge_embeddings.size(0)

class ExcelGraphBuilder:
    """Construit le graphe de relations entre cellules Excel"""
    
    def __init__(self, 
                 max_distance: int = 5,
                 include_spatial: bool = True,
                 include_formula_deps: bool = True,
                 include_structural: bool = True,
                 include_sheet_relations: bool = True,
                 same_style_threshold: float = 0.8,
                 cross_sheet_weight: float = 0.3):
        """
        Args:
            max_distance: Distance maximale pour les arêtes spatiales
            include_spatial: Inclure les relations spatiales
            include_formula_deps: Inclure les dépendances de formules
            include_structural: Inclure les relations structurelles
            include_sheet_relations: Inclure les relations de feuille
            same_style_threshold: Seuil pour considérer des styles similaires
            cross_sheet_weight: Poids pour les relations inter-feuilles
        """
        self.max_distance = max_distance
        self.include_spatial = include_spatial
        self.include_formula_deps = include_formula_deps
        self.include_structural = include_structural
        self.include_sheet_relations = include_sheet_relations
        self.same_style_threshold = same_style_threshold
        self.cross_sheet_weight = cross_sheet_weight
    
    def build_graph(self, cells: List['FullCellInfo']) -> Tuple[List[GraphEdge], Dict[Tuple[str, int, int], int]]:
        """
        Construit le graphe de relations entre cellules
        
        Args:
            cells: Liste des cellules Excel
            
        Returns:
            Tuple[edges, position_to_index_map] - position inclut maintenant sheet_name
        """
        edges = []
        
        # Créer un mapping (sheet, row, col) -> index
        pos_to_idx = {(cell.sheet_name, cell.row, cell.col): i for i, cell in enumerate(cells)}
        
        # 1. Relations de feuille (en premier pour établir le contexte)
        if self.include_sheet_relations:
            edges.extend(self._build_sheet_relation_edges(cells, pos_to_idx))
        
        # 2. Relations spatiales (seulement dans la même feuille)
        if self.include_spatial:
            edges.extend(self._build_spatial_edges(cells, pos_to_idx))
        
        # 3. Relations de dépendance (formules - incluant cross-sheet)
        if self.include_formula_deps:
            edges.extend(self._build_formula_dependency_edges(cells, pos_to_idx))
        
        # 4. Relations structurelles
        if self.include_structural:
            edges.extend(self._build_structural_edges(cells, pos_to_idx))
        
        return edges, pos_to_idx
    
    def _build_sheet_relation_edges(self, cells: List['FullCellInfo'], pos_to_idx: Dict) -> List[GraphEdge]:
        """Construit les arêtes basées sur l'appartenance aux feuilles"""
        edges = []
        
        # Grouper les cellules par feuille
        sheets = {}
        for i, cell in enumerate(cells):
            if cell.sheet_name not in sheets:
                sheets[cell.sheet_name] = []
            sheets[cell.sheet_name].append(i)
        
        # Relations intra-feuille (SAME_SHEET)
        for sheet_name, cell_indices in sheets.items():
            for i in range(len(cell_indices)):
                for j in range(i + 1, len(cell_indices)):
                    idx_i, idx_j = cell_indices[i], cell_indices[j]
                    edges.append(GraphEdge(
                        idx_i, idx_j, 
                        EdgeType.SAME_SHEET, 
                        weight=1.0,
                        metadata={'sheet_name': sheet_name}
                    ))
        
        # Relations inter-feuilles (CROSS_SHEET)
        sheet_names = list(sheets.keys())
        for i in range(len(sheet_names)):
            for j in range(i + 1, len(sheet_names)):
                sheet_i, sheet_j = sheet_names[i], sheet_names[j]
                cells_i, cells_j = sheets[sheet_i], sheets[sheet_j]
                
                # Connecter toutes les cellules entre feuilles différentes
                # (peut être coûteux - limiter si nécessaire)
                for cell_idx_i in cells_i[:10]:  # Limiter à 10 cellules par feuille
                    for cell_idx_j in cells_j[:10]:
                        edges.append(GraphEdge(
                            cell_idx_i, cell_idx_j,
                            EdgeType.CROSS_SHEET,
                            weight=self.cross_sheet_weight,
                            metadata={
                                'sheet_i': sheet_i,
                                'sheet_j': sheet_j
                            }
                        ))
        
        return edges
    
    def _build_spatial_edges(self, cells: List['FullCellInfo'], pos_to_idx: Dict) -> List[GraphEdge]:
        """Construit les arêtes basées sur les relations spatiales (dans la même feuille uniquement)"""
        edges = []
        
        # Grouper par feuille pour éviter les relations spatiales inter-feuilles
        sheets = {}
        for i, cell in enumerate(cells):
            if cell.sheet_name not in sheets:
                sheets[cell.sheet_name] = []
            sheets[cell.sheet_name].append((i, cell))
        
        # Construire les relations spatiales pour chaque feuille séparément
        for sheet_name, sheet_cells in sheets.items():
            for i, (idx_i, cell_i) in enumerate(sheet_cells):
                for j, (idx_j, cell_j) in enumerate(sheet_cells):
                    if i >= j:  # Éviter les doublons
                        continue
                    
                    row_i, col_i = cell_i.row, cell_i.col
                    row_j, col_j = cell_j.row, cell_j.col
                    
                    # Distance Manhattan
                    distance = abs(row_i - row_j) + abs(col_i - col_j)
                    
                    if distance > self.max_distance:
                        # Relation "far" (même feuille)
                        edges.append(GraphEdge(
                            idx_i, idx_j, EdgeType.FAR, 
                            weight=1.0/distance,
                            metadata={'sheet_name': sheet_name}
                        ))
                        continue
                    
                    # Relations spécifiques
                    if row_i == row_j and col_i != col_j:
                        # Même ligne
                        edges.append(GraphEdge(
                            idx_i, idx_j, EdgeType.SAME_ROW, 
                            weight=1.0,
                            metadata={'sheet_name': sheet_name, 'row': row_i}
                        ))
                        
                        # Adjacent horizontal
                        if abs(col_i - col_j) == 1:
                            edges.append(GraphEdge(idx_i, idx_j, EdgeType.ADJACENT_RIGHT, weight=1.0))
                    
                    elif col_i == col_j and row_i != row_j:
                        # Même colonne
                        edges.append(GraphEdge(
                            idx_i, idx_j, EdgeType.SAME_COL, 
                            weight=1.0,
                            metadata={'sheet_name': sheet_name, 'col': col_i}
                        ))
                        
                        # Adjacent vertical
                        if abs(row_i - row_j) == 1:
                            edges.append(GraphEdge(idx_i, idx_j, EdgeType.ADJACENT_DOWN, weight=1.0))
                    
                    elif distance == 2 and abs(row_i - row_j) == 1 and abs(col_i - col_j) == 1:
                        # Adjacent diagonal
                        edges.append(GraphEdge(idx_i, idx_j, EdgeType.ADJACENT_DIAG, weight=0.7))
                    
                    # Relations de distance
                    elif distance == 2:
                        edges.append(GraphEdge(idx_i, idx_j, EdgeType.NEAR_2, weight=0.5))
                    elif distance == 3:
                        edges.append(GraphEdge(idx_i, idx_j, EdgeType.NEAR_3, weight=0.3))
                    elif distance <= 5:
                        edges.append(GraphEdge(idx_i, idx_j, EdgeType.NEAR_5, weight=0.1))
        
        return edges
    
    def _build_formula_dependency_edges(self, cells: List['FullCellInfo'], pos_to_idx: Dict) -> List[GraphEdge]:
        """Construit les arêtes basées sur les dépendances de formules (incluant cross-sheet)"""
        edges = []
        
        for i, cell in enumerate(cells):
            if not cell.formula:
                continue
            
            # Parser les références dans la formule
            references = self._parse_formula_references(cell.formula, cell.sheet_name)
            
            for ref in references:
                if ref['type'] == 'cell':
                    # Référence à une cellule (même feuille ou autre)
                    target_pos = (ref['sheet'], ref['row'], ref['col'])
                    if target_pos in pos_to_idx:
                        j = pos_to_idx[target_pos]
                        edge_type = EdgeType.CROSS_SHEET_REF if ref['sheet'] != cell.sheet_name else EdgeType.FORMULA_REF
                        weight = 0.8 if ref['sheet'] != cell.sheet_name else 1.0
                        
                        edges.append(GraphEdge(
                            i, j, edge_type, 
                            weight=weight,
                            metadata={
                                'formula_ref': ref['text'],
                                'source_sheet': cell.sheet_name,
                                'target_sheet': ref['sheet']
                            }
                        ))
                
                elif ref['type'] == 'range':
                    # Référence à une plage
                    start_row, start_col = ref['start_row'], ref['start_col']
                    end_row, end_col = ref['end_row'], ref['end_col']
                    
                    for row in range(start_row, end_row + 1):
                        for col in range(start_col, end_col + 1):
                            target_pos = (ref['sheet'], row, col)
                            if target_pos in pos_to_idx:
                                j = pos_to_idx[target_pos]
                                edge_type = EdgeType.CROSS_SHEET_REF if ref['sheet'] != cell.sheet_name else EdgeType.FORMULA_RANGE
                                weight = 0.4 if ref['sheet'] != cell.sheet_name else 0.5
                                
                                edges.append(GraphEdge(
                                    i, j, edge_type,
                                    weight=weight,
                                    metadata={
                                        'formula_range': ref['text'],
                                        'source_sheet': cell.sheet_name,
                                        'target_sheet': ref['sheet']
                                    }
                                ))
                
                elif ref['type'] == 'indirect':
                    # Référence indirecte (plus complexe à analyser)
                    edges.append(GraphEdge(
                        i, i, EdgeType.FORMULA_INDIRECT, 
                        weight=0.3,
                        metadata={'indirect_ref': ref['text']}
                    ))
        
        return edges
    
    def _parse_formula_references(self, formula: str, current_sheet: str) -> List[Dict[str, Any]]:
        """Parse les références dans une formule Excel (incluant cross-sheet)"""
        references = []
        
        # Pattern pour les références cross-sheet (Sheet1!A1, 'Sheet Name'!A1)
        cross_sheet_pattern = r"(?:'([^']+)'|([^!\s]+))!(\$?[A-Z]+\$?\d+)"
        cross_sheet_matches = re.finditer(cross_sheet_pattern, formula.upper())
        
        for match in cross_sheet_matches:
            sheet_name_quoted, sheet_name_simple, cell_ref = match.groups()
            sheet_name = sheet_name_quoted or sheet_name_simple
            
            # Parser la référence de cellule
            cell_match = re.match(r'\$?([A-Z]+)\$?(\d+)', cell_ref)
            if cell_match:
                col_str, row_str = cell_match.groups()
                col = self._col_str_to_num(col_str)
                row = int(row_str) - 1
                
                references.append({
                    'type': 'cell',
                    'sheet': sheet_name,
                    'row': row,
                    'col': col,
                    'text': match.group(0)
                })
        
        # Pattern pour les plages cross-sheet (Sheet1!A1:B5)
        cross_sheet_range_pattern = r"(?:'([^']+)'|([^!\s]+))!(\$?[A-Z]+\$?\d+:\$?[A-Z]+\$?\d+)"
        cross_sheet_range_matches = re.finditer(cross_sheet_range_pattern, formula.upper())
        
        for match in cross_sheet_range_matches:
            sheet_name_quoted, sheet_name_simple, range_ref = match.groups()
            sheet_name = sheet_name_quoted or sheet_name_simple
            
            # Parser la plage
            range_match = re.match(r'\$?([A-Z]+)\$?(\d+):\$?([A-Z]+)\$?(\d+)', range_ref)
            if range_match:
                start_col_str, start_row_str, end_col_str, end_row_str = range_match.groups()
                start_col = self._col_str_to_num(start_col_str)
                start_row = int(start_row_str) - 1
                end_col = self._col_str_to_num(end_col_str)
                end_row = int(end_row_str) - 1
                
                references.append({
                    'type': 'range',
                    'sheet': sheet_name,
                    'start_row': start_row,
                    'start_col': start_col,
                    'end_row': end_row,
                    'end_col': end_col,
                    'text': match.group(0)
                })
        
        # Pattern pour les références locales (A1, A1:B5 sans nom de feuille)
        # Retirer d'abord les références cross-sheet pour éviter les doublons
        formula_local = formula.upper()
        for match in re.finditer(cross_sheet_pattern, formula_local):
            formula_local = formula_local.replace(match.group(0), "")
        for match in re.finditer(cross_sheet_range_pattern, formula_local):
            formula_local = formula_local.replace(match.group(0), "")
        
        # Références de cellules locales
        cell_pattern = r'\$?([A-Z]+)\$?(\d+)'
        cell_matches = re.finditer(cell_pattern, formula_local)
        
        for match in cell_matches:
            col_str, row_str = match.groups()
            col = self._col_str_to_num(col_str)
            row = int(row_str) - 1
            references.append({
                'type': 'cell',
                'sheet': current_sheet,
                'row': row,
                'col': col,
                'text': match.group(0)
            })
        
        # Plages locales
        range_pattern = r'\$?([A-Z]+)\$?(\d+):\$?([A-Z]+)\$?(\d+)'
        range_matches = re.finditer(range_pattern, formula_local)
        
        for match in range_matches:
            start_col_str, start_row_str, end_col_str, end_row_str = match.groups()
            start_col = self._col_str_to_num(start_col_str)
            start_row = int(start_row_str) - 1
            end_col = self._col_str_to_num(end_col_str)
            end_row = int(end_row_str) - 1
            
            references.append({
                'type': 'range',
                'sheet': current_sheet,
                'start_row': start_row,
                'start_col': start_col,
                'end_row': end_row,
                'end_col': end_col,
                'text': match.group(0)
            })
        
        # Fonctions indirectes
        indirect_pattern = r'INDIRECT\s*\([^)]+\)'
        indirect_matches = re.finditer(indirect_pattern, formula.upper())
        
        for match in indirect_matches:
            references.append({
                'type': 'indirect',
                'text': match.group(0)
            })
        
        return references
    
    def _col_str_to_num(self, col_str: str) -> int:
        """Convertit une colonne string (A, B, ..., AA, AB) en nombre"""
        result = 0
        for char in col_str:
            result = result * 26 + (ord(char) - ord('A') + 1)
        return result - 1  # Convert to 0-based
    
    def _build_structural_edges(self, cells: List['FullCellInfo'], pos_to_idx: Dict) -> List[GraphEdge]:
        """Construit les arêtes basées sur les relations structurelles"""
        edges = []
        
        for i, cell_i in enumerate(cells):
            for j, cell_j in enumerate(cells):
                if i >= j:
                    continue
                
                # Cellules fusionnées
                if cell_i.is_merged and cell_j.is_merged:
                    if cell_i.merge_range == cell_j.merge_range:
                        edges.append(GraphEdge(i, j, EdgeType.MERGED_CELL, weight=1.0))
                
                # Même style (simplifié - comparer style_id)
                if cell_i.style_id and cell_j.style_id and cell_i.style_id == cell_j.style_id:
                    edges.append(GraphEdge(i, j, EdgeType.SAME_STYLE, weight=0.8))
                
                # Même type de valeur
                if cell_i.cell_type == cell_j.cell_type and cell_i.cell_type != 0:  # Pas vide
                    edges.append(GraphEdge(i, j, EdgeType.SAME_VALUE_TYPE, weight=0.4))
        
        return edges

class EdgeEmbedder(nn.Module):
    """Embedder pour les arêtes du graphe"""
    
    def __init__(self, embedding_dim: int = 64):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # Embedding par type d'arête
        num_edge_types = len(EdgeType)
        self.edge_type_embedding = nn.Embedding(num_edge_types, embedding_dim)
        
        # Embedding pour les poids (discrétisés)
        self.weight_bins = 20
        self.weight_embedding = nn.Embedding(self.weight_bins, embedding_dim // 4)
        
        # Projection finale
        self.projection = nn.Linear(embedding_dim + embedding_dim // 4, embedding_dim)
        
        # Mapping des types vers des IDs
        self.edge_type_to_id = {edge_type: i for i, edge_type in enumerate(EdgeType)}
    
    def forward(self, edge_types: List[EdgeType], edge_weights: torch.Tensor) -> torch.Tensor:
        """
        Embed les arêtes
        
        Args:
            edge_types: Liste des types d'arêtes
            edge_weights: Poids des arêtes [num_edges]
            
        Returns:
            Embeddings des arêtes [num_edges, embedding_dim]
        """
        # Convertir types en IDs
        type_ids = torch.tensor([self.edge_type_to_id[et] for et in edge_types])
        type_embeddings = self.edge_type_embedding(type_ids)
        
        # Discrétiser les poids
        weight_bins = torch.clamp(
            (edge_weights * self.weight_bins).long(), 
            0, self.weight_bins - 1
        )
        weight_embeddings = self.weight_embedding(weight_bins)
        
        # Combiner
        combined = torch.cat([type_embeddings, weight_embeddings], dim=1)
        return self.projection(combined)

class ExcelGraphEmbedder(nn.Module):
    """Embedder complet pour les graphes Excel"""
    
    def __init__(self, cell_embedder: ExcelCellEmbedder, edge_embedding_dim: int = 64):
        super().__init__()
        self.cell_embedder = cell_embedder
        self.edge_embedder = EdgeEmbedder(edge_embedding_dim)
        self.graph_builder = ExcelGraphBuilder()
        
    def forward(self, cells: List['FullCellInfo']) -> ExcelGraph:
        """
        Convertit une liste de cellules en graphe embedé
        
        Args:
            cells: Liste des cellules Excel
            
        Returns:
            ExcelGraph avec embeddings
        """
        # 1. Embed les cellules
        cell_embeddings = self.cell_embedder(cells)
        
        # 2. Construire le graphe
        edges, pos_to_idx = self.graph_builder.build_graph(cells)
        
        if not edges:
            # Graphe vide
            return ExcelGraph(
                cell_embeddings=cell_embeddings,
                edge_embeddings=torch.empty(0, self.edge_embedder.embedding_dim),
                edge_indices=torch.empty(2, 0, dtype=torch.long),
                edge_types=[],
                edge_weights=torch.empty(0),
                cell_positions=[(cell.row, cell.col) for cell in cells]
            )
        
        # 3. Préparer les données d'arêtes
        edge_indices = torch.tensor([[e.source_idx, e.target_idx] for e in edges]).T
        edge_types = [e.edge_type for e in edges]
        edge_weights = torch.tensor([e.weight for e in edges])
        
        # 4. Embed les arêtes
        edge_embeddings = self.edge_embedder(edge_types, edge_weights)
        
        # 5. Créer le graphe final
        return ExcelGraph(
            cell_embeddings=cell_embeddings,
            edge_embeddings=edge_embeddings,
            edge_indices=edge_indices,
            edge_types=edge_types,
            edge_weights=edge_weights,
            cell_positions=[(cell.row, cell.col) for cell in cells]
        )
    
    def get_edge_statistics(self, graph: ExcelGraph) -> Dict[str, int]:
        """Retourne des statistiques sur les types d'arêtes"""
        stats = {}
        for edge_type in graph.edge_types:
            stats[edge_type.value] = stats.get(edge_type.value, 0) + 1
        return stats

# Exemple d'utilisation et tests
if __name__ == "__main__":
    # Simuler des cellules pour tester
    from dataclasses import dataclass
    from typing import Tuple
    
    @dataclass
    class MockFullCellInfo:
        raw_value: str = ""
        cell_type: int = 0
        formula: str = ""
        row: int = 0
        col: int = 0
        sheet_name: str = "Sheet1"
        style_id: str = ""
        is_merged: bool = False
        merge_range: Tuple[int, int, int, int] = (0, 0, 0, 0)
        # Autres attributs...
        bold: bool = False
        italic: bool = False
        underline: bool = False
        strike: bool = False
        font_size: float = 11.0
        font_family: str = "Calibri"
        text_color: str = "#000000"
        background_color: str = "#FFFFFF"
        horizontal_align: int = 0
        vertical_align: int = 0
        text_wrap: bool = False
        text_rotation: int = 0
        border_top: int = 0
        border_bottom: int = 0
        border_left: int = 0
        border_right: int = 0
    
    # Créer des cellules de test avec multiple feuilles
    cells = [
        # Feuille 1
        MockFullCellInfo(raw_value="A1", row=0, col=0, cell_type=1, sheet_name="Sheet1"),
        MockFullCellInfo(raw_value="10", row=0, col=1, cell_type=2, sheet_name="Sheet1"),
        MockFullCellInfo(raw_value="B1", row=1, col=0, cell_type=1, sheet_name="Sheet1"),
        MockFullCellInfo(formula="=A1+B1", row=1, col=1, cell_type=3, sheet_name="Sheet1"),
        
        # Feuille 2
        MockFullCellInfo(raw_value="Data", row=0, col=0, cell_type=1, sheet_name="Sheet2"),
        MockFullCellInfo(raw_value="100", row=0, col=1, cell_type=2, sheet_name="Sheet2"),
        MockFullCellInfo(formula="=Sheet1!A1*2", row=1, col=0, cell_type=3, sheet_name="Sheet2"),
        
        # Feuille 3 
        MockFullCellInfo(formula="=SUM(Sheet1!A1:B2)", row=0, col=0, cell_type=3, sheet_name="Sheet3"),
    ]
    
    # Créer le graph builder et tester
    builder = ExcelGraphBuilder(include_sheet_relations=True)
    edges, pos_to_idx = builder.build_graph(cells)
    
    print(f"Nombre de cellules: {len(cells)}")
    print(f"Nombre d'arêtes: {len(edges)}")
    print(f"Nombre de feuilles: {len(set(cell.sheet_name for cell in cells))}")
    
    print(f"\nPosition mapping (sheet, row, col) -> index:")
    for pos, idx in list(pos_to_idx.items())[:5]:
        print(f"  {pos} -> {idx}")
    
    print("\nArêtes par type:")
    edge_type_counts = {}
    for edge in edges:
        edge_type_counts[edge.edge_type.value] = edge_type_counts.get(edge.edge_type.value, 0) + 1
    
    for edge_type, count in sorted(edge_type_counts.items()):
        print(f"  {edge_type}: {count}")
    
    print("\nExemples d'arêtes de feuille:")
    sheet_edges = [e for e in edges if e.edge_type in [EdgeType.SAME_SHEET, EdgeType.CROSS_SHEET]]
    for edge in sheet_edges[:5]:
        source_cell = cells[edge.source_idx]
        target_cell = cells[edge.target_idx]
        print(f"  {edge.edge_type.value}: "
              f"{source_cell.sheet_name}!({source_cell.row},{source_cell.col}) -> "
              f"{target_cell.sheet_name}!({target_cell.row},{target_cell.col}) "
              f"[weight: {edge.weight:.2f}]")
    
    print("\nExemples de références cross-sheet:")
    cross_ref_edges = [e for e in edges if e.edge_type == EdgeType.CROSS_SHEET_REF]
    for edge in cross_ref_edges:
        source_cell = cells[edge.source_idx]
        target_cell = cells[edge.target_idx]
        metadata = edge.metadata or {}
        print(f"  {edge.edge_type.value}: "
              f"{source_cell.sheet_name}!({source_cell.row},{source_cell.col}) -> "
              f"{target_cell.sheet_name}!({target_cell.row},{target_cell.col}) "
              f"[formula: {metadata.get('formula_ref', 'N/A')}]")

Nombre de cellules: 8
Nombre d'arêtes: 52
Nombre de feuilles: 3

Position mapping (sheet, row, col) -> index:
  ('Sheet1', 0, 0) -> 0
  ('Sheet1', 0, 1) -> 1
  ('Sheet1', 1, 0) -> 2
  ('Sheet1', 1, 1) -> 3
  ('Sheet2', 0, 0) -> 4

Arêtes par type:
  adjacent_diag: 3
  adjacent_down: 3
  adjacent_right: 3
  cross_sheet: 19
  formula_ref: 2
  same_col: 3
  same_row: 3
  same_sheet: 9
  same_value_type: 7

Exemples d'arêtes de feuille:
  same_sheet: Sheet1!(0,0) -> Sheet1!(0,1) [weight: 1.00]
  same_sheet: Sheet1!(0,0) -> Sheet1!(1,0) [weight: 1.00]
  same_sheet: Sheet1!(0,0) -> Sheet1!(1,1) [weight: 1.00]
  same_sheet: Sheet1!(0,1) -> Sheet1!(1,0) [weight: 1.00]
  same_sheet: Sheet1!(0,1) -> Sheet1!(1,1) [weight: 1.00]

Exemples de références cross-sheet:


Fonction pour transformer un/des json en GraphEmbedded

In [22]:
class JSONToGraphTransformer:
    """Transforme un JSON Excel en GraphEmbedded"""
    
    def __init__(self, 
                 embedding_config: Optional['EmbeddingConfig'] = None,
                 max_cells_per_sheet: int = 1000,
                 include_empty_cells: bool = False,
                 graph_config: Optional[Dict[str, Any]] = None):
        """
        Args:
            embedding_config: Configuration pour l'embedder
            max_cells_per_sheet: Limite de cellules par feuille
            include_empty_cells: Inclure les cellules vides
            graph_config: Configuration pour le graph builder
        """
        # Configuration par défaut
        if embedding_config is None:
            embedding_config = EmbeddingConfig(
                embedding_dim=256,
                position_embedding_dim=32,
                max_position=10000,
                color_vocab_size=1000
            )
        
        if graph_config is None:
            graph_config = {
                'max_distance': 5,
                'include_spatial': True,
                'include_formula_deps': True,
                'include_structural': True,
                'include_sheet_relations': True,
                'cross_sheet_weight': 0.3
            }
        
        self.embedding_config = embedding_config
        self.max_cells_per_sheet = max_cells_per_sheet
        self.include_empty_cells = include_empty_cells
        self.graph_config = graph_config
        
        # Initialiser les composants
        self.cell_embedder = ExcelCellEmbedder(embedding_config)
        self.graph_embedder = ExcelGraphEmbedder(
            self.cell_embedder, 
            edge_embedding_dim=64
        )
        
        # Configurer le graph builder
        self.graph_embedder.graph_builder = ExcelGraphBuilder(**graph_config)
    
    def transform(self, 
                  json_data: Union[str, Dict[str, Any]], 
                  filter_sheets: Optional[List[str]] = None,
                  max_total_cells: Optional[int] = None) -> 'ExcelGraph':
        """
        Transforme un JSON Excel en ExcelGraph
        
        Args:
            json_data: JSON string ou dict contenant les données Excel
            filter_sheets: Liste des noms de feuilles à inclure (None = toutes)
            max_total_cells: Limite totale de cellules (None = pas de limite)
            
        Returns:
            ExcelGraph avec embeddings
        """
        # 1. Parser le JSON
        if isinstance(json_data, str):
            excel_data = json.loads(json_data)
        else:
            excel_data = json_data
        
        # 2. Extraire les cellules
        all_cells = self._extract_cells_from_json(excel_data, filter_sheets)
        
        # 3. Filtrer et limiter les cellules
        filtered_cells = self._filter_cells(all_cells, max_total_cells)
        
        # 4. Transformer en graphe embedé
        excel_graph = self.graph_embedder(filtered_cells)
        
        return excel_graph
    
    def _extract_cells_from_json(self, 
                                 excel_data: Dict[str, Any], 
                                 filter_sheets: Optional[List[str]] = None) -> List['FullCellInfo']:
        """Extrait les cellules du JSON Excel"""
        try:
            # Utiliser le parser existant
            all_cells = ExcelParser.parse_excel_json(excel_data)
            
            # S'assurer qu'on a une liste
            if not isinstance(all_cells, list):
                if all_cells is None:
                    return []
                # Si c'est un seul objet FullCellInfo, le mettre dans une liste
                return [all_cells]
            
            # Filtrer par feuilles si spécifié
            if filter_sheets:
                all_cells = [cell for cell in all_cells if cell.sheet_name in filter_sheets]
            
            return all_cells
            
        except Exception as e:
            print(f"Erreur lors du parsing JSON: {e}")
            return []
    
    def _filter_cells(self, 
                      cells: List['FullCellInfo'], 
                      max_total_cells: Optional[int] = None) -> List['FullCellInfo']:
        """Filtre et limite les cellules selon les critères"""
        # Vérifier que cells est bien une liste
        if not isinstance(cells, list):
            if cells is None:
                return []
            return [cells]  # Convertir un seul objet en liste
        
        if not cells:  # Liste vide
            return []
        
        filtered_cells = []
        
        # Grouper par feuille
        sheets = {}
        for cell in cells:
            if cell.sheet_name not in sheets:
                sheets[cell.sheet_name] = []
            sheets[cell.sheet_name].append(cell)
        
        # Filtrer par feuille
        for sheet_name, sheet_cells in sheets.items():
            # Filtrer les cellules vides si nécessaire
            if not self.include_empty_cells:
                sheet_cells = [
                    cell for cell in sheet_cells 
                    if cell.raw_value or cell.formula or cell.style_id or cell.is_merged
                ]
            
            # Trier par priorité (formules > valeurs > style > vides)
            sheet_cells.sort(key=self._cell_priority, reverse=True)
            
            # Limiter par feuille
            if len(sheet_cells) > self.max_cells_per_sheet:
                sheet_cells = sheet_cells[:self.max_cells_per_sheet]
            
            filtered_cells.extend(sheet_cells)
        
        # Limiter le total si spécifié
        if max_total_cells and len(filtered_cells) > max_total_cells:
            # Trier par priorité globale et prendre les plus importantes
            filtered_cells.sort(key=self._cell_priority, reverse=True)
            filtered_cells = filtered_cells[:max_total_cells]
        
        return filtered_cells
    
    def _cell_priority(self, cell: 'FullCellInfo') -> int:
        """Calcule la priorité d'une cellule pour le filtrage"""
        priority = 0
        
        # Formules ont la plus haute priorité
        if cell.formula:
            priority += 1000
        
        # Cellules avec valeurs
        if cell.raw_value:
            priority += 500
            
            # Bonus selon le type
            if cell.cell_type == 2:  # Nombre
                priority += 100
            elif cell.cell_type == 1:  # Texte
                priority += 50
        
        # Cellules avec style
        if cell.style_id and cell.style_id != "s0":
            priority += 200
        
        # Cellules fusionnées
        if cell.is_merged:
            priority += 150
        
        # Proximité du coin supérieur gauche (cellules importantes souvent en haut à gauche)
        distance_from_origin = cell.row + cell.col
        priority += max(0, 100 - distance_from_origin)
        
        return priority
    
    def transform_batch(self, 
                        json_files: List[Union[str, Dict[str, Any]]],
                        batch_size: int = 32) -> List['ExcelGraph']:
        """
        Transforme plusieurs JSON en batch
        
        Args:
            json_files: Liste de JSON (strings ou dicts)
            batch_size: Taille des batches pour l'embedding
            
        Returns:
            Liste d'ExcelGraph
        """
        graphs = []
        
        for i in range(0, len(json_files), batch_size):
            batch = json_files[i:i + batch_size]
            batch_graphs = []
            
            for json_data in batch:
                try:
                    graph = self.transform(json_data)
                    batch_graphs.append(graph)
                except Exception as e:
                    print(f"Erreur lors de la transformation: {e}")
                    # Créer un graphe vide en cas d'erreur
                    empty_graph = self._create_empty_graph()
                    batch_graphs.append(empty_graph)
            
            graphs.extend(batch_graphs)
        
        return graphs
    
    def _create_empty_graph(self) -> 'ExcelGraph':
        """Crée un graphe vide en cas d'erreur"""
        return ExcelGraph(
            cell_embeddings=torch.empty(0, self.embedding_config.embedding_dim),
            edge_embeddings=torch.empty(0, 64),
            edge_indices=torch.empty(2, 0, dtype=torch.long),
            edge_types=[],
            edge_weights=torch.empty(0),
            cell_positions=[]
        )
    
    def get_graph_statistics(self, graph: 'ExcelGraph') -> Dict[str, Any]:
        """Retourne des statistiques détaillées sur le graphe"""
        if graph.num_nodes == 0:
            return {'empty_graph': True}
        
        # Statistiques de base
        stats = {
            'num_nodes': graph.num_nodes,
            'num_edges': graph.num_edges,
            'avg_degree': graph.num_edges * 2 / graph.num_nodes if graph.num_nodes > 0 else 0
        }
        
        # Statistiques par type d'arête
        edge_type_counts = {}
        for edge_type in graph.edge_types:
            edge_type_counts[edge_type.value] = edge_type_counts.get(edge_type.value, 0) + 1
        stats['edge_types'] = edge_type_counts
        
        # Statistiques des feuilles
        sheets = set(pos[0] if isinstance(pos, tuple) and len(pos) >= 1 else "unknown" 
                    for pos in graph.cell_positions)
        stats['num_sheets'] = len(sheets)
        stats['sheet_names'] = list(sheets)
        
        # Statistiques des positions
        if graph.cell_positions:
            positions = [(pos[1], pos[2]) if isinstance(pos, tuple) and len(pos) >= 3 
                        else pos for pos in graph.cell_positions]
            if positions and all(isinstance(p, tuple) and len(p) == 2 for p in positions):
                rows = [p[0] for p in positions]
                cols = [p[1] for p in positions]
                stats['position_range'] = {
                    'min_row': min(rows),
                    'max_row': max(rows),
                    'min_col': min(cols),
                    'max_col': max(cols)
                }
        
        return stats
    
    def save_graph(self, graph: 'ExcelGraph', filepath: str):
        """Sauvegarde un graphe au format PyTorch"""
        torch.save({
            'cell_embeddings': graph.cell_embeddings,
            'edge_embeddings': graph.edge_embeddings,
            'edge_indices': graph.edge_indices,
            'edge_types': [et.value for et in graph.edge_types],
            'edge_weights': graph.edge_weights,
            'cell_positions': graph.cell_positions,
            'config': self.embedding_config,
            'graph_config': self.graph_config
        }, filepath)
    
    def load_graph(self, filepath: str) -> 'ExcelGraph':
        """Charge un graphe depuis un fichier"""
        data = torch.load(filepath)
        
        # Reconstituer les edge_types
        edge_types = [EdgeType(et) for et in data['edge_types']]
        
        return ExcelGraph(
            cell_embeddings=data['cell_embeddings'],
            edge_embeddings=data['edge_embeddings'],
            edge_indices=data['edge_indices'],
            edge_types=edge_types,
            edge_weights=data['edge_weights'],
            cell_positions=data['cell_positions']
        )

# Fonction utilitaire standalone
def json_to_excel_graph(json_data: Union[str, Dict[str, Any]], 
                        **kwargs) -> 'ExcelGraph':
    """
    Fonction utilitaire pour transformer rapidement un JSON en ExcelGraph
    
    Args:
        json_data: JSON Excel (string ou dict)
        **kwargs: Arguments pour JSONToGraphTransformer
        
    Returns:
        ExcelGraph
    """
    transformer = JSONToGraphTransformer(**kwargs)
    return transformer.transform(json_data)

# Exemple d'utilisation
if __name__ == "__main__":
    # JSON d'exemple (format Univer)
    example_json = {
        "styles": {
            "s0": {"fs": 12.0},
            "s1": {"fs": 16.0, "bl": 1, "cl": {"rgb": "#FFFFFF"}, "bg": {"rgb": "#4470C4"}}
        },
        "sheets": {
            "sheet1": {
                "id": "sheet1",
                "name": "Sheet1",
                "hidden": 0,
                "rowCount": 10,
                "columnCount": 10,
                "mergeData": [],
                "cellData": {
                    "0": {
                        "0": {"v": "Hello", "t": 1, "s": "s0"},
                        "1": {"v": "World", "t": 1, "s": "s1"}
                    },
                    "1": {
                        "0": {"v": 42, "t": 2, "s": "s0"},
                        "1": {"f": "=A1&B1", "t": 3, "s": "s0"}
                    }
                }
            },
            "sheet2": {
                "id": "sheet2", 
                "name": "Sheet2",
                "hidden": 0,
                "rowCount": 5,
                "columnCount": 5,
                "mergeData": [],
                "cellData": {
                    "0": {
                        "0": {"v": "Data", "t": 1, "s": "s0"},
                        "1": {"f": "=Sheet1!A1", "t": 3, "s": "s0"}
                    }
                }
            }
        }
    }
    
    # Créer le transformer
    transformer = JSONToGraphTransformer(
        max_cells_per_sheet=100,
        include_empty_cells=False
    )
    
    # Transformer le JSON
    excel_graph = transformer.transform(example_json)
    
    # Afficher les statistiques
    stats = transformer.get_graph_statistics(excel_graph)
    print("Statistiques du graphe:")
    print(f"  Nombre de nœuds: {stats['num_nodes']}")
    print(f"  Nombre d'arêtes: {stats['num_edges']}")
    print(f"  Degré moyen: {stats['avg_degree']:.2f}")
    print(f"  Nombre de feuilles: {stats['num_sheets']}")
    print(f"  Feuilles: {stats['sheet_names']}")
    
    if 'edge_types' in stats:
        print("  Types d'arêtes:")
        for edge_type, count in stats['edge_types'].items():
            print(f"    {edge_type}: {count}")
    
    print(f"\nDimensions des embeddings:")
    print(f"  Cellules: {excel_graph.cell_embeddings.shape}")
    print(f"  Arêtes: {excel_graph.edge_embeddings.shape}")
    print(f"  Indices d'arêtes: {excel_graph.edge_indices.shape}")
    
    # Test avec une liste de JSON
    json_list = [example_json, example_json]
    graphs = transformer.transform_batch(json_list)
    print(f"\nBatch transformé: {len(graphs)} graphes")
    
    # Sauvegarde (optionnel)
    # transformer.save_graph(excel_graph, "example_graph.pt")
    
    # Test de la fonction utilitaire
    quick_graph = json_to_excel_graph(example_json, max_cells_per_sheet=50)
    print(f"\nGraphe rapide: {quick_graph.num_nodes} nœuds, {quick_graph.num_edges} arêtes")

Statistiques du graphe:
  Nombre de nœuds: 256
  Nombre d'arêtes: 0
  Degré moyen: 0.00
  Nombre de feuilles: 1
  Feuilles: [0]
  Types d'arêtes:

Dimensions des embeddings:
  Cellules: torch.Size([256])
  Arêtes: torch.Size([0, 64])
  Indices d'arêtes: torch.Size([2, 0])

Batch transformé: 2 graphes

Graphe rapide: 256 nœuds, 0 arêtes
