In [None]:
from google.colab import drive

# Force remount to re-initiate the authorization process
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!cat /content/drive/MyDrive/SSL_models/emotions/updated.csv | head -n 2

filename|wavpath|transcription
Calm-1|/home/self/workspace/professional/Talks-classes/IIIDh/datasets/emotions/calm/wavs/Calm-1.wav|I bet that is really neat!


In [None]:
!sed 's|/home/self/workspace/professional/Talks-classes/IIIDh/datasets/emotions/|/content/drive/MyDrive/SSL_models/emotions/|g' /content/drive/MyDrive/SSL_models/emotions/updated.csv > /content/drive/MyDrive/SSL_models/emotions/updated_gc.csv

In [None]:
!cat /content/drive/MyDrive/SSL_models/emotions/updated_gc.csv | head -n 2

filename|wavpath|transcription
Calm-1|/content/drive/MyDrive/SSL_models/emotions/calm/wavs/Calm-1.wav|I bet that is really neat!


In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import librosa
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from sentence_transformers import SentenceTransformer
from sklearn.cross_decomposition import CCA

# ------------------------------
# 1. Helper classes & functions
# ------------------------------

class SSLModelExtractor:
    """
    A class to load a self-supervised speech model (wav2vec2/Hubert/WavLM)
    without using a Processor, and extract layer-wise embeddings.
    """
    def __init__(self, model_name, device='cpu'):
        """
        model_name can be one of:
         - "facebook/wav2vec2-base"
         - "facebook/hubert-base-ls960"
         - "microsoft/wavlm-base-plus"
         or any other valid Hugging Face SSL checkpoint.
        """
        self.device = device
        self.model = self.load_model(model_name).to(device)
        self.model.eval()

    def load_model(self, model_name):
        if "wav2vec2" in model_name.lower():
            model = Wav2Vec2Model.from_pretrained(model_name, output_hidden_states=True)
        elif "hubert" in model_name.lower():
            model = HubertModel.from_pretrained(model_name, output_hidden_states=True)
        elif "wavlm" in model_name.lower():
            model = WavLMModel.from_pretrained(model_name, output_hidden_states=True)
        else:
            raise ValueError(f"Unsupported model name: {model_name}")
        return model

    def process_audio(self, audio, sampling_rate=16000):
        """
        Converts a 1D numpy array into a float32 tensor with shape (1, time).
        No normalization or padding is done here. If you want to match the
        official pipelines exactly, you may add normalization.
        """
        input_tensor = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
        return {"input_values": input_tensor.to(self.device)}

    def get_cnn_layer_embeddings(self, inputs):
        """
        Extracts outputs of the CNN (feature encoder) layers.
        Assumes the model's feature extractor exposes its conv layers as:
        model.feature_extractor.conv_layers
        For each CNN layer, a forward pass is performed and the output is mean-pooled over time.
        Returns a list of embeddings—one per CNN layer.
        """
        # The feature extractor expects shape (batch, channel, time)
        x = inputs["input_values"].unsqueeze(1)  # shape: (1, 1, time)
        cnn_embeddings = []
        with torch.no_grad():  # Add no_grad to prevent gradient tracking
            for conv in self.model.feature_extractor.conv_layers:
                x = conv(x)  # shape: (1, channels, new_time)
                mean_emb = x.mean(dim=-1).squeeze(dim=0).cpu().numpy()  # (channels,)
                cnn_embeddings.append(mean_emb)
        return cnn_embeddings

    def extract_transformer_embeddings_all_layers(self, inputs):
        """
        Runs a forward pass on the model and returns all hidden states.
        outputs.hidden_states is a tuple of length (num_layers + 1):
          - hidden_states[0] is the CNN/feature encoder output
          - hidden_states[1:] are the transformer layer outputs.
        """
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.hidden_states

    def get_mean_pooled_layer_embeddings(self, hidden_states, skip_feature_encoder=True):
        """
        Given a tuple of hidden_states, returns a list of mean-pooled embeddings—one per layer.
        If skip_feature_encoder is True, hidden_states[0] (the CNN output) is ignored.
        """
        if skip_feature_encoder:
            hidden_states = hidden_states[1:]
        transformer_embeddings = []
        for layer_state in hidden_states:
            # layer_state shape: (1, time, hidden_dim); mean over time dimension
            mean_emb = layer_state.mean(dim=1).squeeze(dim=0).cpu().numpy()
            transformer_embeddings.append(mean_emb)
        return transformer_embeddings

