In [None]:
import sys
from pathlib import Path

# add project root to path
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
import seaborn as sns

print(f"project root: {project_root}")

In [None]:
# configuration
RESULTS_DIR = project_root / 'results'
FIGURES_DIR = RESULTS_DIR / 'figures'
TABLES_DIR = RESULTS_DIR / 'tables'

FIGURES_DIR.mkdir(parents=True, exist_ok=True)
TABLES_DIR.mkdir(parents=True, exist_ok=True)

# figure settings
DPI = 300
SAVE_FORMAT = 'pdf'  # pdf for vector, png for raster

In [None]:
from src.utils import set_publication_style, FigureGenerator, PALETTES

# set publication style
set_publication_style()

# show available palettes
print("available color palettes:")
for name, colors in PALETTES.items():
    print(f"  {name}: {len(colors)} colors")

## Load Results

In [None]:
# load aggregated results
aggregated_file = RESULTS_DIR / 'aggregated_results.json'
hypothesis_file = RESULTS_DIR / 'hypothesis_results.json'

if aggregated_file.exists():
    with open(aggregated_file, 'r') as f:
        aggregated = json.load(f)
    print(f"loaded aggregated results")
else:
    print(f"run 02_results_analysis.ipynb first")
    aggregated = None

if hypothesis_file.exists():
    with open(hypothesis_file, 'r') as f:
        hypothesis = json.load(f)
    print(f"loaded hypothesis results")
else:
    hypothesis = None

In [None]:
# create figure generator
fig_gen = FigureGenerator(output_dir=str(FIGURES_DIR))

## Figure 1: Model Architecture and Experimental Overview

Panel A: Wav2Vec2 architecture diagram  
Panel B: Probing classifier setup  
Panel C: Activation patching methodology

