In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("quadeer15sh/lfw-facial-recognition")

print("Path to dataset files:", path)

In [2]:
# FedAvg with ResNet-18 + Pairwise Additive Masking Secure Aggregation (simulation)
# Fixes the gradient-size mismatch during reconstruction by ensuring the captured
# per-parameter signal is ordered to match model.parameters() (parameters only,
# excluding buffers like running_mean/running_var).
#
# Usage: adjust DATA_ROOT and FOLDERS to point to your dataset.
# Note: This is a simulation of secure aggregation for research/experiments.
# For production, use a vetted secure aggregation protocol (Bonawitz et al.) and proper key exchange.
#
# Requirements: torch, torchvision, numpy, pillow, scikit-image, tqdm, sklearn

import os
import random
import math
import json
import hashlib
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics.pairwise import cosine_similarity

# -----------------------------
# Experiment hyperparameters
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Federated learning variables
NUM_CLIENTS = 50
CLIENTS_PER_ROUND = 10
ROUNDS = 100
LOCAL_EPOCHS = 1
LOCAL_BATCH_SIZE = 1   # keep small for strong leakage baseline (can be increased)

CLIENT_LR = 0.01
MOMENTUM = 0.9

# Model / data variables
IMAGE_SIZE = 224
INPUT_CHANNELS = 3
NUM_CLASSES = None

# Attack variables (IDLG with L-BFGS)
ATTACK_USE_LBFGS = True
ATTACK_LBFGS_MAX_ITER = 300
ATTACK_LBFGS_LINESEARCH = True
ATTACK_INIT = "random"
ATTACK_TV_WEIGHT = 1e-4
ATTACK_GRAD_MATCH_WEIGHT = 1.0
ATTACK_COSINE_THRESHOLD = 0.6
ATTACK_EVERY_N_ROUNDS = 1

RECON_SHAPE = (1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

PRINT_EVERY = 5
SAVE_RECON_DIR = "./reconstructions_secureagg_fixed"
os.makedirs(SAVE_RECON_DIR, exist_ok=True)

# Determinism
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------------
# Dataset utilities
# -----------------------------
DATA_ROOT = "/kaggle/input/lfw-facial-recognition/Face Recognition"
FOLDERS = ["Faces"]

transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def gather_image_paths(root, folders):
    paths = []
    for f in folders:
        folder = os.path.join(root, f)
        if not os.path.isdir(folder):
            continue
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                paths.append(os.path.join(folder, fname))
    return sorted(paths)

all_image_paths = gather_image_paths(DATA_ROOT, FOLDERS)
if len(all_image_paths) == 0:
    raise RuntimeError(f"No images found in {FOLDERS} under {DATA_ROOT}. Check paths.")

def identity_from_filename(path):
    name = os.path.basename(path)
    base = os.path.splitext(name)[0]
    parts = base.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:-1])
    return base

id_to_paths = defaultdict(list)
for p in all_image_paths:
    id_ = identity_from_filename(p)
    id_to_paths[id_].append(p)

valid_id_to_paths = {k: v for k, v in id_to_paths.items() if len(v) >= 2}
if len(valid_id_to_paths) == 0:
    raise RuntimeError("No identities with >=2 images found. Check filename format or dataset.")

all_valid_paths = []
all_valid_labels = []
unique_ids = sorted(valid_id_to_paths.keys())
label_map = {id_: idx for idx, id_ in enumerate(unique_ids)}
for id_, paths in valid_id_to_paths.items():
    for p in paths:
        all_valid_paths.append(p)
        all_valid_labels.append(label_map[id_])

NUM_CLASSES = len(unique_ids)
print(f"Found {len(all_valid_paths)} images across {NUM_CLASSES} identities (>=2 images each).")

# -----------------------------
# PyTorch Dataset and client partitioning
# -----------------------------
class FaceImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(self.labels[idx])
        return img, label, p

full_dataset = FaceImageDataset(all_valid_paths, all_valid_labels, transform=transform)

# Partition identities across clients (non-IID)
client_id_to_indices = {c: [] for c in range(NUM_CLIENTS)}
identities = list(unique_ids)
random.shuffle(identities)
for i, id_ in enumerate(identities):
    client_id = i % NUM_CLIENTS
    for p in valid_id_to_paths[id_]:
        idx = all_valid_paths.index(p)
        client_id_to_indices[client_id].append(idx)

num_nonempty = sum(1 for c in client_id_to_indices if len(client_id_to_indices[c]) > 0)
print(f"Assigned identities to {num_nonempty} non-empty clients out of {NUM_CLIENTS} total clients.")

client_loaders = {}
for c in range(NUM_CLIENTS):
    inds = client_id_to_indices[c]
    if len(inds) == 0:
        client_loaders[c] = None
        continue
    sub_paths = [all_valid_paths[i] for i in inds]
    sub_labels = [all_valid_labels[i] for i in inds]
    ds = FaceImageDataset(sub_paths, sub_labels, transform=transform)
    loader = DataLoader(ds, batch_size=LOCAL_BATCH_SIZE, shuffle=True)
    client_loaders[c] = loader

# Held-out test loader (one image per identity)
test_paths = []
test_labels = []
for id_, paths in valid_id_to_paths.items():
    p = paths[-1]
    test_paths.append(p)
    test_labels.append(label_map[id_])
