# DAC Embeddings with Concatenation Pooling

This notebook uses **concatenation** instead of **averaging** for pooling DAC embeddings.

## Strategy:
- Extract codebook embeddings: `[batch, n_codebooks=12, time, dim=8]`
- **Mean pool across time**: `[batch, 12, 8]`
- **Concatenate codebooks**: `[batch, 96]` (12 Ã— 8 = 96D)

**Hypothesis**: This preserves more information than averaging and should improve clustering!

In [None]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, davies_bouldin_score
from tqdm import tqdm

# Import our utilities
from dac_utils import DACProcessor, SpeechCommandsLoader

print("Imports successful!")

## Step 1: Initialize DAC Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

dac_processor = DACProcessor(model_type="16khz", device=device)

## Step 2: Custom Extraction Function with Concatenation

In [None]:
def extract_concatenated_embedding(dac_processor, audio_path):
    """
    Extract DAC embedding using concatenation strategy
    
    Returns:
        96D vector (12 codebooks Ã— 8D each)
    """
    # Encode audio
    encoded = dac_processor.encode_audio(audio_path)
    codes = encoded['codes']  # [1, 12, T]
    
    # Get codebook embeddings
    embeddings = dac_processor.get_codebook_embeddings(codes)  # [1, 12, T, 8]
    
    # Mean pool across time dimension
    time_pooled = embeddings.mean(dim=2)  # [1, 12, 8]
    
    # Concatenate all codebooks (flatten last two dimensions)
    concatenated = time_pooled.reshape(1, -1)  # [1, 96]
    
    # Convert to numpy
    vector = concatenated.squeeze(0).detach().cpu().numpy()  # [96]
    
    return vector

print("Custom extraction function defined!")

## Step 3: Test on Single Sample

In [None]:
# Test the extraction
loader = SpeechCommandsLoader()
file_paths, labels = loader.load_word_samples(['zero'], samples_per_word=1)

if len(file_paths) > 0:
    test_file = file_paths[0]
    print(f"Testing with: {test_file}\n")
    
    embedding = extract_concatenated_embedding(dac_processor, test_file)
    
    print(f"Embedding shape: {embedding.shape}  # Should be (96,)")
    print(f"\nFirst 10 values: {embedding[:10]}")
    print(f"\nStatistics:")
    print(f"  Min: {embedding.min():.4f}")
    print(f"  Max: {embedding.max():.4f}")
    print(f"  Mean: {embedding.mean():.4f}")
    print(f"  Std: {embedding.std():.4f}")
else:
    print("No audio files found!")

## Step 4: Load Dataset - 5 Words, 10 Samples Each

In [None]:
# Select 5 words for visualization
words = ['zero', 'one', 'two', 'yes', 'no']
samples_per_word = 10

# Load audio paths
file_paths, file_labels = loader.load_word_samples(words, samples_per_word=samples_per_word)

print(f"Total samples: {len(file_paths)}")
print(f"\nLabel distribution:")
for word in words:
    count = file_labels.count(word)
    print(f"  {word}: {count} samples")

## Step 5: Extract Concatenated Embeddings for All Samples

In [None]:
# Extract embeddings for all samples
embeddings_list = []
valid_labels = []

for file_path, label in tqdm(zip(file_paths, file_labels), total=len(file_paths), desc="Extracting embeddings"):
    try:
        embedding = extract_concatenated_embedding(dac_processor, file_path)
        embeddings_list.append(embedding)
        valid_labels.append(label)
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        continue

embeddings = np.array(embeddings_list)
print(f"\nFinal embeddings shape: {embeddings.shape}  # Should be (50, 96)")
print(f"Embedding dimension: {embeddings.shape[1]}D (12 codebooks Ã— 8D)")

## Step 6: PCA Visualizations

In [None]:
# PCA - 2D
pca_2d = PCA(n_components=2, random_state=42)
pca_2d_result = pca_2d.fit_transform(embeddings)
variance_2d = pca_2d.explained_variance_ratio_.sum()

