In [None]:
# Differentially Private Federated ASR with DP-FedAvg (using f̃_c estimator)
# Based on: McMahan et al., 2018 (ICLR) and user's ASR FedAvg setup

import torch
import random
import numpy as np
import copy
from transformers import TrainingArguments, Trainer
from torch.nn.utils import clip_grad_norm_
from collections import defaultdict
from opacus.accountants import RDPAccountant
import matplotlib.pyplot as plt
import os

# --- Configurable Constants ---
q = 0.05  # Sample rate
S = 15.0  # Clipping norm
Wmin = 20  # Minimum total weight in denominator
rounds = 10
noise_levels = [0.006, 0.012, 0.024]  # Lambda values (stddev for Gaussian noise)

# --- Placeholder: Client Data Setup & Helper Functions ---
# Assume: client_datasets and eval_dataset already loaded
# Assume: global_model, processor, data_collator are initialized

# --- Clipping function: Flat Clip (Eq. 1 from paper) ---
def flat_clip(delta_k, S):
    norm = torch.norm(delta_k)
    if norm > S:
        delta_k = delta_k * (S / norm)
    return delta_k

# --- DP-FedAvg update estimator (f̃_c) ---
def dp_aggregate_f_c(updates, weights):
    W = sum(weights)
    denom = max(q * Wmin, W)
    agg = sum([w * u for u, w in zip(updates, weights)]) / denom
    return agg

# --- Train loop with DP noise ---
def train_dp_fedavg(global_model, noise_std, label):
    model = copy.deepcopy(global_model)
    accountant = RDPAccountant()
    accs = []

    for rnd in range(rounds):
        selected = random.sample(client_datasets, int(q * len(client_datasets)))
        updates = []
        weights = []

        for client_data in selected:
            local_model = copy.deepcopy(model)
            state_before = copy.deepcopy(local_model.state_dict())

            state_after, _ = local_finetune(local_model, client_data, processor, data_collator, compute_metrics, "tmp")
            delta = {k: state_after[k] - state_before[k] for k in state_before}
            clipped_delta = {k: flat_clip(v, S) for k, v in delta.items()}
            updates.append(clipped_delta)
            weights.append(1.0)  # Assume equal weights for now

        # Aggregate updates with noise
        avg_update = dp_aggregate_f_c(updates, weights)
        noise = {k: torch.normal(0, noise_std * S, size=v.size()).to(v.device) for k, v in avg_update.items()}
        noisy_update = {k: avg_update[k] + noise[k] for k in avg_update}

        new_state = model.state_dict()
        for k in new_state:
            new_state[k] += noisy_update[k]
        model.load_state_dict(new_state)

        # Evaluate
        metrics = evaluate_global_model(model, eval_dataset)
        accs.append(metrics['eval_wer'])
        print(f"Round {rnd + 1}, Noise: {noise_std}, WER: {metrics['eval_wer']:.3f}")

    return accs

# --- Run DP FedAvg for Different Noise Settings ---
results = {}
for sigma in noise_levels:
    acc_curve = train_dp_fedavg(global_model, sigma, label=f"DP-{sigma}")
    results[f"DP-sigma={sigma}"] = acc_curve

# --- Baseline (non-private) ---
baseline_metrics = []
global_model = copy.deepcopy(base_model)
for round_num in range(rounds):
    selected_clients = random.sample(client_datasets, k=5)
    weights = []
    for client_data in selected_clients:
        local_model = copy.deepcopy(global_model)
        delta, _ = local_finetune(local_model, client_data, processor, data_collator, compute_metrics, "tmp")
        weights.append(delta)

    avg_update = fed_avg(weights)
    global_model.load_state_dict(avg_update)
    metrics = evaluate_global_model(global_model, eval_dataset)
    baseline_metrics.append(metrics['eval_wer'])
    print(f"Baseline Round {round_num + 1}, WER: {metrics['eval_wer']:.3f}")

results["Baseline"] = baseline_metrics

# --- Plotting Noise Scheduling & Baseline Comparison ---
plt.figure(figsize=(10, 6))
for label, vals in results.items():
    plt.plot(range(1, rounds + 1), vals, label=label)
plt.xlabel("Communication Round")
plt.ylabel("WER")
plt.title("Noise Scheduling vs Baseline in DP-FedAvg")
plt.legend()
plt.grid(True)
plt.savefig("dp_fedavg_wer_comparison.png")
plt.show()