test_ds = FaceImageDataset(test_paths, test_labels, transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# -----------------------------
# Model: ResNet-18 wrapper
# -----------------------------
class ResNet18Wrapper(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        backbone = models.resnet18(pretrained=pretrained)
        in_features = backbone.fc.in_features
        backbone.fc = nn.Identity()
        self.backbone = backbone
        self.classifier = nn.Linear(in_features, num_classes)
    def forward(self, x, return_features=False):
        feats = self.backbone(x)
        logits = self.classifier(feats)
        if return_features:
            return feats, logits
        return logits

global_model = ResNet18Wrapper(NUM_CLASSES, pretrained=True).to(DEVICE)

# Store global_state as float tensors
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

def model_from_state(state_dict):
    m = ResNet18Wrapper(NUM_CLASSES, pretrained=False).to(DEVICE)
    sd = {k: v.to(DEVICE) for k, v in state_dict.items()}
    m.load_state_dict(sd)
    return m

# -----------------------------
# Utility: ensure image tensors have shape (B,C,H,W)
# -----------------------------
def ensure_batch_image_shape(x):
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            return x.unsqueeze(0)
        if x.dim() == 5 and x.size(1) == 1:
            return x.squeeze(1)
        if x.dim() == 5 and x.size(2) == 1:
            return x.squeeze(2)
        return x
    else:
        return x

# -----------------------------
# Mask generation utilities (deterministic PRNG from seed)
# -----------------------------
def seed_from_pair(i, j, round_idx):
    s = f"{min(i,j)}_{max(i,j)}_{round_idx}"
    h = hashlib.sha256(s.encode("utf-8")).digest()
    seed = int.from_bytes(h[:8], byteorder="big", signed=False)
    return seed

def generate_mask_for_state(state_template, seed):
    gen = torch.Generator()
    gen.manual_seed(seed & ((1 << 63) - 1))
    mask = {}
    for k, v in state_template.items():
        m = torch.randn(v.shape, generator=gen, dtype=torch.float32)
        mask[k] = m
    return mask

def add_state_dicts(a, b):
    return {k: (a[k] + b[k]) for k in a.keys()}

def sub_state_dicts(a, b):
    return {k: (a[k] - b[k]) for k in a.keys()}

def zero_state_like(template):
    return {k: torch.zeros_like(v) for k, v in template.items()}

# -----------------------------
# Client local training (no DP) -> produce delta
# -----------------------------
def get_model_copy_for_client(state_dict):
    return model_from_state(state_dict)

def client_local_update_compute_delta(client_id, global_state):
    loader = client_loaders[client_id]
    if loader is None:
        return None, None, None

    local_model = get_model_copy_for_client(global_state)
    local_model.train()
    opt = torch.optim.SGD(local_model.parameters(), lr=CLIENT_LR, momentum=MOMENTUM)
    captured_label = None
    captured_image = None

    loss_fn = nn.CrossEntropyLoss(reduction='mean')

    for epoch in range(LOCAL_EPOCHS):
        for x, y, pth in loader:
            x = ensure_batch_image_shape(x)
            if x.dim() == 5 and x.size(1) == 1:
                x = x.squeeze(1)
            x = x.to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)

            opt.zero_grad()
            logits = local_model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            if captured_label is None:
                if y.dim() == 0:
                    captured_label = y.detach().cpu().clone().unsqueeze(0)
                else:
                    captured_label = y.detach().cpu().clone().view(-1)[0:1]
                captured_image = x.detach().cpu().clone()[0:1]
            opt.step()
            if LOCAL_EPOCHS == 1:
                break

    new_state = {k: v.cpu().to(torch.float32) for k, v in local_model.state_dict().items()}
    delta = {k: new_state[k] - global_state[k] for k in new_state.keys()}
    return delta, captured_label, captured_image

# -----------------------------
# Secure aggregation: client-side masking
# -----------------------------
def client_mask_delta(client_id, delta_state, participants, round_idx):
    masked = {k: delta_state[k].clone() for k in delta_state.keys()}
    mask_sum = zero_state_like(delta_state)
    for j in participants:
        if j == client_id:
            continue
        seed = seed_from_pair(client_id, j, round_idx)
        pair_mask = generate_mask_for_state(delta_state, seed)
        if client_id < j:
            for k in masked.keys():
                masked[k] = masked[k] + pair_mask[k]
                mask_sum[k] = mask_sum[k] + pair_mask[k]
        else:
            for k in masked.keys():
                masked[k] = masked[k] - pair_mask[k]
                mask_sum[k] = mask_sum[k] - pair_mask[k]
    return masked, mask_sum

# -----------------------------
# Server aggregation with mask cancellation (simulation)
# -----------------------------
def server_aggregate_masked(masked_deltas_dict, participants, round_idx, dropout_list=None):
    template = next(iter(masked_deltas_dict.values()))
    sum_masked = zero_state_like(template)
    for cid, md in masked_deltas_dict.items():
        for k in sum_masked.keys():
            sum_masked[k] = sum_masked[k] + md[k]

    present_clients = list(masked_deltas_dict.keys())
    present_mask_sums = zero_state_like(template)
    for i in present_clients:
        mask_sum_i = zero_state_like(template)
        for j in participants:
            if j == i:
                continue
            seed = seed_from_pair(i, j, round_idx)
            pair_mask = generate_mask_for_state(template, seed)
            if i < j:
                for k in mask_sum_i.keys():
                    mask_sum_i[k] = mask_sum_i[k] + pair_mask[k]
            else:
                for k in mask_sum_i.keys():
                    mask_sum_i[k] = mask_sum_i[k] - pair_mask[k]
        for k in present_mask_sums.keys():
            present_mask_sums[k] = present_mask_sums[k] + mask_sum_i[k]

    sum_present_deltas = {k: (sum_masked[k] - present_mask_sums[k]) for k in sum_masked.keys()}
    return sum_present_deltas

# -----------------------------
# Reconstruction utilities (unchanged)
# -----------------------------
def flatten_grads(grad_list):
    vecs = []
    for g in grad_list:
        if g is None:
            continue
        if not g.is_cuda:
            g = g.to(DEVICE)
        vecs.append(g.view(-1))
    if len(vecs) == 0:
        return torch.tensor([], device=DEVICE)
    return torch.cat(vecs)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) + \
           torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))

def assert_shapes_for_recon(x_init, observed_label, observed_grads):
    if not isinstance(x_init, torch.Tensor):
        raise RuntimeError("x_init must be a torch.Tensor")
    if x_init.dim() != 4 or x_init.size(0) != 1:
        raise RuntimeError(f"Reconstruction input must be (1,C,H,W), got {x_init.shape}")
    if observed_label is None:
        raise RuntimeError("Observed label is None")
    if observed_label.dim() != 1 or observed_label.size(0) != 1:
        raise RuntimeError(f"Observed label must be shape (1,), got {observed_label.shape}")
    flat = flatten_grads(observed_grads)
    if flat.numel() == 0:
        raise RuntimeError("Observed grads flattened to zero length; check grad capture")

def reconstruct_idlg_lbfgs(model_state_for_attack, observed_grads_cpu, observed_label_cpu, iters=ATTACK_LBFGS_MAX_ITER):
    if observed_label_cpu is None:
        raise ValueError("No observed label provided for reconstruction")
    if observed_label_cpu.dim() == 0:
        observed_label_cpu = observed_label_cpu.unsqueeze(0).long()
    else:
        observed_label_cpu = observed_label_cpu.view(-1)[0:1].long()

    model = model_from_state(model_state_for_attack)
    model.eval()

    obs = [g.to(DEVICE) if g is not None else None for g in observed_grads_cpu]
    obs_flat = flatten_grads(obs).detach()

    if ATTACK_INIT == "random":
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)
    else:
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)

    try:
        assert_shapes_for_recon(x_hat, observed_label_cpu, observed_grads_cpu)
    except Exception as e:
        raise RuntimeError(f"Pre-reconstruction shape check failed: {e}")

    optimizer = torch.optim.LBFGS([x_hat], max_iter=iters, line_search_fn='strong_wolfe' if ATTACK_LBFGS_LINESEARCH else None)

    def closure():
        optimizer.zero_grad()
        x_in = ensure_batch_image_shape(x_hat)
        logits = model(x_in)
        loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
        grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
        grads_hat_flat = flatten_grads(grads_hat)
        if obs_flat.numel() == 0 or grads_hat_flat.numel() == 0:
            loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            loss.backward()
            return loss
        loss_match = F.mse_loss(grads_hat_flat, obs_flat)
        loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception as e:
        print("LBFGS failed, falling back to Adam:", e)
        opt_adam = torch.optim.Adam([x_hat], lr=0.05)
        for _ in range(500):
            opt_adam.zero_grad()
            x_in = ensure_batch_image_shape(x_hat)
            logits = model(x_in)
            loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
            grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
            grads_hat_flat = flatten_grads(grads_hat)
            loss_match = F.mse_loss(grads_hat_flat, obs_flat)
            loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
            loss.backward()
            opt_adam.step()
            with torch.no_grad():
                x_hat.clamp_(-1, 1)

    with torch.no_grad():
        x_rec = x_hat.clamp_(-1, 1).detach().cpu()
    return x_rec