print(f"PCA 2D variance explained: {variance_2d:.2%}")

# Create color map
color_map = {word: px.colors.qualitative.Plotly[i] for i, word in enumerate(words)}

# 2D Plot
fig_2d = go.Figure()

for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig_2d.add_trace(go.Scatter(
        x=pca_2d_result[mask, 0],
        y=pca_2d_result[mask, 1],
        mode='markers',
        name=word,
        marker=dict(size=12, color=color_map[word], opacity=0.7, line=dict(width=0.5, color='white'))
    ))

fig_2d.update_layout(
    title=f'PCA 2D: DAC Concatenated Embeddings (96D) - Variance: {variance_2d:.1%}',
    xaxis_title='PC 1',
    yaxis_title='PC 2',
    width=900,
    height=700,
    xaxis=dict(scaleanchor='y', scaleratio=1),
    template='plotly_white'
)

fig_2d.write_html('dac_concat_pca_2d.html')
fig_2d.show()

print("Saved: dac_concat_pca_2d.html")

In [None]:
# PCA - 3D
pca_3d = PCA(n_components=3, random_state=42)
pca_3d_result = pca_3d.fit_transform(embeddings)
variance_3d = pca_3d.explained_variance_ratio_.sum()

print(f"PCA 3D variance explained: {variance_3d:.2%}")

# 3D Plot
fig_3d = go.Figure()

for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig_3d.add_trace(go.Scatter3d(
        x=pca_3d_result[mask, 0],
        y=pca_3d_result[mask, 1],
        z=pca_3d_result[mask, 2],
        mode='markers',
        name=word,
        marker=dict(size=6, color=color_map[word], opacity=0.7, line=dict(width=0))
    ))

fig_3d.update_layout(
    title=f'PCA 3D: DAC Concatenated Embeddings (96D) - Variance: {variance_3d:.1%}',
    scene=dict(
        xaxis_title='PC 1',
        yaxis_title='PC 2',
        zaxis_title='PC 3',
        aspectmode='cube',
        camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
    ),
    width=900,
    height=700,
    template='plotly_white'
)

fig_3d.write_html('dac_concat_pca_3d.html')
fig_3d.show()

print("Saved: dac_concat_pca_3d.html")

## Step 7: t-SNE Visualizations

In [None]:
# t-SNE - 2D
perplexity = min(30, len(embeddings) - 1)
tsne_2d = TSNE(n_components=2, random_state=42, metric='cosine', perplexity=perplexity)
tsne_2d_result = tsne_2d.fit_transform(embeddings)

print(f"t-SNE 2D completed with perplexity={perplexity}")

# 2D Plot
fig_tsne_2d = go.Figure()

for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig_tsne_2d.add_trace(go.Scatter(
        x=tsne_2d_result[mask, 0],
        y=tsne_2d_result[mask, 1],
        mode='markers',
        name=word,
        marker=dict(size=12, color=color_map[word], opacity=0.7, line=dict(width=0.5, color='white'))
    ))

fig_tsne_2d.update_layout(
    title=f't-SNE 2D: DAC Concatenated Embeddings (96D) - Perplexity={perplexity}',
    xaxis_title='t-SNE 1',
    yaxis_title='t-SNE 2',
    width=900,
    height=700,
    xaxis=dict(scaleanchor='y', scaleratio=1),
    template='plotly_white'
)

fig_tsne_2d.write_html('dac_concat_tsne_2d.html')
fig_tsne_2d.show()

print("Saved: dac_concat_tsne_2d.html")

In [None]:
# t-SNE - 3D
tsne_3d = TSNE(n_components=3, random_state=42, metric='cosine', perplexity=perplexity)
tsne_3d_result = tsne_3d.fit_transform(embeddings)

print(f"t-SNE 3D completed with perplexity={perplexity}")

# 3D Plot
fig_tsne_3d = go.Figure()

