In [6]:
import sys
import numpy as np
from frechet_audio_distance import FrechetAudioDistance
from pathlib import Path
import pandas as pd
import warnings
import shutil
import tempfile
warnings.filterwarnings('ignore')

scripts_dir = Path(r'C:\Users\sahan\OneDrive\Documents\Python\MUSHRA').parent
sys.path.insert(0, str(scripts_dir))

from batch_bss_eval import extract_model_name

In [7]:
def collect_files_by_model(references_dir, estimated_dir):
    
    references_dir = Path(references_dir)
    estimated_dir = Path(estimated_dir)
    
    ref_files = sorted(references_dir.glob('*.wav'))
    
    model_files = {}
    
    print(f"Found {len(ref_files)} reference files")
    print("\nCollecting estimated files by model...")
    
    for ref_file in ref_files:
        ref_stem = ref_file.stem
        estimates_folder = estimated_dir/ref_stem
        
        if not estimates_folder.exists():
            print(f' No folder found for {ref_stem}')
            continue
        
        estimated_files = sorted(estimates_folder.glob('*.wav'))
        
        for est_file in estimated_files:
            model_name = extract_model_name(est_file)
            
            if model_name not in model_files:
                model_files[model_name] = []
            
            model_files[model_name].append(est_file)
    
    print("\nFiles collected:")
    print(f"  References: {len(ref_files)} files")
    for model, files in model_files.items():
        print(f"  {model}: {len(files)} files")
    
    return ref_files, model_files


In [None]:
def calculate_fad_scores(references_dir, estimated_dir, sample_rate=16000, 
                        model_name="vggish"):
    """
    Calculate FAD scores for each model
    
    Args:
        references_dir: Path to reference audio files
        estimated_dir: Path to estimated sources
        sample_rate: Sample rate for FAD (default 16kHz)
        model_name: Embedding model ("vggish", "pann", or "clap")
    
    Returns:
        DataFrame with FAD scores for each model
    """
    
    print("="*60)
    print("Fréchet Audio Distance (FAD) Calculation")
    print("="*60)
    print(f"Using embedding model: {model_name}")
    print(f"Sample rate: {sample_rate} Hz\n")
    
    # Collect files organized by model
    ref_files, model_files = collect_files_by_model(references_dir, estimated_dir)
    
    if not ref_files:
        print("❌ No reference files found!")
        return pd.DataFrame()
    
    if not model_files:
        print("❌ No estimated files found!")
        return pd.DataFrame()
    
    # Initialize FAD calculator
    print(f"\nInitializing FAD with {model_name} embeddings...")
    frechet = FrechetAudioDistance(
        model_name=model_name,
        sample_rate=sample_rate,
        #use_pca=False,
        #use_activation=False,
        verbose=True
    )
    
    # Calculate FAD for each model
    results = []
    
    print("\nCalculating FAD scores...\n")
    
    with tempfile.TemporaryDirectory() as ref_temp_dir:
        print('Preparing reference files...')
        for ref_file in ref_files:
            shutil.copy2(ref_file, ref_temp_dir)
            
    
        for model, est_files in sorted(model_files.items()):
            print(f"Processing {model}...")
            print(f"  Comparing {len(ref_files)} references vs {len(est_files)} estimates")
            
            try:
                
                with tempfile.TemporaryDirectory() as est_temp_dir:
                    
                    for est_file in est_files:
                        shutil.copy2(est_file, est_temp_dir)
                        
                
                    # Calculate FAD score
                    fad_score, test = frechet.score(
                        ref_temp_dir,
                        est_temp_dir,
                        #dtype="float32"
                    )
                    
                    results.append({
                        'model': model,
                        'FAD': fad_score,
                        'n_references': len(ref_files),
                        'n_estimates': len(est_files)
                    })
                    
                    print(f"  ✅ FAD Score: {fad_score:.4f}\n")
                    
            except Exception as e:
                print(f"  ❌ Error: {e}\n")
                results.append({
                    'model': model,
                    'FAD': None,
                    'n_references': len(ref_files),
                    'n_estimates': len(est_files)
                })
        
        # Create DataFrame
        results_df = pd.DataFrame(results)
        
        return results_df

In [17]:
"""Main execution"""
upper_dir = 'C:\\Users\\sahan\\OneDrive\\Documents\\Python\\MUSHRA\\'
# Configuration
references_dir = upper_dir + "references"
estimated_dir = upper_dir + "estimated_sources"
output_csv = upper_dir + "results/fad_scores.csv"

# FAD parameters
sample_rate = 16000  # Standard for FAD
embedding_model = "pann"  # Options: "vggish", "pann", "clap"

# Calculate FAD scores
results_df = calculate_fad_scores(
    references_dir=references_dir,
    estimated_dir=estimated_dir,
    sample_rate=sample_rate,
    model_name=embedding_model
)

if results_df.empty:
    print("No results to save!")
    

else:

    # Create results directory if needed
    output_path = Path(output_csv)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Save results
    results_df.to_csv(output_csv, index=False)
    print(f"\n✅ Results saved to {output_csv}")

    # Print summary
    print("\n" + "="*60)
    print("FAD Scores Summary (Lower is Better)")
    print("="*60)

    # Sort by FAD score
    sorted_df = results_df.sort_values('FAD')

    for _, row in sorted_df.iterrows():
        if row['FAD'] is not None:
            print(f"{row['model']:15s}: {row['FAD']:8.4f}")
        else:
            print(f"{row['model']:15s}: {'ERROR':>8s}")

    print("\nInterpretation:")
    print("  FAD < 5    : Excellent quality")
    print("  FAD 5-15   : Good quality")
    print("  FAD 15-30  : Moderate quality")
    print("  FAD > 30   : Poor quality")