# -----------------------------
# Evaluation metrics (unchanged)
# -----------------------------
def denorm(img_tensor):
    img = img_tensor.clone()
    img = (img * 0.5) + 0.5
    return img.clamp(0,1)

def compute_ssim(img1, img2):
    a = img1.squeeze(0).permute(1,2,0).numpy()
    b = img2.squeeze(0).permute(1,2,0).numpy()
    s = 0.0
    for ch in range(a.shape[2]):
        s += ssim(a[:,:,ch], b[:,:,ch], data_range=1.0)
    return s / a.shape[2]

def compute_cosine_similarity_feature(model_state, img1, img2):
    model = model_from_state(model_state)
    model.eval()
    with torch.no_grad():
        f1, _ = model(ensure_batch_image_shape(img1).to(DEVICE), return_features=True)
        f2, _ = model(ensure_batch_image_shape(img2).to(DEVICE), return_features=True)
    f1 = f1.cpu().numpy()
    f2 = f2.cpu().numpy()
    return float(cosine_similarity(f1, f2)[0,0])

def evaluate_accuracy(model_state, dataloader):
    model = model_from_state(model_state)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x = ensure_batch_image_shape(x).to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

# -----------------------------
# Training loop with secure aggregation (fixed capture ordering)
# -----------------------------
attack_logs = []
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

for rnd in range(ROUNDS):
    available_clients = [c for c in range(NUM_CLIENTS) if client_loaders[c] is not None]
    sampled = random.sample(available_clients, min(CLIENTS_PER_ROUND, len(available_clients)))
    sampled = sorted(sampled)  # deterministic ordering for mask derivation
    client_masked_deltas = {}
    client_mask_sums = {}
    client_sizes = []

    victim_client = sampled[0] if len(sampled) > 0 else None
    captured_grad_for_attack = None
    captured_label = None
    captured_image = None
    model_state_sent_to_client = None

    # Each client computes local delta
    for c in sampled:
        model_state_sent_to_client = {k: v.clone() for k, v in global_state.items()}
        delta, cap_label, cap_image = client_local_update_compute_delta(c, model_state_sent_to_client)
        if delta is None:
            continue

        # For attack simulation: capture victim's uncloaked delta in the same order as model.parameters()
        if c == victim_client:
            # Build a model instance to get parameter names and order
            tmp_model = model_from_state(model_state_sent_to_client)
            param_names = [name for name, _ in tmp_model.named_parameters()]  # ordered
            # state_dict keys include parameters and buffers; we select only parameter keys matching param_names
            # state_dict keys for parameters are exactly the param_names
            captured_grad_for_attack = []
            for pname in param_names:
                if pname in delta:
                    captured_grad_for_attack.append(delta[pname].clone())
                else:
                    # fallback: if key missing, append None
                    captured_grad_for_attack.append(None)
            captured_label = cap_label
            captured_image = cap_image

        masked_delta, mask_sum = client_mask_delta(c, delta, sampled, rnd)
        client_masked_deltas[c] = masked_delta
        client_mask_sums[c] = mask_sum
        size = len(client_loaders[c].dataset) if client_loaders[c] is not None else 0
        client_sizes.append(size)

    if len(client_masked_deltas) == 0:
        print("No client updates this round.")
        continue

    aggregated_delta = server_aggregate_masked(client_masked_deltas, sampled, rnd, dropout_list=None)

    total_samples = sum(client_sizes) if sum(client_sizes) > 0 else len(client_masked_deltas)
    avg_delta = {k: (aggregated_delta[k] / float(total_samples)) for k in aggregated_delta.keys()}

    new_global_state = {}
    for k in global_state.keys():
        new_global_state[k] = (global_state[k] + avg_delta[k]).cpu()
    global_state = new_global_state

    if rnd % PRINT_EVERY == 0 or rnd == ROUNDS - 1:
        acc = evaluate_accuracy(global_state, test_loader)
        print(f"Round {rnd:03d}/{ROUNDS:03d} - Test accuracy: {acc:.4f} - sampled {len(sampled)} clients")

    # Attack evaluation: reconstruct from the captured per-parameter delta (parameters-only, ordered)
    if (rnd % ATTACK_EVERY_N_ROUNDS == 0) and (captured_grad_for_attack is not None):
        try:
            # captured_grad_for_attack is a list of tensors (or None) matching model.parameters() order
            reconstructed = reconstruct_idlg_lbfgs(model_state_sent_to_client, captured_grad_for_attack, captured_label, iters=ATTACK_LBFGS_MAX_ITER)
        except Exception as e:
            print("Reconstruction failed:", e)
            reconstructed = None

        if reconstructed is not None:
            rec_den = denorm(reconstructed)
            true_den = denorm(captured_image.cpu())

            ssim_val = compute_ssim(rec_den, true_den)
            cos_val = compute_cosine_similarity_feature(model_state_sent_to_client, rec_den.unsqueeze(0), true_den.unsqueeze(0))
            success = float(cos_val >= ATTACK_COSINE_THRESHOLD)

            def tensor_to_pil(img_tensor, fname):
                arr = (img_tensor.squeeze(0).permute(1,2,0).numpy() * 255).astype(np.uint8)
                Image.fromarray(arr).save(fname)

            fname_rec = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_rec.png")
            fname_true = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_true.png")
            try:
                tensor_to_pil(rec_den, fname_rec)
                tensor_to_pil(true_den, fname_true)
            except Exception:
                np.save(fname_rec + ".npy", rec_den.numpy())
                np.save(fname_true + ".npy", true_den.numpy())

            log_entry = {
                "round": int(rnd),
                "victim_client": int(victim_client),
                "ssim": float(ssim_val),
                "cosine": float(cos_val),
                "success": float(success),
                "recon_path": fname_rec,
                "true_path": fname_true,
                "label": int(captured_label.item())
            }
            attack_logs.append(log_entry)
            print(f"[Attack-sim] Round {rnd} client {victim_client}: SSIM={ssim_val:.4f}, Cos={cos_val:.4f}, Success={success}")

# -----------------------------
# Final aggregated attack metrics and save logs
# -----------------------------
total_attacks = len(attack_logs)
if total_attacks > 0:
    avg_ssim = sum([e["ssim"] for e in attack_logs]) / total_attacks
    avg_cos = sum([e["cosine"] for e in attack_logs]) / total_attacks
    success_rate = sum([e["success"] for e in attack_logs]) / total_attacks