for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig_tsne_3d.add_trace(go.Scatter3d(
        x=tsne_3d_result[mask, 0],
        y=tsne_3d_result[mask, 1],
        z=tsne_3d_result[mask, 2],
        mode='markers',
        name=word,
        marker=dict(size=6, color=color_map[word], opacity=0.7, line=dict(width=0))
    ))

fig_tsne_3d.update_layout(
    title=f't-SNE 3D: DAC Concatenated Embeddings (96D) - Perplexity={perplexity}',
    scene=dict(
        xaxis_title='t-SNE 1',
        yaxis_title='t-SNE 2',
        zaxis_title='t-SNE 3',
        aspectmode='cube',
        camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
    ),
    width=900,
    height=700,
    template='plotly_white'
)

fig_tsne_3d.write_html('dac_concat_tsne_3d.html')
fig_tsne_3d.show()

print("Saved: dac_concat_tsne_3d.html")

## Step 8: Combined Dashboard View

In [None]:
# Create a 2x2 grid
fig = make_subplots(
    rows=2, cols=2,
    specs=[
        [{'type': 'scatter'}, {'type': 'scatter3d'}],
        [{'type': 'scatter'}, {'type': 'scatter3d'}]
    ],
    subplot_titles=(
        f'PCA 2D (Var: {variance_2d:.1%})',
        f'PCA 3D (Var: {variance_3d:.1%})',
        f't-SNE 2D (Perp: {perplexity})',
        f't-SNE 3D (Perp: {perplexity})'
    ),
    vertical_spacing=0.12,
    horizontal_spacing=0.1
)

# PCA 2D
for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig.add_trace(
        go.Scatter(
            x=pca_2d_result[mask, 0],
            y=pca_2d_result[mask, 1],
            mode='markers',
            name=word,
            marker=dict(size=8, color=color_map[word], opacity=0.7),
            showlegend=True,
            legendgroup=word
        ),
        row=1, col=1
    )

# PCA 3D
for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig.add_trace(
        go.Scatter3d(
            x=pca_3d_result[mask, 0],
            y=pca_3d_result[mask, 1],
            z=pca_3d_result[mask, 2],
            mode='markers',
            name=word,
            marker=dict(size=5, color=color_map[word], opacity=0.7),
            showlegend=False,
            legendgroup=word
        ),
        row=1, col=2
    )

# t-SNE 2D
for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig.add_trace(
        go.Scatter(
            x=tsne_2d_result[mask, 0],
            y=tsne_2d_result[mask, 1],
            mode='markers',
            name=word,
            marker=dict(size=8, color=color_map[word], opacity=0.7),
            showlegend=False,
            legendgroup=word
        ),
        row=2, col=1
    )

# t-SNE 3D
for word in words:
    mask = np.array([label == word for label in valid_labels])
    fig.add_trace(
        go.Scatter3d(
            x=tsne_3d_result[mask, 0],
            y=tsne_3d_result[mask, 1],
            z=tsne_3d_result[mask, 2],
            mode='markers',
            name=word,
            marker=dict(size=5, color=color_map[word], opacity=0.7),
            showlegend=False,
            legendgroup=word
        ),
        row=2, col=2
    )

fig.update_layout(
    title_text='DAC Concatenated Embeddings (96D) - Complete Visualization',
    title_x=0.5,
    title_font_size=20,
    width=1400,
    height=1000,
    showlegend=True,
    template='plotly_white'
)

fig.update_xaxes(title_text='PC 1', row=1, col=1, scaleanchor='y', scaleratio=1)
fig.update_yaxes(title_text='PC 2', row=1, col=1)
fig.update_xaxes(title_text='t-SNE 1', row=2, col=1, scaleanchor='y3', scaleratio=1)
fig.update_yaxes(title_text='t-SNE 2', row=2, col=1)

fig.update_scenes(aspectmode='cube', row=1, col=2)
fig.update_scenes(aspectmode='cube', row=2, col=2)

