# K-Means Training on DistilHuBERT Layer Features
500-cluster K-Means model on concatenated outputs from the first three layers of the fine-tuned DistilHuBERT model (step1)

In [3]:
import sys
sys.executable

import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [4]:
import os
import torch
import numpy as np
from datasets import load_dataset, Audio,load_dataset, DownloadConfig
from transformers import AutoProcessor, AutoModel
from sklearn.cluster import MiniBatchKMeans
from train_classifier import extract_all_layer_features

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained("/mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/checkpoints_distilhubert_asr/final_model").to(device)
processor = AutoProcessor.from_pretrained("/mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/checkpoints_distilhubert_asr/final_model")

In [6]:
# Config
num_clusters = 500
layers = [0, 1, 2]
batch_size = 32  # # of audio samples processed at once
max_batches = 1000  # safeguard against too much data (optional)

In [7]:
dataset = load_dataset("audiofolder", 
                       data_dir="/mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/data/LibriSpeech/train-other-500")["train"]

Resolving data files:   0%|          | 0/151472 [00:00<?, ?it/s]

In [8]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

#### test on a small batch


In [None]:
'''
dataset = load_dataset(
    "audiofolder",
    data_dir="/mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/data/LibriSpeech/train-other-500/3885/1193"
)["train"]  # 👈 get the 'train' split from the dataset dict

print("Loaded dataset:", len(dataset))
print("Example:", dataset[0])

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

all_features = []

for example in dataset:  # e.g., loaded using "audiofolder"
    waveform = processor(example["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_values.to(device)
    feats = extract_all_layer_features(model, waveform, layers=[0, 1, 2])
    all_features.append(feats.squeeze(0).cpu().numpy())  # [T, D]

# Combine for clustering
all_feats = np.vstack(all_features)  # [total_T, D]
kmeans = MiniBatchKMeans(n_clusters=500, batch_size=2048)
kmeans.fit(all_feats)

print("Cluster centers shape:", kmeans.cluster_centers_.shape)  # Should be [500, D]
test_feats = extract_all_layer_features(model, waveform, layers=[0,1,2]).squeeze(0).cpu().numpy()
cluster_ids = kmeans.predict(test_feats)  # shape [T]
print(cluster_ids[:20])
'''

In [11]:
print("Type of dataset:", type(dataset))
print("Sample audio entry:", dataset[0]["audio"])


Type of dataset: <class 'datasets.arrow_dataset.Dataset'>
Sample audio entry: {'path': '/mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/data/LibriSpeech/train-other-500/1006/135212/1006-135212-0000.flac', 'array': array([0.04141235, 0.04229736, 0.04345703, ..., 0.03570557, 0.03594971,
       0.03588867]), 'sampling_rate': 16000}


In [13]:
# Initialize KMeans
kmeans = MiniBatchKMeans(n_clusters=num_clusters, batch_size=2048)

# Incrementally feed KMeans
batch_count = 0
for i in range(0, len(dataset), batch_size):
    batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
    feats_batch = []

    for example in batch:
        audio = example["audio"]
        input_values = processor(audio["array"], sampling_rate=16000, return_tensors="pt").input_values.to(device)
        features = extract_all_layer_features(model, input_values, layers=layers)
        feats_batch.append(features.squeeze(0).cpu().numpy())

    all_feats = np.vstack(feats_batch)
    kmeans.partial_fit(all_feats)
    all_feats = np.vstack(feats_batch)  # shape [total_T, D]
    kmeans.partial_fit(all_feats)  # update centroids

    batch_count += 1
    print(f"Processed batch {batch_count}, features shape: {all_feats.shape}")

    if batch_count >= max_batches:
        print("Reached max batch limit. Stopping early.")
        break

Processed batch 1, features shape: (21580, 2304)
Processed batch 2, features shape: (20856, 2304)
Processed batch 3, features shape: (18893, 2304)
Processed batch 4, features shape: (21753, 2304)
Processed batch 5, features shape: (21415, 2304)
Processed batch 6, features shape: (20193, 2304)
Processed batch 7, features shape: (14600, 2304)
Processed batch 8, features shape: (15208, 2304)
Processed batch 9, features shape: (17387, 2304)
Processed batch 10, features shape: (14246, 2304)
Processed batch 11, features shape: (15522, 2304)
Processed batch 12, features shape: (17825, 2304)
Processed batch 13, features shape: (19143, 2304)
Processed batch 14, features shape: (18272, 2304)
Processed batch 15, features shape: (18724, 2304)
Processed batch 16, features shape: (18361, 2304)
Processed batch 17, features shape: (19963, 2304)
Processed batch 18, features shape: (17888, 2304)
Processed batch 19, features shape: (15600, 2304)
Processed batch 20, features shape: (16079, 2304)
Processed

In [15]:
# Save model
save_path = "/mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/script/train_classifier/kmeans_model_500.pkl"
from joblib import dump
dump(kmeans, save_path)
print(f"KMeans model saved to {save_path}")

KMeans model saved to /mnt/scratch/pippalin2/jupyter/GMM-DistilHuBERT/script/train_classifier/kmeans_model_500.pkl