else:
    avg_ssim = avg_cos = success_rate = 0.0

print("=== Attack summary (simulated single-client signal) ===")
print(f"Attacks run: {total_attacks}")
print(f"Avg SSIM: {avg_ssim:.4f}")
print(f"Avg Cosine: {avg_cos:.4f}")
print(f"Success rate (cos>={ATTACK_COSINE_THRESHOLD}): {success_rate:.3f}")

with open(os.path.join(SAVE_RECON_DIR, "attack_logs_secureagg_fixed.json"), "w") as f:
    json.dump(attack_logs, f, indent=2)

print("Reconstruction images and logs saved to:", SAVE_RECON_DIR)


Found 9164 images across 1680 identities (>=2 images each).
Assigned identities to 50 non-empty clients out of 50 total clients.
Round 000/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 0 client 6: SSIM=0.0060, Cos=0.2803, Success=0.0




[Attack-sim] Round 1 client 4: SSIM=0.0057, Cos=0.3170, Success=0.0




[Attack-sim] Round 2 client 9: SSIM=0.0041, Cos=0.3057, Success=0.0




[Attack-sim] Round 3 client 3: SSIM=0.0047, Cos=0.3486, Success=0.0




[Attack-sim] Round 4 client 2: SSIM=0.0034, Cos=0.3801, Success=0.0




Round 005/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 5 client 2: SSIM=0.0062, Cos=0.2690, Success=0.0




[Attack-sim] Round 6 client 3: SSIM=0.0048, Cos=0.3135, Success=0.0




[Attack-sim] Round 7 client 2: SSIM=0.0055, Cos=0.3151, Success=0.0




[Attack-sim] Round 8 client 10: SSIM=0.0068, Cos=0.2793, Success=0.0




[Attack-sim] Round 9 client 3: SSIM=0.0062, Cos=0.3168, Success=0.0




Round 010/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 10 client 6: SSIM=0.0023, Cos=0.3247, Success=0.0




[Attack-sim] Round 11 client 5: SSIM=0.0026, Cos=0.3594, Success=0.0




[Attack-sim] Round 12 client 5: SSIM=0.0044, Cos=0.2824, Success=0.0




[Attack-sim] Round 13 client 1: SSIM=0.0047, Cos=0.3391, Success=0.0




[Attack-sim] Round 14 client 4: SSIM=0.0042, Cos=0.3796, Success=0.0




Round 015/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 15 client 2: SSIM=0.0054, Cos=0.3388, Success=0.0




[Attack-sim] Round 16 client 2: SSIM=0.0059, Cos=0.3134, Success=0.0




[Attack-sim] Round 17 client 14: SSIM=0.0034, Cos=0.3371, Success=0.0




[Attack-sim] Round 18 client 12: SSIM=0.0049, Cos=0.2764, Success=0.0




[Attack-sim] Round 19 client 8: SSIM=0.0063, Cos=0.2390, Success=0.0




Round 020/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 20 client 3: SSIM=0.0047, Cos=0.3072, Success=0.0




[Attack-sim] Round 21 client 5: SSIM=0.0059, Cos=0.3318, Success=0.0




[Attack-sim] Round 22 client 1: SSIM=0.0045, Cos=0.2919, Success=0.0




[Attack-sim] Round 23 client 1: SSIM=0.0065, Cos=0.2863, Success=0.0




[Attack-sim] Round 24 client 0: SSIM=0.0032, Cos=0.3200, Success=0.0




Round 025/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 25 client 1: SSIM=0.0053, Cos=0.2868, Success=0.0




[Attack-sim] Round 26 client 3: SSIM=0.0041, Cos=0.3018, Success=0.0




[Attack-sim] Round 27 client 1: SSIM=0.0055, Cos=0.3525, Success=0.0




[Attack-sim] Round 28 client 0: SSIM=0.0048, Cos=0.2961, Success=0.0




[Attack-sim] Round 29 client 1: SSIM=0.0046, Cos=0.3537, Success=0.0




Round 030/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 30 client 1: SSIM=0.0059, Cos=0.2910, Success=0.0




[Attack-sim] Round 31 client 2: SSIM=0.0062, Cos=0.3229, Success=0.0




[Attack-sim] Round 32 client 8: SSIM=0.0062, Cos=0.3084, Success=0.0




[Attack-sim] Round 33 client 2: SSIM=0.0025, Cos=0.3207, Success=0.0




[Attack-sim] Round 34 client 2: SSIM=0.0059, Cos=0.2620, Success=0.0




Round 035/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 35 client 3: SSIM=0.0032, Cos=0.2774, Success=0.0




[Attack-sim] Round 36 client 5: SSIM=0.0048, Cos=0.2708, Success=0.0




[Attack-sim] Round 37 client 5: SSIM=0.0035, Cos=0.3265, Success=0.0




[Attack-sim] Round 38 client 5: SSIM=0.0034, Cos=0.2957, Success=0.0




[Attack-sim] Round 39 client 2: SSIM=0.0050, Cos=0.3262, Success=0.0




Round 040/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 40 client 1: SSIM=0.0054, Cos=0.3265, Success=0.0




[Attack-sim] Round 41 client 2: SSIM=0.0037, Cos=0.3052, Success=0.0




[Attack-sim] Round 42 client 2: SSIM=0.0042, Cos=0.3483, Success=0.0




[Attack-sim] Round 43 client 4: SSIM=0.0050, Cos=0.2701, Success=0.0




[Attack-sim] Round 44 client 1: SSIM=0.0045, Cos=0.3737, Success=0.0




Round 045/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 45 client 7: SSIM=0.0043, Cos=0.3454, Success=0.0




[Attack-sim] Round 46 client 4: SSIM=0.0063, Cos=0.3080, Success=0.0




[Attack-sim] Round 47 client 10: SSIM=0.0059, Cos=0.3211, Success=0.0




[Attack-sim] Round 48 client 1: SSIM=0.0051, Cos=0.3265, Success=0.0




[Attack-sim] Round 49 client 4: SSIM=0.0071, Cos=0.2930, Success=0.0




Round 050/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 50 client 4: SSIM=0.0030, Cos=0.3213, Success=0.0




[Attack-sim] Round 51 client 8: SSIM=0.0058, Cos=0.2621, Success=0.0




[Attack-sim] Round 52 client 3: SSIM=0.0059, Cos=0.3142, Success=0.0




[Attack-sim] Round 53 client 5: SSIM=0.0054, Cos=0.3345, Success=0.0




[Attack-sim] Round 54 client 5: SSIM=0.0041, Cos=0.3001, Success=0.0




Round 055/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 55 client 2: SSIM=0.0049, Cos=0.2927, Success=0.0




[Attack-sim] Round 56 client 1: SSIM=0.0051, Cos=0.2992, Success=0.0




[Attack-sim] Round 57 client 1: SSIM=0.0039, Cos=0.3341, Success=0.0




[Attack-sim] Round 58 client 3: SSIM=0.0028, Cos=0.3310, Success=0.0