def extract_mel_features(wav_path, sample_rate=16000, n_mels=80,
                         frame_length=0.025, frame_stride=0.010):
    """
    Extract 80-dim log Mel filterbank features from a .wav file and
    average them over time to get a single (80,) vector.
    """
    try:
        # Check if the file exists before attempting to load it
        if not os.path.exists(wav_path):
            print(f"Warning: File not found - {wav_path}")
            return np.zeros(n_mels)  # Return a zero vector if file not found

        audio, sr = librosa.load(wav_path, sr=sample_rate)
        hop_length = int(sr * frame_stride)
        win_length = int(sr * frame_length)

        # Updated call with y=audio instead of positional argument
        mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr,
                                                n_fft=win_length,
                                                hop_length=hop_length,
                                                win_length=win_length,
                                                n_mels=n_mels)
        log_mel_spec = librosa.power_to_db(mel_spec)
        mean_mel = np.mean(log_mel_spec, axis=1)  # shape: (80,)
        return mean_mel
    except Exception as e:
        print(f"Error processing file {wav_path}: {e}")
        return np.zeros(n_mels)  # Return a zero vector on error

def batch_sentence_embeddings(sentences, model, batch_size=1, device='cuda'):
    """
    Computes sentence embeddings in batches using a SentenceTransformer model.
    Returns a numpy array of shape (num_sentences, embedding_dim).
    """
    embeddings = model.encode(sentences, batch_size=batch_size,
                              convert_to_tensor=True, device=device)
    return embeddings.detach().cpu().numpy()

def run_cca(X, Y, n_components=1):
    """
    Perform Canonical Correlation Analysis (CCA) between X and Y.
    X and Y shape: (num_samples, num_features).
    Returns the correlation coefficient for the first canonical component.
    """
    cca = CCA(n_components=n_components)
    cca.fit(X, Y)
    X_c, Y_c = cca.transform(X, Y)
    corr_coef = np.corrcoef(X_c[:, 0], Y_c[:, 0])[0, 1]
    return corr_coef

def compute_cca_across_layers(features, model_key, feature_type='mel', include_cnn=True):
    """
    For each layer in model_key (e.g. 'wav2vec2_layers'), compute a single CCA correlation
    with the reference feature type ('mel' or 'sentence'), across all utterances.

    If include_cnn=True, includes the CNN/feature encoder output (index 0) in the analysis.

    Returns a list of correlation coefficients, one per layer.
    """
    if feature_type == 'mel':
        ref_matrix = np.vstack(features['mel_features'])  # shape: (num_utts, 80)
    elif feature_type == 'sentence':
        ref_matrix = np.vstack(features['sentence_embedding'])  # shape: (num_utts, emb_dim)
    else:
        raise ValueError("feature_type must be 'mel' or 'sentence'")

    # Suppose each utterance has L layers => features[model_key][i] is a list of length L
    # We'll gather the i-th layer across all utterances into one matrix => shape (num_utts, hidden_dim)
    num_layers = len(features[model_key][0])
    cca_scores = []
    for layer_idx in range(num_layers):
        layer_matrix = []
        for utt_idx in range(len(features[model_key])):
            layer_matrix.append(features[model_key][utt_idx][layer_idx])
        layer_matrix = np.vstack(layer_matrix)  # shape: (num_utts, hidden_dim)

        corr = run_cca(ref_matrix, layer_matrix)
        cca_scores.append(corr)
    return cca_scores

