
# 3. Visualization and Interpretation

This notebook reads the outputs from the training notebook and shows:
- global concept importance
- fake vs real concept summaries
- per-user important concepts (test speakers)
- plots / heatmaps

This notebook is self-contained and does not import the `.py` script.


In [None]:

from pathlib import Path
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

PROJECT_ROOT = Path('/home/SpeakerRec/BioVoice')
ANALYSIS_DIR = PROJECT_ROOT / 'data' / 'tcav' / 'logreg_concept_analysis' / 'stage4_spoofwrapper_pospct'
PLOTS_DIR = ANALYSIS_DIR / 'plots'
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

print('ANALYSIS_DIR =', ANALYSIS_DIR)


In [None]:

# Load exported analysis files
paths = {
    'meta': ANALYSIS_DIR / 'run_metadata.json',
    'coef': ANALYSIS_DIR / 'global_concept_coefficients.csv',
    'class_summary': ANALYSIS_DIR / 'classwise_concept_summary.csv',
    'preds': ANALYSIS_DIR / 'test_predictions.csv',
    'user_contrib': ANALYSIS_DIR / 'test_user_mean_contributions.csv',
    'top_user': ANALYSIS_DIR / 'test_user_top_concepts.csv',
}
for k, p in paths.items():
    assert p.exists(), f'Missing file for {k}: {p}'

run_meta = json.loads(paths['meta'].read_text(encoding='utf-8'))
coef_df = pd.read_csv(paths['coef'])
class_summary_df = pd.read_csv(paths['class_summary'])
pred_df = pd.read_csv(paths['preds'])
user_contrib = pd.read_csv(paths['user_contrib'])
top_user_df = pd.read_csv(paths['top_user'])

print('Loaded files successfully')
print('Train speakers:', run_meta['speaker_split']['train_speakers'])
print('Test speakers :', run_meta['speaker_split']['test_speakers'])
print('Metrics:')
print(json.dumps({k:v for k,v in run_meta['metrics'].items() if k != 'classification_report'}, indent=2))
print(run_meta['metrics']['classification_report'])


In [None]:

# Tables: global importance + class-wise differences
print('Top global coefficients (absolute):')
display(coef_df.head(20))

print('Top concepts higher in fake (mean difference fake-real):')
display(class_summary_df.sort_values('mean_diff_fake_minus_real', ascending=False).head(15))

print('Top concepts higher in real (mean difference fake-real most negative):')
display(class_summary_df.sort_values('mean_diff_fake_minus_real', ascending=True).head(15))


In [None]:

# Table: per-user important concepts on test speakers
# list_type = top_fake_supporting / top_real_supporting
for spk in sorted(top_user_df['speaker_id'].astype(str).unique().tolist()):
    print('\n===', spk, '===')
    display(top_user_df[top_user_df['speaker_id'] == spk].sort_values(['true label', 'list_type', 'rank']))


In [None]:

# Plot 1: Global coefficient bar chart (top 20 by absolute value)
plot_df = coef_df.head(20).copy().iloc[::-1]
colors = ['#b2182b' if d == 'fake' else '#2166ac' if d == 'real' else '#666666' for d in plot_df['direction']]

plt.figure(figsize=(10, max(6, 0.35 * len(plot_df))))
plt.barh(plot_df['feature'], plot_df['coefficient'], color=colors)
plt.axvline(0, color='black', linewidth=1)
plt.title('Global Logistic Coefficients (positive=fake, negative=real)')
plt.tight_layout()
plt.savefig(PLOTS_DIR / 'global_coefficients_top20.png', dpi=150)
plt.show()
print('Saved:', PLOTS_DIR / 'global_coefficients_top20.png')


In [None]:

# Plot 2: Per-user heatmap (fake samples only, mean contributions)
fake_user = user_contrib[user_contrib['true label'] == 1].copy()
if fake_user.empty:
    print('No fake samples in test user contributions.')
else:
    heat = fake_user.pivot_table(index='speaker_id', columns='feature', values='mean_contribution', aggfunc='mean')
    test_speakers = run_meta['speaker_split']['test_speakers']
    heat = heat.loc[[s for s in test_speakers if s in heat.index]]
    heat = heat.fillna(0)

    plt.figure(figsize=(max(10, 0.45 * heat.shape[1]), max(4, 0.6 * heat.shape[0])))
    im = plt.imshow(heat.to_numpy(), aspect='auto', cmap='coolwarm')
    plt.colorbar(im, label='Mean contribution')
    plt.xticks(range(heat.shape[1]), heat.columns, rotation=90, fontsize=8)
    plt.yticks(range(heat.shape[0]), heat.index, fontsize=9)
    plt.title('Per-user Mean Concept Contributions (Fake Samples Only)')
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / 'user_heatmap_fake_samples.png', dpi=150)
    plt.show()
    print('Saved:', PLOTS_DIR / 'user_heatmap_fake_samples.png')


In [None]:

# Optional: summary of which concepts recur most often in top-k lists across users
freq = (
    top_user_df.groupby(['list_type', 'concept']).size().reset_index(name='count')
    .sort_values(['list_type', 'count', 'concept'], ascending=[True, False, True])
)
display(freq)
freq.to_csv(ANALYSIS_DIR / 'top_concept_frequency_across_users.csv', index=False)
print('Saved:', ANALYSIS_DIR / 'top_concept_frequency_across_users.csv')