[Attack-sim] Round 59 client 9: SSIM=0.0038, Cos=0.3522, Success=0.0




Round 060/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 60 client 6: SSIM=0.0057, Cos=0.2907, Success=0.0




[Attack-sim] Round 61 client 3: SSIM=0.0055, Cos=0.3647, Success=0.0




[Attack-sim] Round 62 client 10: SSIM=0.0042, Cos=0.2740, Success=0.0




[Attack-sim] Round 63 client 1: SSIM=0.0043, Cos=0.3294, Success=0.0




[Attack-sim] Round 64 client 3: SSIM=0.0050, Cos=0.3058, Success=0.0




Round 065/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 65 client 0: SSIM=0.0056, Cos=0.2743, Success=0.0




[Attack-sim] Round 66 client 3: SSIM=0.0031, Cos=0.3152, Success=0.0




[Attack-sim] Round 67 client 6: SSIM=0.0044, Cos=0.3081, Success=0.0




[Attack-sim] Round 68 client 9: SSIM=0.0050, Cos=0.3186, Success=0.0




[Attack-sim] Round 69 client 2: SSIM=0.0028, Cos=0.3620, Success=0.0




Round 070/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 70 client 8: SSIM=0.0069, Cos=0.3258, Success=0.0




[Attack-sim] Round 71 client 1: SSIM=0.0059, Cos=0.3027, Success=0.0




[Attack-sim] Round 72 client 1: SSIM=0.0063, Cos=0.2630, Success=0.0




[Attack-sim] Round 73 client 3: SSIM=0.0032, Cos=0.3249, Success=0.0




[Attack-sim] Round 74 client 2: SSIM=0.0045, Cos=0.3057, Success=0.0




Round 075/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 75 client 1: SSIM=0.0047, Cos=0.2817, Success=0.0




[Attack-sim] Round 76 client 3: SSIM=0.0053, Cos=0.3649, Success=0.0




[Attack-sim] Round 77 client 11: SSIM=0.0056, Cos=0.2796, Success=0.0




[Attack-sim] Round 78 client 12: SSIM=0.0064, Cos=0.3581, Success=0.0




[Attack-sim] Round 79 client 13: SSIM=0.0038, Cos=0.2887, Success=0.0




Round 080/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 80 client 7: SSIM=0.0040, Cos=0.2798, Success=0.0




[Attack-sim] Round 81 client 2: SSIM=0.0060, Cos=0.3304, Success=0.0




[Attack-sim] Round 82 client 6: SSIM=0.0053, Cos=0.4162, Success=0.0




[Attack-sim] Round 83 client 3: SSIM=0.0060, Cos=0.3572, Success=0.0




[Attack-sim] Round 84 client 0: SSIM=0.0033, Cos=0.2820, Success=0.0




Round 085/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 85 client 10: SSIM=0.0037, Cos=0.3566, Success=0.0




[Attack-sim] Round 86 client 2: SSIM=0.0065, Cos=0.2722, Success=0.0




[Attack-sim] Round 87 client 0: SSIM=0.0041, Cos=0.2869, Success=0.0




[Attack-sim] Round 88 client 6: SSIM=0.0052, Cos=0.3660, Success=0.0




[Attack-sim] Round 89 client 7: SSIM=0.0051, Cos=0.3589, Success=0.0




Round 090/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 90 client 0: SSIM=0.0053, Cos=0.2675, Success=0.0




[Attack-sim] Round 91 client 5: SSIM=0.0057, Cos=0.3780, Success=0.0




[Attack-sim] Round 92 client 0: SSIM=0.0062, Cos=0.3221, Success=0.0




[Attack-sim] Round 93 client 0: SSIM=0.0053, Cos=0.3174, Success=0.0




[Attack-sim] Round 94 client 3: SSIM=0.0060, Cos=0.2533, Success=0.0




Round 095/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 95 client 8: SSIM=0.0046, Cos=0.2942, Success=0.0




[Attack-sim] Round 96 client 13: SSIM=0.0039, Cos=0.2782, Success=0.0




[Attack-sim] Round 97 client 4: SSIM=0.0055, Cos=0.2642, Success=0.0




[Attack-sim] Round 98 client 3: SSIM=0.0043, Cos=0.3234, Success=0.0




Round 099/100 - Test accuracy: 0.0012 - sampled 10 clients
[Attack-sim] Round 99 client 5: SSIM=0.0050, Cos=0.3287, Success=0.0
=== Attack summary (simulated single-client signal) ===
Attacks run: 100
Avg SSIM: 0.0049
Avg Cosine: 0.3138
Success rate (cos>=0.6): 0.000
Reconstruction images and logs saved to: ./reconstructions_secureagg_fixed


# 5 client

In [4]:
# FedAvg with ResNet-18 + Pairwise Additive Masking Secure Aggregation (simulation)
# Fixes the gradient-size mismatch during reconstruction by ensuring the captured
# per-parameter signal is ordered to match model.parameters() (parameters only,
# excluding buffers like running_mean/running_var).
#
# Usage: adjust DATA_ROOT and FOLDERS to point to your dataset.
# Note: This is a simulation of secure aggregation for research/experiments.
# For production, use a vetted secure aggregation protocol (Bonawitz et al.) and proper key exchange.
#
# Requirements: torch, torchvision, numpy, pillow, scikit-image, tqdm, sklearn

import os
import random
import math
import json
import hashlib
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics.pairwise import cosine_similarity

# -----------------------------
# Experiment hyperparameters
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Federated learning variables
NUM_CLIENTS = 5
CLIENTS_PER_ROUND = 10
ROUNDS = 10
LOCAL_EPOCHS = 1
LOCAL_BATCH_SIZE = 1   # keep small for strong leakage baseline (can be increased)

CLIENT_LR = 0.01
MOMENTUM = 0.9

# Model / data variables
IMAGE_SIZE = 224
INPUT_CHANNELS = 3
NUM_CLASSES = None

# Attack variables (IDLG with L-BFGS)
ATTACK_USE_LBFGS = True
ATTACK_LBFGS_MAX_ITER = 300
ATTACK_LBFGS_LINESEARCH = True
ATTACK_INIT = "random"
ATTACK_TV_WEIGHT = 1e-4
ATTACK_GRAD_MATCH_WEIGHT = 1.0
ATTACK_COSINE_THRESHOLD = 0.6
ATTACK_EVERY_N_ROUNDS = 1

RECON_SHAPE = (1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

PRINT_EVERY = 5
SAVE_RECON_DIR = "./reconstructions_secureagg_fixed"
os.makedirs(SAVE_RECON_DIR, exist_ok=True)

# Determinism
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------------
# Dataset utilities
# -----------------------------
DATA_ROOT = "/kaggle/input/lfw-facial-recognition/Face Recognition"
FOLDERS = ["Faces"]

transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def gather_image_paths(root, folders):
    paths = []
    for f in folders:
        folder = os.path.join(root, f)
        if not os.path.isdir(folder):
            continue
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                paths.append(os.path.join(folder, fname))
    return sorted(paths)

all_image_paths = gather_image_paths(DATA_ROOT, FOLDERS)
if len(all_image_paths) == 0:
    raise RuntimeError(f"No images found in {FOLDERS} under {DATA_ROOT}. Check paths.")

def identity_from_filename(path):
    name = os.path.basename(path)
    base = os.path.splitext(name)[0]
    parts = base.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:-1])
    return base