def plot_cca_scores(cca_scores_dict, title, layer_prefix):
    """
    Plots layer-wise CCA correlation curves for multiple models.
    cca_scores_dict: {model_name: [corr_layer1, corr_layer2, ...], ...}
    layer_prefix: 'CNN' for CNN layers or 'T' for Transformer layers.
    """
    plt.figure(figsize=(10, 6))

    for model_name, scores in cca_scores_dict.items():
        layers = range(len(scores))
        # Generate labels based on layer_prefix
        labels = [f"{layer_prefix}{i+1}" for i in range(len(scores))]

        plt.plot(layers, scores, marker='o', label=model_name)

    plt.xticks(layers, labels, rotation=45)
    plt.xlabel('Layer')
    plt.ylabel('CCA similarity')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_combined_cca_scores(cnn_scores_dict, trans_scores_dict, title):
    """
    Plots layer-wise CCA correlation curves for CNN and Transformer layers in one figure.
    cnn_scores_dict: {model_name: [corr_cnn_layer1, ...], ...}
    trans_scores_dict: {model_name: [corr_trans_layer1, ...], ...}
    """
    plt.figure(figsize=(14, 7))

    for model_name in cnn_scores_dict.keys():
        # Combine CNN and Transformer scores
        cnn_scores = cnn_scores_dict[model_name]
        trans_scores = trans_scores_dict[model_name]
        combined_scores = cnn_scores + trans_scores

        # Generate layer labels: CNN1, CNN2..., T1, T2...
        num_cnn = len(cnn_scores)
        num_trans = len(trans_scores)
        labels = [f'CNN{i+1}' for i in range(num_cnn)] + [f'T{i+1}' for i in range(num_trans)]

        x = range(len(combined_scores))
        plt.plot(x, combined_scores, marker='o', linestyle='-', label=model_name)

    # Set x-axis labels
    all_labels = (
        [f'C{i+1}' for i in range(len(cnn_scores_dict['Wav2Vec2']))] +
        [f'T{i+1}' for i in range(len(trans_scores_dict['Wav2Vec2']))]
    )
    plt.xticks(range(len(all_labels)), all_labels, rotation=45)

    plt.xlabel('Layer (CNN + Transformer)')
    plt.ylabel('CCA Similarity')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# ------------------------------
