### Federated Fine-tuning for ASR
* Centralized: You fine-tune DistilHuBERT for downstream ASR tasks.
* FL: Simulate a federated ASR learning scenario where multiple clients (speakers) fine-tune a shared model on local speech data using FedAvg/FedOpt.

In [1]:
import torch
import copy
from datasets import load_from_disk
from transformers import (
    AutoProcessor,
    AutoModelForCTC,
    TrainingArguments,
    Trainer
)
from typing import List, Dict
from collections import defaultdict
from torch.utils.data import Dataset
import numpy as np


In [2]:
dataset = load_from_disk("/scratch/pippalin2/jupyter/GMM-DistilHuBERT/processed_dataset")
dataset = dataset.train_test_split(test_size=0.1)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

Loading dataset from disk:   0%|          | 0/47 [00:00<?, ?it/s]

In [3]:
train_dataset

Dataset({
    features: ['input_values', 'labels'],
    num_rows: 25670
})

#### Simulating FL setting: clients

In [7]:
def split_into_clients_nonuniform(dataset, num_clients, min_frac, max_frac):
    size = len(dataset)

    proportions = np.random.uniform(min_frac, max_frac, size=num_clients)
    proportions = proportions / proportions.sum()
    sizes = (proportions * size).astype(int)

    # Ensure total size matches
    diff = size - sizes.sum()
    sizes[0] += diff

    # Client splits
    indices = np.arange(size)
    np.random.shuffle(indices)

    client_datasets = []
    start = 0
    for s in sizes:
        end = start + s
        client_datasets.append(dataset.select(indices[start:end].tolist()))
        start = end

    return client_datasets
client_datasets = split_into_clients_nonuniform(train_dataset, num_clients=20, min_frac=0.01, max_frac=0.1)

In [None]:
processor = AutoProcessor.from_pretrained("/scratch/pippalin2/jupyter/GMM-DistilHuBERT/processor")
base_model = AutoModelForCTC.from_pretrained("ntu-spml/distilhubert").to("cuda")

class DataCollatorCTCWithPadding:
    def __init__(self, processor, padding=True):
        self.processor = processor
        self.padding = padding

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": f["input_values"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        batch = self.processor.feature_extractor.pad(input_features, padding=self.padding, return_tensors="pt")
        labels_batch = self.processor.tokenizer.pad(label_features, padding=self.padding, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor)

# ========== 5. Local Training ==========

def local_finetune(model, dataset, processor, collator, output_dir):
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=2,
        num_train_epochs=1,
        logging_steps=10,
        save_steps=5000,
        learning_rate=1e-4,
        fp16=True,
        report_to="none"
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=processor,
        data_collator=collator
    )
    trainer.train()
    return model.state_dict()

# ========== 6. FedAvg Aggregation ==========

def fed_avg(state_dicts: List[Dict]):
    avg_dict = copy.deepcopy(state_dicts[0])
    for key in avg_dict:
        for i in range(1, len(state_dicts)):
            avg_dict[key] += state_dicts[i][key]
        avg_dict[key] = avg_dict[key] / len(state_dicts)
    return avg_dict

# ========== 7. Federated Learning Loop ==========

global_model = copy.deepcopy(base_model)

for round_num in range(3):
    print(f"Round {round_num + 1}")
    weights = []
    for i, client_data in enumerate(client_datasets):
        print(f"  Client {i+1}")
        local_model = copy.deepcopy(global_model)
        state = local_finetune(local_model, client_data, processor, data_collator, f"./client{i}_round{round_num}")
        weights.append(state)
    avg_weights = fed_avg(weights)
    global_model.load_state_dict(avg_weights)

# ========== 8. Save and Evaluate ==========

global_model.save_pretrained("./federated_distilhubert_asr")
processor.save_pretrained("./federated_distilhubert_asr")