id_to_paths = defaultdict(list)
for p in all_image_paths:
    id_ = identity_from_filename(p)
    id_to_paths[id_].append(p)

valid_id_to_paths = {k: v for k, v in id_to_paths.items() if len(v) >= 2}
if len(valid_id_to_paths) == 0:
    raise RuntimeError("No identities with >=2 images found. Check filename format or dataset.")

all_valid_paths = []
all_valid_labels = []
unique_ids = sorted(valid_id_to_paths.keys())
label_map = {id_: idx for idx, id_ in enumerate(unique_ids)}
for id_, paths in valid_id_to_paths.items():
    for p in paths:
        all_valid_paths.append(p)
        all_valid_labels.append(label_map[id_])

NUM_CLASSES = len(unique_ids)
print(f"Found {len(all_valid_paths)} images across {NUM_CLASSES} identities (>=2 images each).")

# -----------------------------
# PyTorch Dataset and client partitioning
# -----------------------------
class FaceImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(self.labels[idx])
        return img, label, p

full_dataset = FaceImageDataset(all_valid_paths, all_valid_labels, transform=transform)

# Partition identities across clients (non-IID)
client_id_to_indices = {c: [] for c in range(NUM_CLIENTS)}
identities = list(unique_ids)
random.shuffle(identities)
for i, id_ in enumerate(identities):
    client_id = i % NUM_CLIENTS
    for p in valid_id_to_paths[id_]:
        idx = all_valid_paths.index(p)
        client_id_to_indices[client_id].append(idx)

num_nonempty = sum(1 for c in client_id_to_indices if len(client_id_to_indices[c]) > 0)
print(f"Assigned identities to {num_nonempty} non-empty clients out of {NUM_CLIENTS} total clients.")

client_loaders = {}
for c in range(NUM_CLIENTS):
    inds = client_id_to_indices[c]
    if len(inds) == 0:
        client_loaders[c] = None
        continue
    sub_paths = [all_valid_paths[i] for i in inds]
    sub_labels = [all_valid_labels[i] for i in inds]
    ds = FaceImageDataset(sub_paths, sub_labels, transform=transform)
    loader = DataLoader(ds, batch_size=LOCAL_BATCH_SIZE, shuffle=True)
    client_loaders[c] = loader

# Held-out test loader (one image per identity)
test_paths = []
test_labels = []
for id_, paths in valid_id_to_paths.items():
    p = paths[-1]
    test_paths.append(p)
    test_labels.append(label_map[id_])
test_ds = FaceImageDataset(test_paths, test_labels, transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# -----------------------------
# Model: ResNet-18 wrapper
# -----------------------------
class ResNet18Wrapper(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        backbone = models.resnet18(pretrained=pretrained)
        in_features = backbone.fc.in_features
        backbone.fc = nn.Identity()
        self.backbone = backbone
        self.classifier = nn.Linear(in_features, num_classes)
    def forward(self, x, return_features=False):
        feats = self.backbone(x)
        logits = self.classifier(feats)
        if return_features:
            return feats, logits
        return logits

global_model = ResNet18Wrapper(NUM_CLASSES, pretrained=True).to(DEVICE)

# Store global_state as float tensors
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

def model_from_state(state_dict):
    m = ResNet18Wrapper(NUM_CLASSES, pretrained=False).to(DEVICE)
    sd = {k: v.to(DEVICE) for k, v in state_dict.items()}
    m.load_state_dict(sd)
    return m

# -----------------------------
# Utility: ensure image tensors have shape (B,C,H,W)
# -----------------------------
def ensure_batch_image_shape(x):
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            return x.unsqueeze(0)
        if x.dim() == 5 and x.size(1) == 1:
            return x.squeeze(1)
        if x.dim() == 5 and x.size(2) == 1:
            return x.squeeze(2)
        return x
    else:
        return x

# -----------------------------
# Mask generation utilities (deterministic PRNG from seed)
# -----------------------------
def seed_from_pair(i, j, round_idx):
    s = f"{min(i,j)}_{max(i,j)}_{round_idx}"
    h = hashlib.sha256(s.encode("utf-8")).digest()
    seed = int.from_bytes(h[:8], byteorder="big", signed=False)
    return seed

def generate_mask_for_state(state_template, seed):
    gen = torch.Generator()
    gen.manual_seed(seed & ((1 << 63) - 1))
    mask = {}
    for k, v in state_template.items():
        m = torch.randn(v.shape, generator=gen, dtype=torch.float32)
        mask[k] = m
    return mask

def add_state_dicts(a, b):
    return {k: (a[k] + b[k]) for k in a.keys()}

def sub_state_dicts(a, b):
    return {k: (a[k] - b[k]) for k in a.keys()}

def zero_state_like(template):
    return {k: torch.zeros_like(v) for k, v in template.items()}

# -----------------------------
# Client local training (no DP) -> produce delta
# -----------------------------
def get_model_copy_for_client(state_dict):
    return model_from_state(state_dict)

def client_local_update_compute_delta(client_id, global_state):
    loader = client_loaders[client_id]
    if loader is None:
        return None, None, None

    local_model = get_model_copy_for_client(global_state)
    local_model.train()
    opt = torch.optim.SGD(local_model.parameters(), lr=CLIENT_LR, momentum=MOMENTUM)
    captured_label = None
    captured_image = None

    loss_fn = nn.CrossEntropyLoss(reduction='mean')

    for epoch in range(LOCAL_EPOCHS):
        for x, y, pth in loader:
            x = ensure_batch_image_shape(x)
            if x.dim() == 5 and x.size(1) == 1:
                x = x.squeeze(1)
            x = x.to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)

            opt.zero_grad()
            logits = local_model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            if captured_label is None:
                if y.dim() == 0:
                    captured_label = y.detach().cpu().clone().unsqueeze(0)
                else:
                    captured_label = y.detach().cpu().clone().view(-1)[0:1]
                captured_image = x.detach().cpu().clone()[0:1]
            opt.step()
            if LOCAL_EPOCHS == 1:
                break

    new_state = {k: v.cpu().to(torch.float32) for k, v in local_model.state_dict().items()}
    delta = {k: new_state[k] - global_state[k] for k in new_state.keys()}
    return delta, captured_label, captured_image

# -----------------------------
# Secure aggregation: client-side masking
# -----------------------------
def client_mask_delta(client_id, delta_state, participants, round_idx):
    masked = {k: delta_state[k].clone() for k in delta_state.keys()}
    mask_sum = zero_state_like(delta_state)
    for j in participants:
        if j == client_id:
            continue
        seed = seed_from_pair(client_id, j, round_idx)
        pair_mask = generate_mask_for_state(delta_state, seed)
        if client_id < j:
            for k in masked.keys():
                masked[k] = masked[k] + pair_mask[k]
                mask_sum[k] = mask_sum[k] + pair_mask[k]
        else:
            for k in masked.keys():
                masked[k] = masked[k] - pair_mask[k]
                mask_sum[k] = mask_sum[k] - pair_mask[k]
    return masked, mask_sum

# -----------------------------
# Server aggregation with mask cancellation (simulation)
# -----------------------------
def server_aggregate_masked(masked_deltas_dict, participants, round_idx, dropout_list=None):
    template = next(iter(masked_deltas_dict.values()))
    sum_masked = zero_state_like(template)
    for cid, md in masked_deltas_dict.items():
        for k in sum_masked.keys():
            sum_masked[k] = sum_masked[k] + md[k]

    present_clients = list(masked_deltas_dict.keys())
    present_mask_sums = zero_state_like(template)
    for i in present_clients:
        mask_sum_i = zero_state_like(template)
        for j in participants:
            if j == i:
                continue
            seed = seed_from_pair(i, j, round_idx)
            pair_mask = generate_mask_for_state(template, seed)
            if i < j:
                for k in mask_sum_i.keys():
                    mask_sum_i[k] = mask_sum_i[k] + pair_mask[k]
            else:
                for k in mask_sum_i.keys():
                    mask_sum_i[k] = mask_sum_i[k] - pair_mask[k]
        for k in present_mask_sums.keys():
            present_mask_sums[k] = present_mask_sums[k] + mask_sum_i[k]

    sum_present_deltas = {k: (sum_masked[k] - present_mask_sums[k]) for k in sum_masked.keys()}
    return sum_present_deltas

# -----------------------------
# Reconstruction utilities (unchanged)
# -----------------------------
def flatten_grads(grad_list):
    vecs = []
    for g in grad_list:
        if g is None:
            continue
        if not g.is_cuda:
            g = g.to(DEVICE)
        vecs.append(g.view(-1))
    if len(vecs) == 0:
        return torch.tensor([], device=DEVICE)
    return torch.cat(vecs)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) + \
           torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))