In [None]:
def figure_methodology():
    """Create methodology overview figure."""
    fig = plt.figure(figsize=(12, 4))
    gs = GridSpec(1, 3, figure=fig, wspace=0.3)
    
    colors = PALETTES['categorical']
    
    # panel a: wav2vec2 architecture
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.set_xlim(0, 10)
    ax1.set_ylim(0, 10)
    
    # draw feature encoder
    ax1.add_patch(mpatches.Rectangle((1, 7), 8, 2, facecolor=colors[0], alpha=0.7))
    ax1.text(5, 8, 'Feature Encoder\n(CNN)', ha='center', va='center', fontsize=9)
    
    # draw transformer layers
    for i, y in enumerate([5, 3.5, 2]):
        layer_num = i * 4 + 1 if i < 2 else 9
        label = f'Layers {layer_num}-{layer_num+3}' if i < 2 else 'Layers 9-12'
        ax1.add_patch(mpatches.Rectangle((1, y), 8, 1.2, facecolor=colors[1], alpha=0.6 + i*0.1))
        ax1.text(5, y+0.6, label, ha='center', va='center', fontsize=8)
    
    # draw arrows
    ax1.annotate('', xy=(5, 6.8), xytext=(5, 7),
                arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    ax1.annotate('', xy=(5, 4.9), xytext=(5, 5),
                arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    ax1.set_title('A. Wav2Vec2 Architecture', fontweight='bold', loc='left')
    ax1.axis('off')
    
    # panel b: probing setup
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.set_xlim(0, 10)
    ax2.set_ylim(0, 10)
    
    # activation extraction
    ax2.add_patch(mpatches.FancyBboxPatch((0.5, 7), 4, 2, boxstyle='round,pad=0.1',
                                          facecolor=colors[2], alpha=0.7))
    ax2.text(2.5, 8, 'Layer\nActivations', ha='center', va='center', fontsize=9)
    
    # linear probe
    ax2.add_patch(mpatches.FancyBboxPatch((5.5, 7), 4, 2, boxstyle='round,pad=0.1',
                                          facecolor=colors[3], alpha=0.7))
    ax2.text(7.5, 8, 'Linear\nProbe', ha='center', va='center', fontsize=9)
    
    ax2.annotate('', xy=(5.3, 8), xytext=(4.7, 8),
                arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # output
    ax2.add_patch(mpatches.Circle((5, 3), 1.5, facecolor=colors[0], alpha=0.7))
    ax2.text(5, 3, 'PD/HC', ha='center', va='center', fontsize=10, fontweight='bold')
    
    ax2.annotate('', xy=(5, 4.7), xytext=(7.5, 6.8),
                arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    ax2.set_title('B. Probing Classifier', fontweight='bold', loc='left')
    ax2.axis('off')
    
    # panel c: patching
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.set_xlim(0, 10)
    ax3.set_ylim(0, 10)
    
    # source (pd)
    ax3.add_patch(mpatches.Rectangle((0.5, 6.5), 2.5, 3, facecolor=colors[1], alpha=0.7))
    ax3.text(1.75, 8, 'PD\nSample', ha='center', va='center', fontsize=9)
    
    # target (hc)
    ax3.add_patch(mpatches.Rectangle((3.5, 6.5), 2.5, 3, facecolor=colors[0], alpha=0.7))
    ax3.text(4.75, 8, 'HC\nSample', ha='center', va='center', fontsize=9)
    
    # patched
    ax3.add_patch(mpatches.Rectangle((7, 6.5), 2.5, 3, facecolor=colors[4], alpha=0.7))
    ax3.text(8.25, 8, 'Patched', ha='center', va='center', fontsize=9)
    
    # arrow
    ax3.annotate('', xy=(6.8, 8), xytext=(6.2, 8),
                arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    ax3.text(6.5, 8.7, 'Swap\nLayer L', ha='center', va='center', fontsize=8)
    
    # measure
    ax3.text(5, 3, 'Measure:', ha='center', va='center', fontsize=10)
    ax3.text(5, 2, r'$\Delta$logit recovery', ha='center', va='center', fontsize=10)
    
    ax3.set_title('C. Activation Patching', fontweight='bold', loc='left')
    ax3.axis('off')
    
    plt.tight_layout()
    return fig

fig = figure_methodology()
fig.savefig(FIGURES_DIR / f'figure_methodology.{SAVE_FORMAT}', dpi=DPI, bbox_inches='tight')
plt.show()
print(f"saved: figure_methodology.{SAVE_FORMAT}")

## Figure 2: Probing and Patching Results

Panel A: Layer-wise probing accuracy  
Panel B: Layer-wise patching recovery  
Panel C: Probing vs patching correlation

In [None]:
# use figure generator if we have results
if aggregated:
    # extract data
    probing = aggregated.get('probing', {})
    patching = aggregated.get('patching', {})
    
    # convert to expected format
    probing_dict = {}
    for layer_str, data in probing.items():
        if layer_str.isdigit():
            probing_dict[int(layer_str)] = {
                'mean': data.get('mean_score', data.get('mean', 0)),
                'std': data.get('std_score', data.get('std', 0))
            }
    
    patching_dict = {}
    for layer_str, data in patching.items():
        if layer_str.isdigit():
            patching_dict[int(layer_str)] = {
                'mean_recovery': data.get('mean_recovery', 0),
                'std_recovery': data.get('std_recovery', 0)
            }
    
    if probing_dict and patching_dict:
        fig = fig_gen.figure_1_overview(
            probing_dict, 
            patching_dict,
            model_accuracy=0.85
        )
        plt.show()
    else:
        print("no probing/patching data available")
else:
    print("run 02_results_analysis.ipynb first to generate aggregated results")

## Figure 3: Clinical Feature Encoding

In [None]:
# create example clinical heatmap
if aggregated and 'clinical' in aggregated:
    clinical = aggregated['clinical']
    fig = fig_gen.figure_2_clinical_encoding(clinical)
    plt.show()
else:
    # create demo figure with synthetic data
    print("creating demo clinical encoding figure...")
    
    # synthetic data for demonstration
    features = ['jitter', 'shimmer', 'hnr', 'f0_mean', 'f0_std']
    layers = list(range(1, 13))
    
    # simulate different encoding patterns
    np.random.seed(42)
    demo_data = {}
    for feat in features:
        demo_data[feat] = {}
        if feat in ['jitter', 'shimmer']:
            # peak early
            peak = 3
        else:
            # peak middle
            peak = 7
        
        for layer in layers:
            base = 0.5 * np.exp(-0.3 * abs(layer - peak))
            demo_data[feat][layer] = {
                'mean': base + np.random.uniform(-0.05, 0.05),
                'std': np.random.uniform(0.02, 0.08)
            }
    
    fig = fig_gen.figure_2_clinical_encoding(demo_data)
    plt.show()
    print("(demo data - replace with real results)")

## Figure 4: Hypothesis Testing Summary

In [None]:
if hypothesis:
    fig = fig_gen.figure_3_hypothesis_summary(hypothesis)
    plt.show()
else:
    # create demo hypothesis summary
    print("creating demo hypothesis summary...")
    
    demo_hypothesis = {
        'hypothesis_1': {
            'supported': True,
            'phonatory_early_fraction': 0.75,
            'prosodic_middle_fraction': 0.80
        },
        'hypothesis_2': {
            'supported': True,
            'probing_patching_correlation': {
                'spearman_r': 0.85,
                'spearman_p': 0.0003
            }
        },
        'hypothesis_3': {
            'supported': True,
            'cross_dataset_mean': 0.78,
            'generalization_gap': 0.08
        }
    }
    
    fig = fig_gen.figure_3_hypothesis_summary(demo_hypothesis)
    plt.show()
    print("(demo data - replace with real results)")

## Figure 5: Cross-Dataset Generalization Matrix

In [None]:
# create cross-dataset matrix
if aggregated and 'cross_dataset' in aggregated:
    cross = aggregated['cross_dataset']
    fig = fig_gen.figure_4_cross_dataset(cross)
    plt.show()
else:
    # create demo figure
    print("creating demo cross-dataset matrix...")
    
    datasets = ['Italian', 'mDVR-KCL', 'UCI Oxford']
    
    demo_cross = {
        datasets[0]: {
            datasets[0]: 0.92,
            datasets[1]: 0.78,
            datasets[2]: 0.72
        },
        datasets[1]: {
            datasets[0]: 0.76,
            datasets[1]: 0.89,
            datasets[2]: 0.74
        },
        datasets[2]: {
            datasets[0]: 0.71,
            datasets[1]: 0.73,
            datasets[2]: 0.88
        }
    }
    
    fig = fig_gen.figure_4_cross_dataset(demo_cross)
    plt.show()
    print("(demo data - replace with real results)")

## Generate All Figures

In [None]:
# batch generate all figures
if aggregated and hypothesis:
    # prepare data
    probing = aggregated.get('probing', {})
    patching = aggregated.get('patching', {})
    clinical = aggregated.get('clinical', {})
    cross = aggregated.get('cross_dataset', {})
    
    saved = fig_gen.generate_all_figures(
        probing_results=probing,
        patching_results=patching,
        clinical_results=clinical,
        hypothesis_results=hypothesis,
        cross_dataset_results=cross
    )
    
    print(f"\ngenerated {len(saved)} figures:")
    for path in saved:
        print(f"  - {path}")
else:
    print("run 02_results_analysis.ipynb first")

## Supplementary Tables

In [None]:
def create_results_table(probing_data, patching_data, save_path=None):
    """Create summary results table."""
    rows = []
    
    layers = sorted(set(probing_data.keys()) | set(patching_data.keys()))
    
    for layer in layers:
        row = {'Layer': layer}
        
        if layer in probing_data:
            p = probing_data[layer]
            row['Probing Acc'] = f"{p['mean']:.3f} ± {p['std']:.3f}"
        else:
            row['Probing Acc'] = '-'
            
        if layer in patching_data:
            pt = patching_data[layer]
            row['Patching Recovery'] = f"{pt['mean_recovery']:.3f} ± {pt['std_recovery']:.3f}"
        else:
            row['Patching Recovery'] = '-'
        
        rows.append(row)
    
    df = pd.DataFrame(rows)
    
    if save_path:
        # save as latex
        latex = df.to_latex(index=False, caption='Layer-wise probing and patching results',
                           label='tab:layer_results')
        with open(save_path, 'w') as f:
            f.write(latex)
        print(f"saved: {save_path}")
    
    return df

# example
if aggregated:
    probing = {int(k): v for k, v in aggregated.get('probing', {}).items() if k.isdigit()}
    patching = {int(k): v for k, v in aggregated.get('patching', {}).items() if k.isdigit()}
    
    if probing or patching:
        df = create_results_table(probing, patching, TABLES_DIR / 'layer_results.tex')
        display(df)
else:
    print("no aggregated results")

In [None]:
print("figure generation complete!")
print(f"figures saved to: {FIGURES_DIR}")
print(f"tables saved to: {TABLES_DIR}")