# 2. Main pipeline (both CNN and transformer layers)
# ------------------------------
def main(csv_file="/content/drive/MyDrive/SSL_models/emotions/updated_gc.csv", device=None):
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using device:", device)

    # 1) Read CSV (with columns: filename|wavpath|transcription)
    df = pd.read_csv(csv_file, sep='|')
    print(f"Loaded {len(df)} utterances from {csv_file}")

    # 2) Prepare a dictionary to store features:
    #    - mel_features: common acoustic reference
    #    - For each SSL model, store CNN layer embeddings and transformer layer embeddings.
    features = {
        'filename': [],
        'mel_features': [],
        'sentence_embedding': [],  # For sentence transformer embeddings (if used)
        'wav2vec2_cnn_layers': [],
        'hubert_cnn_layers': [],
        'wavlm_cnn_layers': [],
        'wav2vec2_transformer_layers': [],
        'hubert_transformer_layers': [],
        'wavlm_transformer_layers': []
    }

    # 3) Instantiate SSL model extractors once (to avoid reloading per utterance)
    print("Loading SSL models...")
    wav2vec2_extractor = SSLModelExtractor("facebook/wav2vec2-base", device=device)
    hubert_extractor   = SSLModelExtractor("facebook/hubert-base-ls960", device=device)
    wavlm_extractor    = SSLModelExtractor("microsoft/wavlm-base", device=device)

    # Collect text for batch BERT embeddings
    transcriptions = df['transcription'].tolist()

    #Prepare BERT-based sentence embedding model
    print("Loading sentence embedding model")
    bert_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)

    # Compute all sentence embeddings at once
    sentence_embs = batch_sentence_embeddings(transcriptions,
                                              bert_model,
                                              batch_size=16,
                                              device=device)

    # 4) Process each utterance
    print("Processing audio files...")
    for i, row in tqdm(df.iterrows(), total=len(df), desc="Utterances"):
        filename = row['filename']
        wav_path = row['wavpath']
        transcript = row['transcription']
        features['filename'].append(filename)

        # Extract Mel filterbank features as the reference
        mel_vec = extract_mel_features(wav_path)
        features['mel_features'].append(mel_vec)

        # BERT sentence embedding (already computed in batch)
        features['sentence_embedding'].append(sentence_embs[i])

        try:
            audio, sr = librosa.load(wav_path, sr=16000)
            # Process audio into input dict
            inputs_w2v2 = wav2vec2_extractor.process_audio(audio, sr)
            inputs_hubert = hubert_extractor.process_audio(audio, sr)
            inputs_wavlm = wavlm_extractor.process_audio(audio, sr)

            # ----- CNN Layers Extraction -----
            w2v2_cnn = wav2vec2_extractor.get_cnn_layer_embeddings(inputs_w2v2)
            hubert_cnn = hubert_extractor.get_cnn_layer_embeddings(inputs_hubert)
            wavlm_cnn = wavlm_extractor.get_cnn_layer_embeddings(inputs_wavlm)

            features['wav2vec2_cnn_layers'].append(w2v2_cnn)
            features['hubert_cnn_layers'].append(hubert_cnn)
            features['wavlm_cnn_layers'].append(wavlm_cnn)

            # ----- Transformer Layers Extraction -----
            w2v2_hidden = wav2vec2_extractor.extract_transformer_embeddings_all_layers(inputs_w2v2)
            hubert_hidden = hubert_extractor.extract_transformer_embeddings_all_layers(inputs_hubert)
            wavlm_hidden = wavlm_extractor.extract_transformer_embeddings_all_layers(inputs_wavlm)

            # For transformer layers, skip the feature encoder output (index 0)
            w2v2_trans = wav2vec2_extractor.get_mean_pooled_layer_embeddings(w2v2_hidden, skip_feature_encoder=True)
            hubert_trans = hubert_extractor.get_mean_pooled_layer_embeddings(hubert_hidden, skip_feature_encoder=True)
            wavlm_trans = wavlm_extractor.get_mean_pooled_layer_embeddings(wavlm_hidden, skip_feature_encoder=True)

            features['wav2vec2_transformer_layers'].append(w2v2_trans)
            features['hubert_transformer_layers'].append(hubert_trans)
            features['wavlm_transformer_layers'].append(wavlm_trans)
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            # In case of error, add empty embeddings:
            empty_cnn = [np.zeros(768) for _ in range(7)]   # Assuming 7 CNN layers
            empty_trans = [np.zeros(768) for _ in range(12)]  # Assuming 12 transformer layers
            features['wav2vec2_cnn_layers'].append(empty_cnn)
            features['hubert_cnn_layers'].append(empty_cnn)
            features['wavlm_cnn_layers'].append(empty_cnn)
            features['wav2vec2_transformer_layers'].append(empty_trans)
            features['hubert_transformer_layers'].append(empty_trans)
            features['wavlm_transformer_layers'].append(empty_trans)

    ##################################################################################
    # MFCC - CCA
    ##################################################################################
    # 5) Compute CCA for CNN layers: MFCCs
    print("Computing CCA for CNN layers...")
    cca_w2v2_cnn = compute_cca_across_layers(features, 'wav2vec2_cnn_layers', feature_type='mel')
    cca_hubert_cnn = compute_cca_across_layers(features, 'hubert_cnn_layers', feature_type='mel')
    cca_wavlm_cnn = compute_cca_across_layers(features, 'wavlm_cnn_layers', feature_type='mel')

    cca_scores_cnn = {
        'Wav2Vec2': cca_w2v2_cnn,
        'HuBERT':   cca_hubert_cnn,
        'WavLM':    cca_wavlm_cnn
    }
    #plot_cca_scores(cca_scores_cnn, "CCA: Mel Filterbank vs. CNN Layers", layer_prefix="C")

    # 6) Compute CCA for Transformer layers: MFCCs
    print("Computing CCA for Transformer layers...")
    cca_w2v2_trans = compute_cca_across_layers(features, 'wav2vec2_transformer_layers', feature_type='mel')
    cca_hubert_trans = compute_cca_across_layers(features, 'hubert_transformer_layers', feature_type='mel')
    cca_wavlm_trans = compute_cca_across_layers(features, 'wavlm_transformer_layers', feature_type='mel')

    cca_scores_trans = {
        'Wav2Vec2': cca_w2v2_trans,
        'HuBERT':   cca_hubert_trans,
        'WavLM':    cca_wavlm_trans
    }
    # For transformer layers, label as T1, T2, ..., T12
    #plot_cca_scores(cca_scores_trans, "CCA: Mel Filterbank vs. Transformer Layers", layer_prefix="T")

    plot_combined_cca_scores(cnn_scores_dict=cca_scores_cnn,
                             trans_scores_dict=cca_scores_trans,
                             title="CCA: Mel Filterbank vs. All Layers")

    ##################################################################################
    # Sentence embeddings - CCA
    ##################################################################################
    # 5) Compute CCA for CNN layers
    print("Computing CCA for CNN layers...")
    cca_w2v2_cnn_s = compute_cca_across_layers(features, 'wav2vec2_cnn_layers', feature_type='sentence')
    cca_hubert_cnn_s = compute_cca_across_layers(features, 'hubert_cnn_layers', feature_type='sentence')
    cca_wavlm_cnn_s = compute_cca_across_layers(features, 'wavlm_cnn_layers', feature_type='sentence')

    cca_scores_cnn_s = {
        'Wav2Vec2': cca_w2v2_cnn_s,
        'HuBERT':   cca_hubert_cnn_s,
        'WavLM':    cca_wavlm_cnn_s
    }
    #plot_cca_scores(cca_scores_cnn, "CCA: Mel Filterbank vs. CNN Layers", layer_prefix="C")

    # 6) Compute CCA for Transformer layers
    print("Computing CCA for Transformer layers...")
    cca_w2v2_trans_s = compute_cca_across_layers(features, 'wav2vec2_transformer_layers', feature_type='sentence')
    cca_hubert_trans_s = compute_cca_across_layers(features, 'hubert_transformer_layers', feature_type='sentence')
    cca_wavlm_trans_s = compute_cca_across_layers(features, 'wavlm_transformer_layers', feature_type='sentence')

    cca_scores_trans_s = {
        'Wav2Vec2': cca_w2v2_trans_s,
        'HuBERT':   cca_hubert_trans_s,
        'WavLM':    cca_wavlm_trans_s
    }
    # For transformer layers, label as T1, T2, ..., T12
    #plot_cca_scores(cca_scores_trans, "CCA: Mel Filterbank vs. Transformer Layers", layer_prefix="T")

    plot_combined_cca_scores(cnn_scores_dict=cca_scores_cnn_s,
                             trans_scores_dict=cca_scores_trans_s,
                             title="CCA: Sentence embeddings vs. All Layers")


    # 7) Optionally, save the features to disk
    with open("features_full.pkl", "wb") as f:
        pickle.dump(features, f)
    print("Saved full features to features_full.pkl")

    # Return the computed CCA scores for further analysis if needed
    return {
        'cnn': cca_scores_cnn,
        'transformer': cca_scores_trans
    }

if __name__ == "__main__":
    main()



Using device: cuda
Loaded 2836 utterances from /content/drive/MyDrive/SSL_models/emotions/updated_gc.csv
Loading SSL models...




Loading sentence embedding model


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Processing audio files...


Utterances:  21%|██▏       | 604/2836 [09:20<34:30,  1.08it/s]


KeyboardInterrupt: 