def assert_shapes_for_recon(x_init, observed_label, observed_grads):
    if not isinstance(x_init, torch.Tensor):
        raise RuntimeError("x_init must be a torch.Tensor")
    if x_init.dim() != 4 or x_init.size(0) != 1:
        raise RuntimeError(f"Reconstruction input must be (1,C,H,W), got {x_init.shape}")
    if observed_label is None:
        raise RuntimeError("Observed label is None")
    if observed_label.dim() != 1 or observed_label.size(0) != 1:
        raise RuntimeError(f"Observed label must be shape (1,), got {observed_label.shape}")
    flat = flatten_grads(observed_grads)
    if flat.numel() == 0:
        raise RuntimeError("Observed grads flattened to zero length; check grad capture")

def reconstruct_idlg_lbfgs(model_state_for_attack, observed_grads_cpu, observed_label_cpu, iters=ATTACK_LBFGS_MAX_ITER):
    if observed_label_cpu is None:
        raise ValueError("No observed label provided for reconstruction")
    if observed_label_cpu.dim() == 0:
        observed_label_cpu = observed_label_cpu.unsqueeze(0).long()
    else:
        observed_label_cpu = observed_label_cpu.view(-1)[0:1].long()

    model = model_from_state(model_state_for_attack)
    model.eval()

    obs = [g.to(DEVICE) if g is not None else None for g in observed_grads_cpu]
    obs_flat = flatten_grads(obs).detach()

    if ATTACK_INIT == "random":
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)
    else:
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)

    try:
        assert_shapes_for_recon(x_hat, observed_label_cpu, observed_grads_cpu)
    except Exception as e:
        raise RuntimeError(f"Pre-reconstruction shape check failed: {e}")

    optimizer = torch.optim.LBFGS([x_hat], max_iter=iters, line_search_fn='strong_wolfe' if ATTACK_LBFGS_LINESEARCH else None)

    def closure():
        optimizer.zero_grad()
        x_in = ensure_batch_image_shape(x_hat)
        logits = model(x_in)
        loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
        grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
        grads_hat_flat = flatten_grads(grads_hat)
        if obs_flat.numel() == 0 or grads_hat_flat.numel() == 0:
            loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            loss.backward()
            return loss
        loss_match = F.mse_loss(grads_hat_flat, obs_flat)
        loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception as e:
        print("LBFGS failed, falling back to Adam:", e)
        opt_adam = torch.optim.Adam([x_hat], lr=0.05)
        for _ in range(500):
            opt_adam.zero_grad()
            x_in = ensure_batch_image_shape(x_hat)
            logits = model(x_in)
            loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
            grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
            grads_hat_flat = flatten_grads(grads_hat)
            loss_match = F.mse_loss(grads_hat_flat, obs_flat)
            loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
            loss.backward()
            opt_adam.step()
            with torch.no_grad():
                x_hat.clamp_(-1, 1)

    with torch.no_grad():
        x_rec = x_hat.clamp_(-1, 1).detach().cpu()
    return x_rec

# -----------------------------
# Evaluation metrics (unchanged)
# -----------------------------
def denorm(img_tensor):
    img = img_tensor.clone()
    img = (img * 0.5) + 0.5
    return img.clamp(0,1)

def compute_ssim(img1, img2):
    a = img1.squeeze(0).permute(1,2,0).numpy()
    b = img2.squeeze(0).permute(1,2,0).numpy()
    s = 0.0
    for ch in range(a.shape[2]):
        s += ssim(a[:,:,ch], b[:,:,ch], data_range=1.0)
    return s / a.shape[2]

def compute_cosine_similarity_feature(model_state, img1, img2):
    model = model_from_state(model_state)
    model.eval()
    with torch.no_grad():
        f1, _ = model(ensure_batch_image_shape(img1).to(DEVICE), return_features=True)
        f2, _ = model(ensure_batch_image_shape(img2).to(DEVICE), return_features=True)
    f1 = f1.cpu().numpy()
    f2 = f2.cpu().numpy()
    return float(cosine_similarity(f1, f2)[0,0])

def evaluate_accuracy(model_state, dataloader):
    model = model_from_state(model_state)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x = ensure_batch_image_shape(x).to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

# -----------------------------
# Training loop with secure aggregation (fixed capture ordering)
# -----------------------------
attack_logs = []
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

