<a href="https://colab.research.google.com/github/saito1111/Speaker-diarization-/blob/main/Enhanced_Diarization_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üéôÔ∏è Enhanced Multi-Channel Speaker Diarization Training

**Objectif :** Entra√Æner un syst√®me de diarization de locuteurs avanc√© sur le corpus AMI  
**Architecture :** TCN multi-√©chelle avec attention, classification de locuteurs et gestion m√©moire optimis√©e

---

## üìã Table des Mati√®res
1. [Configuration de l'Environnement](#setup)
2. [T√©l√©chargement du Corpus AMI](#data)
3. [Pr√©paration des Donn√©es](#preprocessing)
4. [Mod√®le et Configuration](#model)
5. [Entra√Ænement](#training)
6. [√âvaluation](#evaluation)
7. [Sauvegarde et Visualisations](#results)

## üîß 1. Configuration de l'Environnement

Installation de Conda et des d√©pendances n√©cessaires pour l'entra√Ænement.

In [None]:
# Installation de Conda sur Google Colab
!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!conda --version

import sys
sys.path.append('/usr/local/lib/python3.10/site-packages')

In [None]:
# Cr√©ation de l'environnement conda pour la diarization
!conda create -n diarization python=3.9 -y
!conda activate diarization

# Installation des d√©pendances principales via conda
!conda install -n diarization pytorch torchaudio cudatoolkit=11.8 -c pytorch -c nvidia -y
!conda install -n diarization numpy scipy scikit-learn matplotlib seaborn pandas -y

# Activation de l'environnement dans le notebook
import os
os.environ['CONDA_DEFAULT_ENV'] = 'diarization'
os.environ['PATH'] = '/usr/local/envs/diarization/bin:' + os.environ['PATH']

In [None]:
# Installation des d√©pendances sp√©cialis√©es
!pip install wandb optuna tqdm psutil
!pip install librosa soundfile
!pip install speechbrain

# V√©rification des installations
import torch
import torchaudio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

### üìÇ Clonage du R√©pertoire et Configuration

In [None]:
# Cloner le projet depuis GitHub
import os
from pathlib import Path

# V√©rifier si le repository est d√©j√† pr√©sent
current_dir = Path.cwd()
repo_name = "Speaker-diarization-"

print(f"üìÅ R√©pertoire courant: {current_dir}")

# Si on n'est pas d√©j√† dans le repository, le cloner
if not (current_dir / "src").exists():
    print("üì• Clonage du repository...")
    !git clone https://github.com/saito1111/Speaker-diarization-.git
    
    # V√©rifier que le clonage a r√©ussi
    if Path(repo_name).exists():
        print(f"‚úÖ Repository clon√© avec succ√®s dans {repo_name}/")
        %cd {repo_name}
        print(f"üìÇ Changement de r√©pertoire vers: {Path.cwd()}")
    else:
        print("‚ùå ERREUR: Le clonage a √©chou√©!")
        print("üîß V√©rifiez votre connexion internet et l'URL du repository")
        raise Exception("Clonage du repository √©chou√©")
else:
    print("‚úÖ Repository d√©j√† pr√©sent")

# V√©rifier la structure du projet
print("\nüìã V√©rification de la structure du projet:")
print("üìÅ Racine du projet:")
!ls -la

print("\nüìÇ Contenu du r√©pertoire src/:")
if Path("src").exists():
    !ls -la src/
    print("‚úÖ R√©pertoire src trouv√©")
else:
    print("‚ùå ERREUR: R√©pertoire src manquant!")
    raise FileNotFoundError("Le r√©pertoire src n'existe pas apr√®s le clonage")

print("\nüéâ Structure du projet v√©rifi√©e avec succ√®s!")

In [None]:
# Configuration des chemins et imports
import sys
sys.path.append('./src')

# Cr√©er les dossiers n√©cessaires
!mkdir -p data/ami_corpus/audio
!mkdir -p data/ami_corpus/annotations
!mkdir -p models/checkpoints
!mkdir -p results/logs
!mkdir -p results/figures

# Variables globales
DATA_DIR = Path('./data/ami_corpus')
AUDIO_DIR = DATA_DIR / 'audio'
ANNOTATION_DIR = DATA_DIR / 'annotations'
MODEL_DIR = Path('./models/checkpoints')
RESULTS_DIR = Path('./results')

print(f"R√©pertoires configur√©s:")
print(f"- Audio: {AUDIO_DIR}")
print(f"- Annotations: {ANNOTATION_DIR}")
print(f"- Mod√®les: {MODEL_DIR}")
print(f"- R√©sultats: {RESULTS_DIR}")

## üìä 2. T√©l√©chargement et Pr√©paration du Corpus AMI

Le corpus AMI contient des enregistrements de r√©unions avec annotations temporelles des locuteurs.

In [None]:
# T√©l√©chargement r√©el du corpus AMI avec v√©rification et stockage persistant
# URL officielle: https://groups.inf.ed.ac.uk/ami/corpus/

import urllib.request
import zipfile
from tqdm import tqdm
import hashlib
import shutil
from pathlib import Path

class AMICorpusDownloader:
    """Gestionnaire complet pour le t√©l√©chargement et la v√©rification du corpus AMI."""
    
    def __init__(self, base_dir):
        self.base_dir = Path(base_dir)
        self.audio_dir = self.base_dir / "ami_audio"
        self.annotation_dir = self.base_dir / "ami_annotations"
        self.download_dir = self.base_dir / "downloads"
        
        # Cr√©er les r√©pertoires
        for dir_path in [self.audio_dir, self.annotation_dir, self.download_dir]:
            dir_path.mkdir(parents=True, exist_ok=True)
    
    def download_with_progress_bar(self, url, filename, description="T√©l√©chargement"):
        """T√©l√©charge un fichier avec barre de progression d√©taill√©e."""
        print(f"üîΩ {description}: {url}")
        
        class ProgressBar:
            def __init__(self):
                self.pbar = None
            
            def __call__(self, block_num, block_size, total_size):
                if not self.pbar:
                    self.pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=description)
                
                downloaded = block_num * block_size
                if downloaded < total_size:
                    self.pbar.update(block_size)
                else:
                    self.pbar.close()
        
        urllib.request.urlretrieve(url, filename, ProgressBar())
        print(f"‚úÖ T√©l√©charg√©: {filename}")
    
    def verify_file_integrity(self, filepath, expected_size=None, expected_hash=None):
        """V√©rifie l'int√©grit√© d'un fichier t√©l√©charg√©."""
        if not filepath.exists():
            return False, "Fichier inexistant"
        
        file_size = filepath.stat().st_size
        if expected_size and file_size != expected_size:
            return False, f"Taille incorrecte: {file_size} vs {expected_size}"
        
        if expected_hash:
            with open(filepath, 'rb') as f:
                file_hash = hashlib.md5(f.read()).hexdigest()
            if file_hash != expected_hash:
                return False, f"Hash incorrect: {file_hash} vs {expected_hash}"
        
        return True, "Fichier valide"
    
    def check_existing_files(self):
        """V√©rifie si les fichiers AMI sont d√©j√† pr√©sents et valides."""
        print("üîç V√©rification des fichiers existants...")
        
        # Fichiers audio essentiels
        required_audio = [
            "ES2002a.Headset-0.wav", "ES2002a.Headset-1.wav", 
            "ES2002a.Headset-2.wav", "ES2002a.Headset-3.wav",
            "ES2002b.Headset-0.wav", "ES2002b.Headset-1.wav",
            "ES2002c.Headset-0.wav", "ES2002c.Headset-1.wav",
            "ES2002d.Headset-0.wav", "ES2002d.Headset-1.wav"
        ]
        
        existing_audio = []
        for audio_file in required_audio:
            audio_path = self.audio_dir / audio_file
            if audio_path.exists() and audio_path.stat().st_size > 1000000:  # >1MB
                existing_audio.append(audio_file)
        
        # Fichiers d'annotations
        annotation_files = list(self.annotation_dir.glob("*.rttm"))
        
        print(f"üìÅ Fichiers audio trouv√©s: {len(existing_audio)}/{len(required_audio)}")
        print(f"üìÅ Fichiers d'annotation trouv√©s: {len(annotation_files)}")
        
        return {
            'audio_complete': len(existing_audio) >= 8,  # Au moins 8 fichiers audio
            'annotations_complete': len(annotation_files) >= 10,
            'existing_audio': existing_audio,
            'existing_annotations': annotation_files
        }
    
    def download_ami_audio_files(self):
        """T√©l√©charge les fichiers audio AMI essentiels."""
        print("üéµ T√©l√©chargement des fichiers audio AMI...")
        
        # URLs r√©elles pour les fichiers audio AMI (exemple pour ES2002)
        base_url = "https://groups.inf.ed.ac.uk/ami/AMICorpusAnnotations/amicorpus"
        audio_urls = {
            "ES2002a.Headset-0.wav": f"{base_url}/ES2002a/audio/ES2002a.Headset-0.wav",
            "ES2002a.Headset-1.wav": f"{base_url}/ES2002a/audio/ES2002a.Headset-1.wav",
            "ES2002a.Headset-2.wav": f"{base_url}/ES2002a/audio/ES2002a.Headset-2.wav",
            "ES2002a.Headset-3.wav": f"{base_url}/ES2002a/audio/ES2002a.Headset-3.wav",
            "ES2002b.Headset-0.wav": f"{base_url}/ES2002b/audio/ES2002b.Headset-0.wav",
            "ES2002b.Headset-1.wav": f"{base_url}/ES2002b/audio/ES2002b.Headset-1.wav",
        }
        
        downloaded_files = []
        
        for filename, url in audio_urls.items():
            audio_path = self.audio_dir / filename
            
            if audio_path.exists() and audio_path.stat().st_size > 1000000:
                print(f"‚úÖ D√©j√† pr√©sent: {filename}")
                downloaded_files.append(filename)
                continue
            
            try:
                self.download_with_progress_bar(url, str(audio_path), f"Audio {filename}")
                
                # V√©rifier la taille
                if audio_path.stat().st_size > 1000000:
                    downloaded_files.append(filename)
                    print(f"‚úÖ T√©l√©charg√© avec succ√®s: {filename}")
                else:
                    print(f"‚ö†Ô∏è Fichier trop petit: {filename}")
                    
            except Exception as e:
                print(f"‚ùå Erreur pour {filename}: {e}")
        
        return downloaded_files
    
    def download_ami_annotations(self):
        """T√©l√©charge les annotations AMI."""
        print("? T√©l√©chargement des annotations AMI...")
        
        annotation_url = "https://groups.inf.ed.ac.uk/ami/AMICorpusAnnotations/ami_public_manual_1.6.2.zip"
        zip_path = self.download_dir / "ami_annotations.zip"
        
        if not zip_path.exists():
            try:
                self.download_with_progress_bar(
                    annotation_url, 
                    str(zip_path), 
                    "Annotations AMI"
                )
            except Exception as e:
                print(f"‚ùå Erreur t√©l√©chargement annotations: {e}")
                return False
        
        # Extraction
        try:
            print("üì¶ Extraction des annotations...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(self.annotation_dir)
            
            # Chercher les fichiers RTTM extraits
            rttm_files = list(self.annotation_dir.rglob("*.rttm"))
            print(f"‚úÖ Extraction termin√©e: {len(rttm_files)} fichiers RTTM trouv√©s")
            return True
            
        except Exception as e:
            print(f"‚ùå Erreur extraction: {e}")
            return False

# Initialiser le t√©l√©chargeur AMI
ami_downloader = AMICorpusDownloader(DATA_DIR)

# V√©rifier les fichiers existants
existing_status = ami_downloader.check_existing_files()

print("\n" + "="*60)
print("üìä STATUT DU CORPUS AMI")
print("="*60)

if existing_status['audio_complete'] and existing_status['annotations_complete']:
    print("üéâ Corpus AMI complet d√©j√† pr√©sent!")
    print(f"   üìÅ Audio: {len(existing_status['existing_audio'])} fichiers")
    print(f"   üìÅ Annotations: {len(existing_status['existing_annotations'])} fichiers")
    print("\n‚úÖ Utilisation des fichiers existants...")
    
else:
    print("‚ö†Ô∏è Corpus AMI incomplet - T√©l√©chargement n√©cessaire")
    
    # T√©l√©charger les fichiers manquants
    if not existing_status['audio_complete']:
        print("\nüîΩ T√©l√©chargement des fichiers audio...")
        downloaded_audio = ami_downloader.download_ami_audio_files()
        print(f"‚úÖ Audio t√©l√©charg√©: {len(downloaded_audio)} fichiers")
    
    if not existing_status['annotations_complete']:
        print("\nüîΩ T√©l√©chargement des annotations...")
        ami_downloader.download_ami_annotations()

# Rapport final
final_status = ami_downloader.check_existing_files()
print("\n" + "="*60)
print("üìà RAPPORT FINAL")
print("="*60)
print(f"üéµ Fichiers audio: {len(final_status['existing_audio'])} disponibles")
print(f"üìù Fichiers annotation: {len(final_status['existing_annotations'])} disponibles")

if final_status['audio_complete']:
    print("‚úÖ Corpus audio pr√™t pour l'entra√Ænement!")
else:
    print("‚ö†Ô∏è Corpus audio incomplet - Mode d√©monstration activ√©")

# Configurer les chemins pour la suite
AUDIO_PATH = ami_downloader.audio_dir
ANNOTATION_PATH = ami_downloader.annotation_dir

print(f"\nüìÇ Chemins configur√©s:")
print(f"   üéµ Audio: {AUDIO_PATH}")
print(f"   üìù Annotations: {ANNOTATION_PATH}")

In [None]:
# Validation finale et pr√©paration des donn√©es pour l'entra√Ænement

def validate_ami_corpus():
    """Valide que le corpus AMI est pr√™t pour l'entra√Ænement."""
    print("? Validation finale du corpus AMI...")
    
    audio_files = list(AUDIO_PATH.glob("*.wav"))
    annotation_files = list(ANNOTATION_PATH.rglob("*.rttm"))
    
    print(f"\nüìä Inventaire final:")
    print(f"   üéµ Fichiers audio: {len(audio_files)}")
    print(f"   üìù Fichiers RTTM: {len(annotation_files)}")
    
    # Afficher quelques exemples
    if audio_files:
        print(f"\nüéµ Exemples audio:")
        for i, audio_file in enumerate(audio_files[:5]):
            size_mb = audio_file.stat().st_size / (1024*1024)
            print(f"   - {audio_file.name} ({size_mb:.1f} MB)")
    
    if annotation_files:
        print(f"\nüìù Exemples annotations:")
        for i, rttm_file in enumerate(annotation_files[:5]):
            lines = len(open(rttm_file).readlines())
            print(f"   - {rttm_file.name} ({lines} segments)")
    
    # Statut global
    corpus_ready = len(audio_files) >= 4 and len(annotation_files) >= 4
    
    if corpus_ready:
        print(f"\n‚úÖ Corpus AMI valid√© et pr√™t!")
        print(f"üöÄ Pr√™t pour l'entra√Ænement avec {len(audio_files)} fichiers audio")
    else:
        print(f"\n‚ö†Ô∏è Corpus incomplet - Mode d√©monstration recommand√©")
    
    return {
        'ready': corpus_ready,
        'audio_count': len(audio_files),
        'annotation_count': len(annotation_files),
        'audio_files': audio_files,
        'annotation_files': annotation_files
    }

# Valider le corpus
corpus_status = validate_ami_corpus()

print(f"\n{'='*60}")
print(f"üéØ CORPUS AMI: {'PR√äT' if corpus_status['ready'] else 'PARTIEL'}")
print(f"{'='*60}")

# Configurer les chemins finaux pour l'entra√Ænement
if corpus_status['ready']:
    print("‚úÖ Utilisation du corpus AMI r√©el")
    FINAL_AUDIO_DIR = AUDIO_PATH
    FINAL_ANNOTATION_DIR = ANNOTATION_PATH
else:
    print("‚ö†Ô∏è Basculement vers donn√©es de d√©monstration")
    # On gardera quand m√™me les fichiers partiels s'ils existent
    FINAL_AUDIO_DIR = AUDIO_PATH if corpus_status['audio_count'] > 0 else DATA_DIR / "demo_audio"
    FINAL_ANNOTATION_DIR = ANNOTATION_PATH if corpus_status['annotation_count'] > 0 else DATA_DIR / "demo_annotations"

print(f"\nüìÇ Configuration finale:")
print(f"   üéµ Audio: {FINAL_AUDIO_DIR}")
print(f"   üìù Annotations: {FINAL_ANNOTATION_DIR}")
print(f"   üìä Mode: {'Production (AMI)' if corpus_status['ready'] else 'D√©monstration/Partiel'}")

### üìä Division des Donn√©es (Train/Eval)

Division stratifi√©e du corpus AMI selon les bonnes pratiques.

In [None]:
def create_ami_splits(audio_dir, rttm_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """
    Divise le corpus AMI en ensembles d'entra√Ænement, validation et test.
    
    Args:
        audio_dir: R√©pertoire des fichiers audio
        rttm_dir: R√©pertoire des annotations RTTM
        train_ratio: Proportion pour l'entra√Ænement
        val_ratio: Proportion pour la validation
        test_ratio: Proportion pour le test
    
    Returns:
        dict: Dictionnaire avec les listes de fichiers pour chaque split
    """
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Les ratios doivent sommer √† 1.0"
    
    # Lister tous les fichiers disponibles
    audio_path = Path(audio_dir)
    rttm_path = Path(rttm_dir)
    
    # Trouver les fichiers audio
    audio_extensions = ['.wav', '.pt', '.flac', '.mp3']
    audio_files = []
    for ext in audio_extensions:
        audio_files.extend(list(audio_path.glob(f'*{ext}')))
    
    # V√©rifier la correspondance audio-RTTM
    valid_pairs = []
    for audio_file in audio_files:
        base_name = audio_file.stem
        rttm_file = rttm_path / f"{base_name}.rttm"
        
        if rttm_file.exists():
            valid_pairs.append({
                'base_name': base_name,
                'audio_file': str(audio_file),
                'rttm_file': str(rttm_file)
            })
    
    print(f"üìä Trouv√© {len(valid_pairs)} paires audio-RTTM valides")
    
    if len(valid_pairs) == 0:
        raise ValueError("Aucune paire audio-RTTM valide trouv√©e!")
    
    # M√©langer et diviser
    import random
    random.seed(42)  # Pour la reproductibilit√©
    random.shuffle(valid_pairs)
    
    n_total = len(valid_pairs)
    n_train = int(n_total * train_ratio)
    n_val = int(n_total * val_ratio)
    
    train_files = valid_pairs[:n_train]
    val_files = valid_pairs[n_train:n_train + n_val]
    test_files = valid_pairs[n_train + n_val:]
    
    splits = {
        'train': train_files,
        'validation': val_files,
        'test': test_files
    }
    
    # Statistiques
    print(f"\nüìà Division des donn√©es:")
    for split_name, files in splits.items():
        print(f"   {split_name:>10}: {len(files):>3} fichiers ({len(files)/n_total*100:.1f}%)")
    
    # Sauvegarder les splits
    splits_file = DATA_DIR / 'data_splits.json'
    with open(splits_file, 'w') as f:
        json.dump(splits, f, indent=2)
    
    print(f"\nüíæ Splits sauvegard√©s dans: {splits_file}")
    return splits

# Cr√©er les splits
data_splits = create_ami_splits(audio_dir, rttm_dir, 
                               train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

# V√©rifier le contenu
print("\nüîç Exemple de fichiers par split:")
for split_name, files in data_splits.items():
    if files:
        print(f"\n{split_name.upper()}:")
        for i, file_info in enumerate(files[:3]):  # Afficher les 3 premiers
            print(f"  {i+1}. {file_info['base_name']}")
        if len(files) > 3:
            print(f"  ... et {len(files)-3} autres")

## üõ†Ô∏è 3. Pr√©paration des Donn√©es et Extraction de Caract√©ristiques

Extraction des caract√©ristiques multi-canaux: LPS, IPD, AF

In [None]:
# Configuration du path et import des modules du projet
import sys
import os
from pathlib import Path

# Ajouter le r√©pertoire src au path Python
current_dir = Path.cwd()
src_path = current_dir / 'src'

print(f"üìÅ R√©pertoire courant: {current_dir}")
print(f"üìÇ R√©pertoire src: {src_path}")
print(f"üìÇ V√©rification existence src: {src_path.exists()}")

# V√©rifier si le r√©pertoire src existe
if not src_path.exists():
    print("‚ùå ERREUR: Le r√©pertoire 'src' n'existe pas!")
    print("üîß Assurez-vous d'avoir clon√© le repository complet avec:")
    print("   !git clone https://github.com/saito1111/Speaker-diarization-.git")
    print("   %cd Speaker-diarization-")
    raise FileNotFoundError("R√©pertoire 'src' manquant. Clonez d'abord le repository.")

# Lister les fichiers dans src pour v√©rification
py_files = list(src_path.glob("*.py"))
print(f"üìÑ Fichiers Python trouv√©s: {[f.name for f in py_files]}")

# V√©rifier que TOUS les modules requis existent
required_modules = [
    'tcn_diarization_model.py',
    'metrics.py', 
    'dataset.py',
    'diarization_losses.py',
    'improved_trainer.py',
    'optimized_dataloader.py',
    'optimized_dataset.py'
]

missing_modules = []
for module in required_modules:
    if not (src_path / module).exists():
        missing_modules.append(module)

if missing_modules:
    print(f"‚ùå ERREUR: Modules manquants dans src/: {missing_modules}")
    print("üîß V√©rifiez que tous les fichiers sont pr√©sents dans le repository")
    raise FileNotFoundError(f"Modules manquants: {missing_modules}")

# Ajouter src au path
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

print(f"‚úÖ R√©pertoire src ajout√© au path: {src_path}")

# Import de TOUS les modules du projet (OBLIGATOIRES - pas de fallback)
print("\nüîÑ Import de TOUS les modules du projet...")

try:
    from tcn_diarization_model import DiarizationTCN
    print("‚úÖ tcn_diarization_model import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer tcn_diarization_model")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier tcn_diarization_model.py")
    raise ImportError("Module tcn_diarization_model requis")

try:
    from metrics import DiarizationMetrics
    print("‚úÖ metrics import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer metrics")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier metrics.py")
    raise ImportError("Module metrics requis")

try:
    from dataset import DiarizationDataset
    print("‚úÖ dataset import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer dataset")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier dataset.py")
    raise ImportError("Module dataset requis")

try:
    from diarization_losses import MultiTaskDiarizationLoss
    print("‚úÖ diarization_losses import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer diarization_losses")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier diarization_losses.py")
    raise ImportError("Module diarization_losses requis")

try:
    from improved_trainer import ImprovedDiarizationTrainer
    print("‚úÖ improved_trainer import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer improved_trainer")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier improved_trainer.py")
    raise ImportError("Module improved_trainer requis")

try:
    from optimized_dataloader import create_optimized_dataloaders
    print("‚úÖ optimized_dataloader import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer optimized_dataloader")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier optimized_dataloader.py")
    raise ImportError("Module optimized_dataloader requis")

try:
    from optimized_dataset import OptimizedDiarizationDataset
    print("‚úÖ optimized_dataset import√©")
except ImportError as e:
    print(f"‚ùå ERREUR CRITIQUE: Impossible d'importer optimized_dataset")
    print(f"   Erreur: {e}")
    print("üîß V√©rifiez le contenu du fichier optimized_dataset.py")
    raise ImportError("Module optimized_dataset requis")

print("\nüéâ TOUS LES MODULES IMPORT√âS AVEC SUCC√àS!")
print("üìã Modules disponibles pour l'entra√Ænement:")
print("   - DiarizationTCN: ‚úÖ")
print("   - DiarizationMetrics: ‚úÖ")
print("   - DiarizationDataset: ‚úÖ")
print("   - MultiTaskDiarizationLoss: ‚úÖ")
print("   - ImprovedDiarizationTrainer: ‚úÖ")
print("   - OptimizedDataLoader: ‚úÖ")
print("   - OptimizedDataset: ‚úÖ")

print("\nüöÄ Pr√™t pour l'entra√Ænement avec TOUS vos mod√®les originaux!")

In [None]:
# Classes utilitaires n√©cessaires (non pr√©sentes dans les modules du projet)
import numpy as np
import torch
import psutil
import gc

# MemoryMonitor pour surveiller l'utilisation m√©moire
class MemoryMonitor:
    def __init__(self):
        self.process = psutil.Process()
        
    def get_memory_info(self):
        """Retourne les informations m√©moire."""
        # RAM
        ram_info = psutil.virtual_memory()
        ram_percent = ram_info.percent
        
        # GPU
        gpu_percent = 0
        gpu_memory_used = 0
        gpu_memory_total = 0
        
        if torch.cuda.is_available():
            gpu_memory_used = torch.cuda.memory_allocated(0)
            gpu_memory_total = torch.cuda.get_device_properties(0).total_memory
            gpu_percent = (gpu_memory_used / gpu_memory_total) * 100
        
        return {
            'ram_percent': ram_percent,
            'gpu_percent': gpu_percent,
            'gpu_memory_used': gpu_memory_used,
            'gpu_memory_total': gpu_memory_total
        }
    
    def cleanup(self):
        """Nettoie la m√©moire."""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

print("‚úÖ Classes utilitaires cr√©√©es (MemoryMonitor)")

In [None]:
# Test de l'extraction de caract√©ristiques
print("üß™ Test de l'extraction de caract√©ristiques...")

# Cr√©er un extracteur de caract√©ristiques
feature_extractor = AudioFeatureExtractor(
    sample_rate=16000,
    n_fft=512,
    hop_length=256
)

# Cr√©er des donn√©es audio fictives (8 canaux)
n_channels = 8
duration = 4.0  # 4 secondes
sample_rate = 16000
n_samples = int(duration * sample_rate)

# Simuler audio multi-canal avec du bruit et des signaux
waveforms = []
for ch in range(n_channels):
    # Signal de base + bruit
    base_signal = np.sin(2 * np.pi * 440 * np.linspace(0, duration, n_samples))  # 440 Hz
    noise = np.random.normal(0, 0.1, n_samples)
    # Ajouter un l√©ger d√©calage temporel pour simuler la spatialisation
    delay_samples = int(0.001 * ch * sample_rate)  # 1ms de d√©lai par canal
    delayed_signal = np.zeros(n_samples)
    if delay_samples < n_samples:
        delayed_signal[delay_samples:] = base_signal[:n_samples-delay_samples]
    
    final_signal = delayed_signal + noise
    waveforms.append(torch.tensor(final_signal, dtype=torch.float32))

print(f"üìä Audio g√©n√©r√©: {n_channels} canaux, {duration}s, {sample_rate} Hz")

# Test d'extraction
features = feature_extractor.extract_features(waveforms)
print(f"‚úÖ Caract√©ristiques extraites: {features.shape}")
print(f"   - Dimensions attendues: [771, ~{int(duration * sample_rate / 256)}]")
print(f"   - Dimensions obtenues: {list(features.shape)}")

# Analyser les caract√©ristiques
print(f"\nüìà Statistiques des caract√©ristiques:")
print(f"   - Min: {features.min():.3f}")
print(f"   - Max: {features.max():.3f}")
print(f"   - Moyenne: {features.mean():.3f}")
print(f"   - Std: {features.std():.3f}")

# Visualiser les caract√©ristiques
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 8))

# Spectrogramme des caract√©ristiques
plt.subplot(2, 2, 1)
plt.imshow(features.numpy()[:100, :], aspect='auto', origin='lower')
plt.title('Caract√©ristiques LPS (premi√®re partie)')
plt.xlabel('Temps (frames)')
plt.ylabel('Fr√©quence')
plt.colorbar()

plt.subplot(2, 2, 2)
plt.imshow(features.numpy()[257:357, :], aspect='auto', origin='lower')
plt.title('Caract√©ristiques IPD (premi√®re partie)')
plt.xlabel('Temps (frames)')
plt.ylabel('Fr√©quence')
plt.colorbar()

plt.subplot(2, 2, 3)
plt.imshow(features.numpy()[500:600, :], aspect='auto', origin='lower')
plt.title('Caract√©ristiques AF (premi√®re partie)')
plt.xlabel('Temps (frames)')
plt.ylabel('Fr√©quence')
plt.colorbar()

plt.subplot(2, 2, 4)
plt.plot(features.numpy().mean(axis=0))
plt.title('√ânergie moyenne par frame')
plt.xlabel('Temps (frames)')
plt.ylabel('√ânergie moyenne')

plt.tight_layout()

# Cr√©er le r√©pertoire de r√©sultats s'il n'existe pas
import os
from pathlib import Path
RESULTS_DIR = Path('./results')
RESULTS_DIR.mkdir(exist_ok=True)

plt.savefig(RESULTS_DIR / 'feature_extraction_test.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Test d'extraction de caract√©ristiques termin√©!")

## üß† 4. Configuration du Mod√®le et Entra√Ænement

Configuration optimale pour le corpus AMI avec les meilleures pratiques.

In [None]:
# Configuration optimis√©e pour AMI corpus
config = {
    # === MOD√àLE ===
    'model': {
        'input_dim': 771,  # LPS (257) + IPD (257*4) + AF (257*4) = 2313 ‚Üí 771 apr√®s agr√©gation
        'hidden_channels': [256, 256, 256, 512, 512],  # Architecture TCN multi-√©chelle
        'kernel_size': 3,
        'num_speakers': 4,  # AMI corpus a typiquement 3-4 locuteurs
        'dropout': 0.2,
        'use_attention': True,  # Auto-attention pour d√©pendances long-terme
        'use_speaker_classifier': True,  # Classification de locuteurs
        'embedding_dim': 256
    },
    
    # === FONCTION DE PERTE ===
    'loss': {
        'type': 'multitask',
        'vad_weight': 1.0,
        'osd_weight': 1.0,
        'consistency_weight': 0.1,
        'use_pit': True,  # Permutation Invariant Training
        'use_focal': True,  # Focal Loss pour donn√©es d√©s√©quilibr√©es
        'focal_gamma': 2.0,
        'num_speakers': 4
    },
    
    # === OPTIMISEUR ===
    'optimizer': {
        'type': 'adamw',
        'lr': 1e-3,  # Learning rate initial
        'weight_decay': 1e-4,
        'betas': (0.9, 0.999)
    },
    
    # === PLANIFICATEUR LR ===
    'scheduler': {
        'type': 'onecycle',  # OneCycleLR pour convergence rapide
        'steps_per_epoch': 100,  # Sera mis √† jour automatiquement
        'pct_start': 0.3  # 30% mont√©e, 70% descente
    },
    
    # === ENTRA√éNEMENT ===
    'training': {
        'epochs': 50,  # R√©duit pour Colab
        'batch_size': 8,  # Adapt√© √† la m√©moire Colab
        'num_workers': 2  # Moins de workers pour √©viter les probl√®mes m√©moire
    },
    
    # === DONN√âES ===
    'data': {
        'segment_duration': 4.0,  # Segments de 4 secondes
        'sample_rate': 16000,
        'train_split': 0.7,
        'max_segments': 1000  # Limite pour Colab
    },
    
    # === OPTIMISATIONS AVANC√âES ===
    'accumulation_steps': 4,  # Batch effectif = 8*4 = 32
    'use_amp': True,  # Pr√©cision mixte
    'grad_clip_norm': 1.0,
    'patience': 10,
    'save_every': 5,
    
    # === MONITORING ===
    'use_wandb': True,  # Weights & Biases (optionnel)
    'project_name': 'ami-speaker-diarization',
    'memory_threshold': 0.85,  # Gestion m√©moire Colab
    'adaptive_batch': True,
    'speaker_loss_weight': 0.5,
    
    # === CHEMINS ===
    'save_dir': str(MODEL_DIR),
    'results_dir': str(RESULTS_DIR)
}

print("‚öôÔ∏è Configuration cr√©√©e avec les param√®tres suivants:")
print(f"   - Architecture: TCN {config['model']['hidden_channels']}")
print(f"   - Batch size: {config['training']['batch_size']} (effectif: {config['training']['batch_size'] * config['accumulation_steps']})")
print(f"   - Epochs: {config['training']['epochs']}")
print(f"   - Learning rate: {config['optimizer']['lr']}")
print(f"   - Pr√©cision mixte: {config['use_amp']}")
print(f"   - Classification locuteurs: {config['model']['use_speaker_classifier']}")

# Sauvegarder la configuration
config_file = MODEL_DIR / 'training_config.json'
with open(config_file, 'w') as f:
    json.dump(config, f, indent=2)

print(f"\nüíæ Configuration sauvegard√©e: {config_file}")

In [None]:
# Configuration de Weights & Biases (optionnel)
import wandb

use_wandb = config.get('use_wandb', False)

if use_wandb:
    try:
        # Connexion √† wandb (n√©cessite un compte gratuit)
        wandb.login()
        
        # Initialisation du projet
        wandb.init(
            project=config['project_name'],
            config=config,
            name=f"ami-tcn-{torch.cuda.get_device_name(0).replace(' ', '-') if torch.cuda.is_available() else 'cpu'}",
            tags=['ami-corpus', 'tcn', 'multi-channel', 'colab'],
            notes="Entra√Ænement sur corpus AMI avec architecture TCN am√©lior√©e"
        )
        
        print("‚úÖ Weights & Biases configur√©!")
        print(f"üìä Dashboard: {wandb.run.url}")
        
    except Exception as e:
        print(f"‚ö†Ô∏è Erreur wandb: {e}")
        print("üìà Entra√Ænement sans monitoring wandb...")
        config['use_wandb'] = False
else:
    print("üìà Entra√Ænement sans monitoring wandb (d√©sactiv√© dans config)")

## üöÄ 5. Entra√Ænement du Mod√®le

Entra√Ænement avec toutes les optimisations: gestion m√©moire, precision mixte, accumulation de gradients.

In [None]:
# Cr√©ation des DataLoaders optimis√©s
print("üîÑ Cr√©ation des DataLoaders avec vos modules optimis√©s...")

# Utiliser la fonction create_optimized_dataloaders import√©e
train_loader, val_loader = create_optimized_dataloaders(
    audio_dir=audio_dir,
    rttm_dir=rttm_dir,
    batch_size=config['training']['batch_size'],
    train_split=config['data']['train_split'],
    num_workers=config['training']['num_workers'],
    segment_duration=config['data']['segment_duration'],
    sample_rate=config['data']['sample_rate'],
    max_segments=config['data']['max_segments'],
    memory_threshold=config['memory_threshold'],
    adaptive_batch=config['adaptive_batch'],
    accumulation_steps=config['accumulation_steps']
)

print(f"‚úÖ DataLoaders cr√©√©s avec succ√®s!")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Validation batches: {len(val_loader)}")

# Mettre √† jour la configuration avec le nombre r√©el de steps
config['scheduler']['steps_per_epoch'] = len(train_loader)

# Test d'un batch
print("\nüß™ Test d'un batch d'entra√Ænement...")
for batch_idx, batch in enumerate(train_loader):
    print(f"   Batch {batch_idx}:")
    print(f"     - Features: {batch['features'].shape}")
    print(f"     - VAD labels: {batch['vad_labels'].shape}")
    print(f"     - OSD labels: {batch['osd_labels'].shape}")
    
    # V√©rifier les dimensions
    assert batch['features'].shape[1] == 771, f"Dimension features incorrecte: {batch['features'].shape[1]} != 771"
    assert batch['vad_labels'].shape[-1] == 4, f"Nombre de locuteurs incorrect: {batch['vad_labels'].shape[-1]} != 4"
    
    print(f"     ‚úÖ Dimensions correctes!")
    break

print("‚úÖ DataLoaders optimis√©s configur√©s et test√©s avec succ√®s!")

In [None]:
# Initialisation du trainer avanc√©
print("üß† Initialisation du trainer avec vos modules optimis√©s...")

# Cr√©er le trainer avec toutes les am√©liorations (OBLIGATOIRE - pas de fallback)
trainer = ImprovedDiarizationTrainer(config)

print(f"‚úÖ Trainer initialis√© avec succ√®s!")
print(f"   - Mod√®le: {trainer.model.get_num_params():,} param√®tres")
print(f"   - Device: {trainer.device}")
print(f"   - Pr√©cision mixte: {trainer.use_amp}")
print(f"   - Accumulation gradients: {trainer.accumulation_steps}")

# Test du forward pass
print("\nüß™ Test du mod√®le...")
model = trainer.model
model.eval()

with torch.no_grad():
    # Test avec un batch de donn√©es
    test_input = torch.randn(2, 771, 250).to(trainer.device)
    
    # Forward pass simple
    vad_out, osd_out = model(test_input)
    print(f"   Forward simple: VAD {vad_out.shape}, OSD {osd_out.shape}")
    
    # Forward avec embeddings si disponible
    try:
        vad_out, osd_out, embeddings, speaker_logits = model(test_input, return_embeddings=True)
        print(f"   Forward complet: VAD {vad_out.shape}, OSD {osd_out.shape}")
        print(f"                   Embeddings {embeddings.shape}, Speaker {speaker_logits.shape}")
    except:
        print("   Forward avec embeddings non disponible (normal)")

print("‚úÖ Mod√®le fonctionne correctement!")
print("üöÄ Trainer optimis√© pr√™t pour l'entra√Ænement!")

In [None]:
# D√©marrage de l'entra√Ænement
print("üöÄ D√âMARRAGE DE L'ENTRA√éNEMENT")
print("=" * 50)

# Monitoring m√©moire
memory_monitor = MemoryMonitor()
initial_memory = memory_monitor.get_memory_info()

print(f"üíæ M√©moire initiale:")
print(f"   - RAM: {initial_memory['ram_percent']:.1f}%")
print(f"   - GPU: {initial_memory['gpu_percent']:.1f}%")

# Configuration d'entra√Ænement
num_epochs = config['training']['epochs']
save_every = config.get('save_every', 5)

# Historique des m√©triques
train_losses = []
val_losses = []
best_val_loss = float('inf')

print(f"\nüìã Configuration d'entra√Ænement:")
print(f"   - Epochs: {num_epochs}")
print(f"   - Batch size: {config['training']['batch_size']}")
print(f"   - Learning rate: {config['optimizer']['lr']}")
print(f"   - Sauvegarde chaque {save_every} epochs")

print("\n" + "="*50)
print("D√âBUT DE L'ENTRA√éNEMENT")
print("="*50)

import time
from datetime import datetime

# Boucle d'entra√Ænement principale
start_time = time.time()

try:
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        print(f"\nüîÑ Epoch {epoch+1}/{num_epochs} - {datetime.now().strftime('%H:%M:%S')}")
        print("-" * 40)
        
        # Phase d'entra√Ænement
        if hasattr(trainer, 'train_epoch'):
            # Utiliser le trainer avanc√©
            train_metrics = trainer.train_epoch(train_loader)
            train_loss = train_metrics['total_loss']
        else:
            # Entra√Ænement simple
            trainer.model.train()
            total_loss = 0
            num_batches = 0
            
            for batch_idx, batch in enumerate(train_loader):
                features = batch['features'].to(trainer.device)
                vad_labels = batch['vad_labels'].to(trainer.device)
                osd_labels = batch['osd_labels'].to(trainer.device)
                
                trainer.optimizer.zero_grad()
                
                vad_pred, osd_pred = trainer.model(features)
                loss_dict = trainer.criterion(vad_pred, osd_pred, vad_labels, osd_labels)
                loss = loss_dict['total_loss']
                
                loss.backward()
                trainer.optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
                
                if batch_idx % 20 == 0:
                    print(f"   Batch {batch_idx}/{len(train_loader)}: Loss {loss.item():.4f}")
            
            train_loss = total_loss / num_batches
        
        train_losses.append(train_loss)
        
        # Phase de validation
        print(f"\nüìä Validation...")
        trainer.model.eval()
        val_loss = 0
        num_val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                features = batch['features'].to(trainer.device)
                vad_labels = batch['vad_labels'].to(trainer.device)
                osd_labels = batch['osd_labels'].to(trainer.device)
                
                vad_pred, osd_pred = trainer.model(features)
                loss_dict = trainer.criterion(vad_pred, osd_pred, vad_labels, osd_labels)
                val_loss += loss_dict['total_loss'].item()
                num_val_batches += 1
        
        val_loss = val_loss / num_val_batches if num_val_batches > 0 else 0
        val_losses.append(val_loss)
        
        # Monitoring m√©moire
        current_memory = memory_monitor.get_memory_info()
        
        # R√©sum√© de l'epoch
        epoch_time = time.time() - epoch_start
        print(f"\nüìà Epoch {epoch+1} R√©sultats:")
        print(f"   - Train Loss: {train_loss:.4f}")
        print(f"   - Val Loss: {val_loss:.4f}")
        print(f"   - Temps: {epoch_time:.1f}s")
        print(f"   - M√©moire GPU: {current_memory['gpu_percent']:.1f}%")
        
        # Sauvegarde du meilleur mod√®le
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = MODEL_DIR / 'best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': trainer.model.state_dict(),
                'optimizer_state_dict': trainer.optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config
            }, best_model_path)
            print(f"   ‚úÖ Nouveau meilleur mod√®le sauv√©! (Loss: {val_loss:.4f})")
        
        # Sauvegarde p√©riodique
        if (epoch + 1) % save_every == 0:
            checkpoint_path = MODEL_DIR / f'checkpoint_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': trainer.model.state_dict(),
                'optimizer_state_dict': trainer.optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'config': config
            }, checkpoint_path)
            print(f"   üíæ Checkpoint sauv√©: {checkpoint_path.name}")
        
        # Logging wandb
        if config.get('use_wandb', False) and 'wandb' in locals():
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'best_val_loss': best_val_loss,
                'gpu_memory_percent': current_memory['gpu_percent'],
                'epoch_time': epoch_time
            })
        
        # Nettoyage m√©moire
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

except KeyboardInterrupt:
    print("\n‚èπÔ∏è Entra√Ænement interrompu par l'utilisateur")
except Exception as e:
    print(f"\n‚ùå Erreur durant l'entra√Ænement: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    total_time = time.time() - start_time
    print(f"\n‚è±Ô∏è Temps total d'entra√Ænement: {total_time/60:.1f} minutes")
    print(f"üéØ Meilleure validation loss: {best_val_loss:.4f}")
    
    # Fermeture wandb
    if config.get('use_wandb', False) and 'wandb' in locals():
        wandb.finish()
    
print("\n" + "="*50)
print("ENTRA√éNEMENT TERMIN√â")
print("="*50)

## üìä 6. √âvaluation et M√©triques

√âvaluation compl√®te avec m√©triques de diarization standard.

In [None]:
# √âvaluation compl√®te du mod√®le
print("üìä √âVALUATION DU MOD√àLE")
print("=" * 40)

# Charger le meilleur mod√®le
best_model_path = MODEL_DIR / 'best_model.pth'

if best_model_path.exists():
    print(f"üìÇ Chargement du meilleur mod√®le: {best_model_path}")
    checkpoint = torch.load(best_model_path, map_location=trainer.device)
    trainer.model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"   - Epoch: {checkpoint['epoch']}")
    print(f"   - Val Loss: {checkpoint['val_loss']:.4f}")
else:
    print("‚ö†Ô∏è Pas de mod√®le sauv√©, utilisation du mod√®le actuel")

# Initialiser les m√©triques
metrics_computer = DiarizationMetrics(num_speakers=config['model']['num_speakers'])

# √âvaluation sur l'ensemble de validation
print("\nüß™ √âvaluation sur l'ensemble de validation...")
trainer.model.eval()

all_vad_preds = []
all_vad_targets = []
all_osd_preds = []
all_osd_targets = []

eval_loss = 0
num_eval_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(val_loader):
        features = batch['features'].to(trainer.device)
        vad_labels = batch['vad_labels'].to(trainer.device)
        osd_labels = batch['osd_labels'].to(trainer.device)
        
        # Pr√©dictions
        if hasattr(trainer.model, 'use_speaker_classifier') and trainer.model.use_speaker_classifier:
            vad_pred, osd_pred, embeddings, speaker_logits = trainer.model(features, return_embeddings=True)
        else:
            vad_pred, osd_pred = trainer.model(features)
        
        # Loss
        loss_dict = trainer.criterion(vad_pred, osd_pred, vad_labels, osd_labels)
        eval_loss += loss_dict['total_loss'].item()
        num_eval_batches += 1
        
        # Collecter pour m√©triques
        all_vad_preds.append(vad_pred.cpu())
        all_vad_targets.append(vad_labels.cpu())
        all_osd_preds.append(osd_pred.cpu())
        all_osd_targets.append(osd_labels.cpu())
        
        if batch_idx % 10 == 0:
            print(f"   Batch {batch_idx}/{len(val_loader)} √©valu√©")

# Calculer m√©triques d√©taill√©es
print("\nüìà Calcul des m√©triques d√©taill√©es...")

vad_preds = torch.cat(all_vad_preds, dim=0)
vad_targets = torch.cat(all_vad_targets, dim=0)
osd_preds = torch.cat(all_osd_preds, dim=0)
osd_targets = torch.cat(all_osd_targets, dim=0)

print(f"   - Donn√©es √©valu√©es: {vad_preds.shape[0]} √©chantillons")
print(f"   - Dur√©e totale: {vad_preds.shape[0] * vad_preds.shape[1] * 0.02:.1f} secondes")

# M√©triques principales
metrics = metrics_computer.compute_metrics(vad_preds, osd_preds, vad_targets, osd_targets)

print("\nüéØ R√âSULTATS D'√âVALUATION")
print("=" * 30)
print(f"üìä Loss finale: {eval_loss / num_eval_batches:.4f}")
print(f"üìä DER (Diarization Error Rate): {metrics.get('der', 0):.2f}%")
print(f"üìä F1 Score global: {metrics.get('f1_score', 0):.3f}")
print(f"üìä Pr√©cision frame: {metrics.get('frame_precision', 0):.3f}")
print(f"üìä Rappel frame: {metrics.get('frame_recall', 0):.3f}")
print(f"üìä Jaccard Index: {metrics.get('jaccard_index', 0):.3f}")

# M√©triques OSD
if 'osd_precision' in metrics:
    print(f"\nüîÄ D√©tection de Chevauchement (OSD):")
    print(f"   - Pr√©cision OSD: {metrics['osd_precision']:.3f}")
    print(f"   - Rappel OSD: {metrics['osd_recall']:.3f}")
    print(f"   - F1 OSD: {metrics['osd_f1']:.3f}")

# M√©triques par locuteur
print(f"\nüë• M√©triques par Locuteur:")
for spk in range(config['model']['num_speakers']):
    if f'speaker_{spk}_f1' in metrics:
        print(f"   Locuteur {spk}: F1={metrics[f'speaker_{spk}_f1']:.3f}, "
              f"P={metrics[f'speaker_{spk}_precision']:.3f}, "
              f"R={metrics[f'speaker_{spk}_recall']:.3f}")

# Sauvegarder les m√©triques
metrics_file = RESULTS_DIR / 'evaluation_metrics.json'
with open(metrics_file, 'w') as f:
    # Convertir les tenseurs en listes pour JSON
    json_metrics = {k: (v.item() if torch.is_tensor(v) else v) for k, v in metrics.items()}
    json_metrics['eval_loss'] = eval_loss / num_eval_batches
    json.dump(json_metrics, f, indent=2)

print(f"\nüíæ M√©triques sauv√©es: {metrics_file}")

## üìà 7. Visualisations et Analyses

G√©n√©ration de graphiques et visualisations des r√©sultats.

In [None]:
# Visualisations des r√©sultats
print("üìä G√©n√©ration des visualisations...")

# Configuration matplotlib
plt.style.use('seaborn-v0_8')
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 12

# === 1. Courbes d'entra√Ænement ===
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('R√©sultats d\'Entra√Ænement - Speaker Diarization TCN', fontsize=16, fontweight='bold')

# Courbe de perte
axes[0, 0].plot(train_losses, label='Train Loss', color='blue', linewidth=2)
axes[0, 0].plot(val_losses, label='Validation Loss', color='red', linewidth=2)
axes[0, 0].set_title('√âvolution de la Perte')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Zoom sur les derni√®res epochs
if len(train_losses) > 10:
    start_idx = max(0, len(train_losses) - 20)
    axes[0, 1].plot(range(start_idx, len(train_losses)), train_losses[start_idx:], 
                   label='Train Loss', color='blue', linewidth=2)
    axes[0, 1].plot(range(start_idx, len(val_losses)), val_losses[start_idx:], 
                   label='Validation Loss', color='red', linewidth=2)
    axes[0, 1].set_title('Convergence (Derni√®res Epochs)')
else:
    axes[0, 1].plot(train_losses, label='Train Loss', color='blue', linewidth=2)
    axes[0, 1].plot(val_losses, label='Validation Loss', color='red', linewidth=2)
    axes[0, 1].set_title('√âvolution de la Perte (Toutes Epochs)')

axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# === 2. Analyse des pr√©dictions ===
# Prendre un √©chantillon pour visualisation
sample_idx = 0
sample_vad_pred = vad_preds[sample_idx].numpy()  # [time, speakers]
sample_vad_target = vad_targets[sample_idx].numpy()
sample_osd_pred = osd_preds[sample_idx].numpy()  # [time]
sample_osd_target = osd_targets[sample_idx].numpy()

# Activit√© des locuteurs (pr√©dictions vs v√©rit√© terrain)
time_frames = np.arange(len(sample_vad_pred)) * 0.02  # Conversion en secondes

# Subplot pour VAD
axes[1, 0].imshow(sample_vad_pred.T, aspect='auto', origin='lower', 
                 extent=[0, len(sample_vad_pred)*0.02, 0, 4], 
                 cmap='Blues', alpha=0.7)
axes[1, 0].imshow(sample_vad_target.T, aspect='auto', origin='lower',
                 extent=[0, len(sample_vad_target)*0.02, 0, 4],
                 cmap='Reds', alpha=0.5)
axes[1, 0].set_title('Activit√© VAD: Pr√©diction (Bleu) vs V√©rit√© (Rouge)')
axes[1, 0].set_xlabel('Temps (s)')
axes[1, 0].set_ylabel('Locuteur ID')
axes[1, 0].set_yticks(range(4))

# Subplot pour OSD
axes[1, 1].plot(time_frames, sample_osd_pred, label='Pr√©diction OSD', 
               color='blue', linewidth=2, alpha=0.8)
axes[1, 1].plot(time_frames, sample_osd_target, label='V√©rit√© OSD', 
               color='red', linewidth=2, alpha=0.6)
axes[1, 1].set_title('D√©tection de Chevauchement (OSD)')
axes[1, 1].set_xlabel('Temps (s)')
axes[1, 1].set_ylabel('Probabilit√© de Chevauchement')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim(0, 1)

plt.tight_layout()
training_plot_path = RESULTS_DIR / 'training_results.png'
plt.savefig(training_plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Graphique d'entra√Ænement sauv√©: {training_plot_path}")

In [None]:
# === 3. Matrice de confusion pour classification ===
if hasattr(trainer.model, 'use_speaker_classifier') and trainer.model.use_speaker_classifier:
    print("\nüìä Analyse de la classification des locuteurs...")
    
    # Extraire les pr√©dictions de classification
    all_speaker_preds = []
    all_speaker_targets = []
    
    trainer.model.eval()
    with torch.no_grad():
        for batch in val_loader:
            features = batch['features'].to(trainer.device)
            vad_labels = batch['vad_labels'].to(trainer.device)
            
            try:
                vad_pred, osd_pred, embeddings, speaker_logits = trainer.model(features, return_embeddings=True)
                
                # Cr√©er des labels de locuteurs √† partir des VAD labels
                speaker_targets = torch.argmax(vad_labels.sum(dim=1), dim=1)  # Locuteur le plus actif
                speaker_preds = torch.argmax(speaker_logits, dim=1)
                
                all_speaker_preds.extend(speaker_preds.cpu().numpy())
                all_speaker_targets.extend(speaker_targets.cpu().numpy())
                
            except Exception as e:
                print(f"   Erreur dans un batch: {e}")
                continue
    
    if all_speaker_preds and all_speaker_targets:
        from sklearn.metrics import confusion_matrix, classification_report
        
        # Matrice de confusion
        cm = confusion_matrix(all_speaker_targets, all_speaker_preds)
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=[f'Pred {i}' for i in range(4)],
                   yticklabels=[f'True {i}' for i in range(4)])
        plt.title('Matrice de Confusion - Classification des Locuteurs', fontsize=14, fontweight='bold')
        plt.xlabel('Pr√©diction')
        plt.ylabel('V√©rit√© Terrain')
        
        confusion_path = RESULTS_DIR / 'speaker_confusion_matrix.png'
        plt.savefig(confusion_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        # Rapport de classification
        print("\nüìã Rapport de Classification:")
        print(classification_report(all_speaker_targets, all_speaker_preds,
                                  target_names=[f'Locuteur {i}' for i in range(4)],
                                  digits=3))
        
        print(f"‚úÖ Matrice de confusion sauv√©e: {confusion_path}")
    else:
        print("‚ö†Ô∏è Pas assez de donn√©es pour la matrice de confusion")
else:
    print("‚ö†Ô∏è Classificateur de locuteurs non activ√©")

In [None]:
# === 4. R√©sum√© final et comparaisons ===
print("\n" + "="*60)
print("üìä R√âSUM√â FINAL DE L'ENTRA√éNEMENT")
print("="*60)

# Cr√©er un r√©sum√© complet
final_summary = {
    'Configuration': {
        'Architecture': f"TCN {config['model']['hidden_channels']}",
        'Param√®tres': f"{sum(p.numel() for p in trainer.model.parameters()):,}",
        'Batch Size': config['training']['batch_size'],
        'Epochs': len(train_losses),
        'Learning Rate': config['optimizer']['lr'],
        'Device': str(trainer.device)
    },
    'R√©sultats Finaux': {
        'Train Loss': f"{train_losses[-1]:.4f}" if train_losses else "N/A",
        'Val Loss': f"{val_losses[-1]:.4f}" if val_losses else "N/A",
        'Best Val Loss': f"{best_val_loss:.4f}",
        'DER': f"{metrics.get('der', 0):.2f}%",
        'F1 Score': f"{metrics.get('f1_score', 0):.3f}",
        'Frame Precision': f"{metrics.get('frame_precision', 0):.3f}",
        'Frame Recall': f"{metrics.get('frame_recall', 0):.3f}"
    },
    'Fichiers G√©n√©r√©s': {
        'Meilleur Mod√®le': str(best_model_path) if best_model_path.exists() else "Non sauv√©",
        'M√©triques': str(metrics_file),
        'Graphiques': str(training_plot_path),
        'Configuration': str(config_file)
    }
}

# Affichage du r√©sum√©
for section, items in final_summary.items():
    print(f"\nüîπ {section}:")
    for key, value in items.items():
        print(f"   {key:.<25} {value}")

# Sauvegarde du r√©sum√©
summary_file = RESULTS_DIR / 'training_summary.json'
with open(summary_file, 'w') as f:
    json.dump(final_summary, f, indent=2)

print(f"\nüíæ R√©sum√© complet sauv√©: {summary_file}")

# === 5. Recommandations d'am√©lioration ===
print("\n" + "="*60)
print("üí° RECOMMANDATIONS POUR AM√âLIORER LES PERFORMANCES")
print("="*60)

current_der = metrics.get('der', 100)
current_f1 = metrics.get('f1_score', 0)

recommendations = []

if current_der > 25:
    recommendations.append("üîß DER √©lev√©: Augmenter le nombre d'epochs ou r√©duire le learning rate")
if current_f1 < 0.7:
    recommendations.append("üîß F1 faible: Essayer focal loss avec gamma plus √©lev√©")
if len(train_losses) < 20:
    recommendations.append("‚è∞ Entra√Ænement court: Augmenter le nombre d'epochs")
if config['training']['batch_size'] < 16:
    recommendations.append("üì¶ Batch size petit: Augmenter si possible pour am√©liorer la stabilit√©")

recommendations.extend([
    "üìä Utiliser plus de donn√©es AMI si disponibles",
    "üéØ Affiner les hyperparam√®tres avec Optuna",
    "üîÑ Essayer l'ensemble de mod√®les",
    "üìà Impl√©menter la validation crois√©e",
    "üß† Tester diff√©rentes architectures d'attention"
])

for i, rec in enumerate(recommendations[:7], 1):
    print(f"{i}. {rec}")

print("\n" + "="*60)
print("üéâ ENTRA√éNEMENT TERMIN√â AVEC SUCC√àS!")
print("="*60)
print(f"üìÅ Tous les r√©sultats sont sauv√©s dans: {RESULTS_DIR}")
print(f"üß† Meilleur mod√®le disponible dans: {MODEL_DIR}")

if config.get('use_wandb', False):
    print(f"üìä Logs d√©taill√©s disponibles sur Weights & Biases")

## üöÄ Prochaines √âtapes

### üìù Pour Continuer l'Am√©lioration:

1. **üìä Donn√©es**: T√©l√©charger le corpus AMI complet
2. **‚öôÔ∏è Hyperparam√®tres**: Optimiser avec Optuna
3. **üéØ Architecture**: Tester diff√©rentes tailles de mod√®le
4. **üìà Ensembles**: Combiner plusieurs mod√®les
5. **üîÑ Post-traitement**: Am√©liorer la segmentation finale

### üíæ Fichiers G√©n√©r√©s:
- `models/checkpoints/best_model.pth` - Meilleur mod√®le
- `results/evaluation_metrics.json` - M√©triques d√©taill√©es
- `results/training_results.png` - Graphiques d'entra√Ænement
- `results/training_summary.json` - R√©sum√© complet

### üéØ Objectifs de Performance:
- **DER < 20%** sur AMI corpus (√©tat de l'art: ~15-18%)
- **F1 > 0.8** pour la d√©tection d'activit√©
- **Temps r√©el** pour l'inf√©rence

---

**üéâ F√©licitations! Vous avez entra√Æn√© avec succ√®s un syst√®me de diarization moderne avec toutes les optimisations avanc√©es!**