Fréchet Audio Distance (FAD) Calculation
Using embedding model: pann
Sample rate: 16000 Hz

Found 6 reference files

Collecting estimated files by model...

Files collected:
  References: 6 files
  anchor: 6 files
  dv2: 6 files
  htdemucs: 6 files
  spleeter: 6 files

Initializing FAD with pann embeddings...
[Frechet Audio Distance] Using device: cpu

Calculating FAD scores...

Preparing reference files...
Processing anchor...
  Comparing 6 references vs 6 estimates


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpa21gd5uz...


100%|██████████| 6/6 [00:00<00:00,  8.40it/s]
 17%|█▋        | 1/6 [00:00<00:01,  3.04it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 33%|███▎      | 2/6 [00:00<00:01,  3.78it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:00<00:01,  2.93it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:01<00:00,  2.88it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:01<00:00,  3.05it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:02<00:00,  2.99it/s]


[Frechet Audio Distance] Embedding shape: torch.Size([2048])


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpm_duoi3a...


100%|██████████| 6/6 [00:01<00:00,  5.14it/s]
 17%|█▋        | 1/6 [00:00<00:01,  3.41it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 33%|███▎      | 2/6 [00:00<00:01,  3.24it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:00<00:01,  2.90it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:01<00:00,  2.64it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:01<00:00,  2.36it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:02<00:00,  2.28it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] An error occurred: not enough values to unpack (expected 2, got 1)
  ✅ FAD Score: -1.0000

Processing dv2...
  Comparing 6 references vs 6 estimates



  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpa21gd5uz...


100%|██████████| 6/6 [00:00<00:00,  8.16it/s]
 33%|███▎      | 2/6 [00:00<00:00,  5.13it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:00<00:00,  3.98it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:01<00:00,  3.80it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:01<00:00,  3.97it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:01<00:00,  3.88it/s]


[Frechet Audio Distance] Embedding shape: torch.Size([2048])


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpww35oj3y...


100%|██████████| 6/6 [00:01<00:00,  4.11it/s]
 33%|███▎      | 2/6 [00:00<00:00,  5.17it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:00<00:00,  3.71it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:01<00:00,  3.50it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:01<00:00,  3.67it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:01<00:00,  3.65it/s]


[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] An error occurred: not enough values to unpack (expected 2, got 1)
  ✅ FAD Score: -1.0000

Processing htdemucs...
  Comparing 6 references vs 6 estimates


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpa21gd5uz...


100%|██████████| 6/6 [00:00<00:00,  7.22it/s]
 17%|█▋        | 1/6 [00:00<00:01,  2.53it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 33%|███▎      | 2/6 [00:00<00:01,  2.87it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:01<00:01,  2.23it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:01<00:00,  2.44it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:01<00:00,  2.89it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:02<00:00,  2.72it/s]


[Frechet Audio Distance] Embedding shape: torch.Size([2048])


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpic3xet3g...


100%|██████████| 6/6 [00:01<00:00,  4.15it/s]
 33%|███▎      | 2/6 [00:00<00:00,  4.23it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:00<00:00,  3.30it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:01<00:00,  3.07it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:01<00:00,  3.18it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:01<00:00,  3.15it/s]


[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] An error occurred: not enough values to unpack (expected 2, got 1)
  ✅ FAD Score: -1.0000

Processing spleeter...
  Comparing 6 references vs 6 estimates


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmpa21gd5uz...


100%|██████████| 6/6 [00:00<00:00, 12.65it/s]
 17%|█▋        | 1/6 [00:00<00:04,  1.16it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 33%|███▎      | 2/6 [00:01<00:02,  1.56it/s]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:14<00:19,  6.54s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:22<00:13,  7.00s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:23<00:04,  4.84s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:24<00:00,  4.15s/it]


[Frechet Audio Distance] Embedding shape: torch.Size([2048])


  0%|          | 0/6 [00:00<?, ?it/s]

[Frechet Audio Distance] Loading audio from C:\Users\sahan\AppData\Local\Temp\tmppp9u015q...


100%|██████████| 6/6 [00:03<00:00,  1.88it/s]
 17%|█▋        | 1/6 [00:01<00:07,  1.41s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 33%|███▎      | 2/6 [00:02<00:04,  1.18s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 50%|█████     | 3/6 [00:04<00:04,  1.60s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 67%|██████▋   | 4/6 [00:10<00:06,  3.47s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


 83%|████████▎ | 5/6 [00:12<00:02,  2.80s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])


100%|██████████| 6/6 [00:22<00:00,  3.80s/it]

[Frechet Audio Distance] Embedding shape: torch.Size([2048])
[Frechet Audio Distance] An error occurred: not enough values to unpack (expected 2, got 1)
  ✅ FAD Score: -1.0000


✅ Results saved to C:\Users\sahan\OneDrive\Documents\Python\MUSHRA\results/fad_scores.csv

FAD Scores Summary (Lower is Better)
anchor         :  -1.0000
dv2            :  -1.0000
htdemucs       :  -1.0000
spleeter       :  -1.0000

Interpretation:
  FAD < 5    : Excellent quality
  FAD 5-15   : Good quality
  FAD 15-30  : Moderate quality
  FAD > 30   : Poor quality



