We use PWCCA to find the most representative layer for the finetuned DistilHuBERT \
https://github.com/ankitapasad/layerwise-analysis \
MFA env: mfa_env

In [18]:
clear_gpu_cache()

In [15]:
from transformers import HubertForCTC, AutoProcessor
from DHuBERT_utils import *
from datasets import load_from_disk
import soundfile as sf
import os

In [4]:
# Load model
model = HubertForCTC.from_pretrained("/scratch/pippalin2/jupyter/GMM-DistilHuBERT/checkpoints_distilhubert_asr/final_model").to('cuda')

processor = AutoProcessor.from_pretrained("/scratch/pippalin2/jupyter/GMM-DistilHuBERT/checkpoints_distilhubert_asr/final_model")

# Dummy input
waveform = torch.randn(1, 16000).to('cuda')


# Forward pass with hidden states
with torch.no_grad():
    outputs = model(input_values=waveform, output_hidden_states=True)

# Hidden states: list of tensors from each layer + input embeddings
hidden_states = outputs.hidden_states  # List of (batch_size, time_steps, hidden_dim)

print(f"# of layers (incl. input): {len(hidden_states)}")
print(f"Shape of one layer: {hidden_states[1].shape}")  # skip index 0 if you want encoder layers only

# of layers (incl. input): 3
Shape of one layer: torch.Size([1, 49, 768])


Distil HuBERT has 7 CNN layer and 3 transformer layer. We apply PWCCA on the transformer layer.

### 1. Extract Hidden Representation

In [5]:
small_data = data.select(range(500))  # Select first 500 rows

In [6]:
cca_phone_scores = []
cca_word_scores = []

for layer in range(3):
    layer_reps = extract_layer_representations(model, processor, small_data, layer)

    # CCA-phone
    phone_pooled = [pool_segment_features(x["layer_output"], x["phone_segments"]) for x in layer_reps]
    phone_flat = np.concatenate(phone_pooled)
    phone_labels = np.concatenate([x["phone_labels"] for x in layer_reps])
    phone_onehot = prepare_onehot_labels(phone_labels)
    cca_phone_scores.append(compute_pwcca_similarity(phone_flat, phone_onehot))

    # CCA-word
    word_pooled = [pool_segment_features(x["layer_output"], x["word_segments"]) for x in layer_reps]
    word_flat = np.concatenate(word_pooled)
    word_labels = np.concatenate([x["word_labels"] for x in layer_reps])
    word_onehot = prepare_onehot_labels(word_labels)
    cca_word_scores.append(compute_pwcca_similarity(word_flat, word_onehot))




Extracting layer 0 representations:   0%|          | 0/500 [00:00<?, ? examples/s]

KeyError: 'phone_segments'

In [None]:
plt.plot(range(4), cca_phone_scores, label='CCA-phone')
plt.plot(range(4), cca_word_scores, label='CCA-word')
plt.xlabel("Layer")
plt.ylabel("PWCCA Similarity")
plt.legend()
plt.title("Layer-wise PWCCA Scores (DistilHuBERT)")
plt.show()