fig.write_html('dac_concat_complete.html')
fig.show()

print("Saved: dac_concat_complete.html")

## Step 9: Clustering Quality Metrics

In [None]:
# Convert labels to numeric
label_to_idx = {word: i for i, word in enumerate(words)}
numeric_labels = np.array([label_to_idx[label] for label in valid_labels])

print("=" * 60)
print("CLUSTERING QUALITY METRICS - CONCATENATION (96D)")
print("=" * 60)

# Original embeddings
if len(np.unique(numeric_labels)) > 1 and len(embeddings) > len(np.unique(numeric_labels)):
    sil_orig = silhouette_score(embeddings, numeric_labels)
    db_orig = davies_bouldin_score(embeddings, numeric_labels)
    print(f"\nOriginal Embeddings (96D):")
    print(f"  Silhouette Score: {sil_orig:.4f}  (higher is better, range: -1 to 1)")
    print(f"  Davies-Bouldin Score: {db_orig:.4f}  (lower is better, >0)")

# PCA 2D
sil_pca2d = silhouette_score(pca_2d_result, numeric_labels)
db_pca2d = davies_bouldin_score(pca_2d_result, numeric_labels)
print(f"\nPCA 2D:")
print(f"  Silhouette Score: {sil_pca2d:.4f}")
print(f"  Davies-Bouldin Score: {db_pca2d:.4f}")

# PCA 3D
sil_pca3d = silhouette_score(pca_3d_result, numeric_labels)
db_pca3d = davies_bouldin_score(pca_3d_result, numeric_labels)
print(f"\nPCA 3D:")
print(f"  Silhouette Score: {sil_pca3d:.4f}")
print(f"  Davies-Bouldin Score: {db_pca3d:.4f}")

# t-SNE 2D
sil_tsne2d = silhouette_score(tsne_2d_result, numeric_labels)
db_tsne2d = davies_bouldin_score(tsne_2d_result, numeric_labels)
print(f"\nt-SNE 2D:")
print(f"  Silhouette Score: {sil_tsne2d:.4f}")
print(f"  Davies-Bouldin Score: {db_tsne2d:.4f}")

# t-SNE 3D
sil_tsne3d = silhouette_score(tsne_3d_result, numeric_labels)
db_tsne3d = davies_bouldin_score(tsne_3d_result, numeric_labels)
print(f"\nt-SNE 3D:")
print(f"  Silhouette Score: {sil_tsne3d:.4f}")
print(f"  Davies-Bouldin Score: {db_tsne3d:.4f}")

print("\n" + "=" * 60)
print("\nðŸ“Š COMPARISON WITH AVERAGING (8D):")
print("Previous results (mean pooling across codebooks):")
print("  Original (8D): Silhouette = -0.0731, Davies-Bouldin = 4.99")
print("\nExpected improvement with concatenation (96D):")
print("  âœ… Higher dimensionality preserves more information")
print("  âœ… Each codebook's features kept separate")
print("  âœ… Should show better separation!")
print("=" * 60)

## Summary

This notebook demonstrated DAC embeddings with **concatenation pooling**:

### **Strategy**:
- Extract codebook embeddings: `[batch, 12, time, 8]`
- Mean pool time: `[batch, 12, 8]`
- **Concatenate codebooks**: `[batch, 96]` instead of averaging to `[batch, 8]`

### **Key Findings**:
- **Dimensionality**: 96D vs 8D (12Ã— larger)
- **Information preservation**: Each codebook's features kept separate
- **Expected result**: Better clustering quality than simple averaging

### **Comparison**:
Compare the clustering metrics above with the original notebook:
- If Silhouette Score improved (less negative or positive): âœ… Better separation
- If Davies-Bouldin Score decreased: âœ… Better clustering

### **Next Steps**:
1. Try using the **latent representation `z`** (1024D) for even better results
2. Compare all three approaches: 8D (averaging) vs 96D (concat) vs 1024D (latent)
3. Test with more diverse words or more samples per word