for rnd in range(ROUNDS):
    available_clients = [c for c in range(NUM_CLIENTS) if client_loaders[c] is not None]
    sampled = random.sample(available_clients, min(CLIENTS_PER_ROUND, len(available_clients)))
    sampled = sorted(sampled)  # deterministic ordering for mask derivation
    client_masked_deltas = {}
    client_mask_sums = {}
    client_sizes = []

    victim_client = sampled[0] if len(sampled) > 0 else None
    captured_grad_for_attack = None
    captured_label = None
    captured_image = None
    model_state_sent_to_client = None

    # Each client computes local delta
    for c in sampled:
        model_state_sent_to_client = {k: v.clone() for k, v in global_state.items()}
        delta, cap_label, cap_image = client_local_update_compute_delta(c, model_state_sent_to_client)
        if delta is None:
            continue

        # For attack simulation: capture victim's uncloaked delta in the same order as model.parameters()
        if c == victim_client:
            # Build a model instance to get parameter names and order
            tmp_model = model_from_state(model_state_sent_to_client)
            param_names = [name for name, _ in tmp_model.named_parameters()]  # ordered
            # state_dict keys include parameters and buffers; we select only parameter keys matching param_names
            # state_dict keys for parameters are exactly the param_names
            captured_grad_for_attack = []
            for pname in param_names:
                if pname in delta:
                    captured_grad_for_attack.append(delta[pname].clone())
                else:
                    # fallback: if key missing, append None
                    captured_grad_for_attack.append(None)
            captured_label = cap_label
            captured_image = cap_image

        masked_delta, mask_sum = client_mask_delta(c, delta, sampled, rnd)
        client_masked_deltas[c] = masked_delta
        client_mask_sums[c] = mask_sum
        size = len(client_loaders[c].dataset) if client_loaders[c] is not None else 0
        client_sizes.append(size)

    if len(client_masked_deltas) == 0:
        print("No client updates this round.")
        continue

    aggregated_delta = server_aggregate_masked(client_masked_deltas, sampled, rnd, dropout_list=None)

    total_samples = sum(client_sizes) if sum(client_sizes) > 0 else len(client_masked_deltas)
    avg_delta = {k: (aggregated_delta[k] / float(total_samples)) for k in aggregated_delta.keys()}

    new_global_state = {}
    for k in global_state.keys():
        new_global_state[k] = (global_state[k] + avg_delta[k]).cpu()
    global_state = new_global_state

    if rnd % PRINT_EVERY == 0 or rnd == ROUNDS - 1:
        acc = evaluate_accuracy(global_state, test_loader)
        print(f"Round {rnd:03d}/{ROUNDS:03d} - Test accuracy: {acc:.4f} - sampled {len(sampled)} clients")

    # Attack evaluation: reconstruct from the captured per-parameter delta (parameters-only, ordered)
    if (rnd % ATTACK_EVERY_N_ROUNDS == 0) and (captured_grad_for_attack is not None):
        try:
            # captured_grad_for_attack is a list of tensors (or None) matching model.parameters() order
            reconstructed = reconstruct_idlg_lbfgs(model_state_sent_to_client, captured_grad_for_attack, captured_label, iters=ATTACK_LBFGS_MAX_ITER)
        except Exception as e:
            print("Reconstruction failed:", e)
            reconstructed = None

        if reconstructed is not None:
            rec_den = denorm(reconstructed)
            true_den = denorm(captured_image.cpu())

            ssim_val = compute_ssim(rec_den, true_den)
            cos_val = compute_cosine_similarity_feature(model_state_sent_to_client, rec_den.unsqueeze(0), true_den.unsqueeze(0))
            success = float(cos_val >= ATTACK_COSINE_THRESHOLD)

            def tensor_to_pil(img_tensor, fname):
                arr = (img_tensor.squeeze(0).permute(1,2,0).numpy() * 255).astype(np.uint8)
                Image.fromarray(arr).save(fname)

            fname_rec = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_rec.png")
            fname_true = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_true.png")
            try:
                tensor_to_pil(rec_den, fname_rec)
                tensor_to_pil(true_den, fname_true)
            except Exception:
                np.save(fname_rec + ".npy", rec_den.numpy())
                np.save(fname_true + ".npy", true_den.numpy())

            log_entry = {
                "round": int(rnd),
                "victim_client": int(victim_client),
                "ssim": float(ssim_val),
                "cosine": float(cos_val),
                "success": float(success),
                "recon_path": fname_rec,
                "true_path": fname_true,
                "label": int(captured_label.item())
            }
            attack_logs.append(log_entry)
            print(f"[Attack-sim] Round {rnd} client {victim_client}: SSIM={ssim_val:.4f}, Cos={cos_val:.4f}, Success={success}")

# -----------------------------
# Final aggregated attack metrics and save logs
# -----------------------------
total_attacks = len(attack_logs)
if total_attacks > 0:
    avg_ssim = sum([e["ssim"] for e in attack_logs]) / total_attacks
    avg_cos = sum([e["cosine"] for e in attack_logs]) / total_attacks
    success_rate = sum([e["success"] for e in attack_logs]) / total_attacks
else:
    avg_ssim = avg_cos = success_rate = 0.0

print("=== Attack summary (simulated single-client signal) ===")
print(f"Attacks run: {total_attacks}")
print(f"Avg SSIM: {avg_ssim:.4f}")
print(f"Avg Cosine: {avg_cos:.4f}")
print(f"Success rate (cos>={ATTACK_COSINE_THRESHOLD}): {success_rate:.3f}")

with open(os.path.join(SAVE_RECON_DIR, "attack_logs_secureagg_fixed.json"), "w") as f:
    json.dump(attack_logs, f, indent=2)

print("Reconstruction images and logs saved to:", SAVE_RECON_DIR)


Found 9164 images across 1680 identities (>=2 images each).
Assigned identities to 5 non-empty clients out of 5 total clients.




Round 000/010 - Test accuracy: 0.0012 - sampled 5 clients
[Attack-sim] Round 0 client 0: SSIM=0.0048, Cos=0.3067, Success=0.0




[Attack-sim] Round 1 client 0: SSIM=0.0036, Cos=0.3165, Success=0.0




[Attack-sim] Round 2 client 0: SSIM=0.0052, Cos=0.3191, Success=0.0




[Attack-sim] Round 3 client 0: SSIM=0.0050, Cos=0.3392, Success=0.0




[Attack-sim] Round 4 client 0: SSIM=0.0035, Cos=0.2987, Success=0.0




Round 005/010 - Test accuracy: 0.0012 - sampled 5 clients
[Attack-sim] Round 5 client 0: SSIM=0.0059, Cos=0.2528, Success=0.0




[Attack-sim] Round 6 client 0: SSIM=0.0055, Cos=0.3249, Success=0.0




[Attack-sim] Round 7 client 0: SSIM=0.0041, Cos=0.2952, Success=0.0




[Attack-sim] Round 8 client 0: SSIM=0.0055, Cos=0.2845, Success=0.0




Round 009/010 - Test accuracy: 0.0012 - sampled 5 clients
[Attack-sim] Round 9 client 0: SSIM=0.0037, Cos=0.3205, Success=0.0
=== Attack summary (simulated single-client signal) ===
Attacks run: 10
Avg SSIM: 0.0047
Avg Cosine: 0.3058
Success rate (cos>=0.6): 0.000
Reconstruction images and logs saved to: ./reconstructions_secureagg_fixed
