In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("quadeer15sh/lfw-facial-recognition")

print("Path to dataset files:", path)

Mounting files to /kaggle/input/lfw-facial-recognition...
Path to dataset files: /kaggle/input/lfw-facial-recognition


# FL with No Protection (Fed Avg) with 50 Client

In [1]:
# Single-cell runnable script: FedAvg (no protection) with LeNet on a local face dataset
# and IDLG-style reconstruction attack during training using L-BFGS.
#
# This version includes fixes for:
# - Ensuring captured label/image are single-example (shape (1,) and (1,C,H,W))
# - Defensive shape checks to avoid batch-size mismatches during reconstruction
# - Robust handling of unexpected batch shapes from DataLoader
#
# Requirements: torch, torchvision, numpy, pillow, scikit-image, tqdm, sklearn

import os
import random
import math
import json
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics.pairwise import cosine_similarity

# -----------------------------
# Experiment hyperparameters
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Federated learning variables
NUM_CLIENTS = 50                 # total clients
CLIENTS_PER_ROUND = 10           # sampled clients per round
ROUNDS = 100                     # total federated rounds
LOCAL_EPOCHS = 1                 # local epochs per client (E)
LOCAL_BATCH_SIZE = 1             # batch size on client (B) -> 1 for strong leakage
CLIENT_LR = 0.01                 # client optimizer lr
MOMENTUM = 0.9

# Model / data variables
IMAGE_SIZE = 64                  # resize images to 64x64
INPUT_CHANNELS = 3
NUM_CLASSES = None               # set after dataset scan

# Attack variables (IDLG with L-BFGS)
ATTACK_USE_LBFGS = True
ATTACK_LBFGS_MAX_ITER = 300
ATTACK_LBFGS_LINESEARCH = True
ATTACK_INIT = "random"           # "random" or "mean_face"
ATTACK_TV_WEIGHT = 1e-4
ATTACK_GRAD_MATCH_WEIGHT = 1.0
ATTACK_COSINE_THRESHOLD = 0.8    # threshold for attack success (cosine similarity)
ATTACK_EVERY_N_ROUNDS = 5        # attack frequency

# Reconstruction image shape
RECON_SHAPE = (1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

# Logging / misc
PRINT_EVERY = 5
SAVE_RECON_DIR = "./reconstructions"
os.makedirs(SAVE_RECON_DIR, exist_ok=True)

# Determinism
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------------
# Dataset utilities
# -----------------------------
# Set your dataset path here (example: Kaggle path provided earlier)
DATA_ROOT = "/kaggle/input/lfw-facial-recognition/Face Recognition"
# Choose which subfolders to include
FOLDERS = ["Faces"]  # or ["detected faces", "Faces"]

# Image transforms (normalize to [-1,1] as used in reconstruction)
transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),                       # [0,1]
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)  # -> [-1,1]
])

# Helper: gather all image file paths from the chosen folders
def gather_image_paths(root, folders):
    paths = []
    for f in folders:
        folder = os.path.join(root, f)
        if not os.path.isdir(folder):
            continue
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                paths.append(os.path.join(folder, fname))
    return sorted(paths)

all_image_paths = gather_image_paths(DATA_ROOT, FOLDERS)
if len(all_image_paths) == 0:
    raise RuntimeError(f"No images found in {FOLDERS} under {DATA_ROOT}. Check paths.")

# Identity extraction: everything before the last underscore (AJ_Cook_0001 -> AJ_Cook)
def identity_from_filename(path):
    name = os.path.basename(path)
    base = os.path.splitext(name)[0]
    parts = base.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:-1])
    return base

# Build mapping identity -> list of image paths
id_to_paths = defaultdict(list)
for p in all_image_paths:
    id_ = identity_from_filename(p)
    id_to_paths[id_].append(p)

# Filter identities with at least 2 images (so we can have train/test per identity)
valid_id_to_paths = {k: v for k, v in id_to_paths.items() if len(v) >= 2}
if len(valid_id_to_paths) == 0:
    raise RuntimeError("No identities with >=2 images found. Check filename format or dataset.")

# Build a flat dataset list and label mapping
all_valid_paths = []
all_valid_labels = []
unique_ids = sorted(valid_id_to_paths.keys())
label_map = {id_: idx for idx, id_ in enumerate(unique_ids)}
for id_, paths in valid_id_to_paths.items():
    for p in paths:
        all_valid_paths.append(p)
        all_valid_labels.append(label_map[id_])

NUM_CLASSES = len(unique_ids)
print(f"Found {len(all_valid_paths)} images across {NUM_CLASSES} identities (>=2 images each).")

# -----------------------------
# PyTorch Dataset and client partitioning
# -----------------------------
class FaceImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)  # should be (C,H,W)
        label = self.labels[idx]
        return img, label, p  # return path for debugging

# Create a dataset object for convenience (we will create per-client subsets)
full_dataset = FaceImageDataset(all_valid_paths, all_valid_labels, transform=transform)

# Partition identities across clients (non-IID: each client gets images of a few identities)
client_id_to_indices = {c: [] for c in range(NUM_CLIENTS)}
identities = list(unique_ids)
random.shuffle(identities)
# Round-robin assign identities to clients
for i, id_ in enumerate(identities):
    client_id = i % NUM_CLIENTS
    for p in valid_id_to_paths[id_]:
        idx = all_valid_paths.index(p)
        client_id_to_indices[client_id].append(idx)

num_nonempty = sum(1 for c in client_id_to_indices if len(client_id_to_indices[c]) > 0)
print(f"Assigned identities to {num_nonempty} non-empty clients out of {NUM_CLIENTS} total clients.")

# Build client dataloaders
client_loaders = {}
for c in range(NUM_CLIENTS):
    inds = client_id_to_indices[c]
    if len(inds) == 0:
        client_loaders[c] = None
        continue
    sub_paths = [all_valid_paths[i] for i in inds]
    sub_labels = [all_valid_labels[i] for i in inds]
    ds = FaceImageDataset(sub_paths, sub_labels, transform=transform)
    loader = DataLoader(ds, batch_size=LOCAL_BATCH_SIZE, shuffle=True)
    client_loaders[c] = loader

# Build a small held-out test loader from each identity (take one image per identity)
test_paths = []
test_labels = []
for id_, paths in valid_id_to_paths.items():
    p = paths[-1]
    test_paths.append(p)
    test_labels.append(label_map[id_])
test_ds = FaceImageDataset(test_paths, test_labels, transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# -----------------------------
# Model: small LeNet-like CNN
# -----------------------------
class LeNetSmall(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(INPUT_CHANNELS, 6, kernel_size=5)
        self.pool = nn.AvgPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        with torch.no_grad():
            dummy = torch.zeros(1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
            x = self.pool(F.relu(self.conv1(dummy)))
            x = self.pool(F.relu(self.conv2(x)))
            flat = x.view(1, -1).shape[1]
        self.fc1 = nn.Linear(flat, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x, return_features=False):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        feat = F.relu(self.fc1(x))
        feat = F.relu(self.fc2(feat))
        logits = self.fc3(feat)
        if return_features:
            return feat, logits
        return logits

# Instantiate global model
global_model = LeNetSmall(NUM_CLASSES).to(DEVICE)
global_state = {k: v.cpu() for k, v in global_model.state_dict().items()}

# Helper to create model from state dict
def model_from_state(state_dict):
    m = LeNetSmall(NUM_CLASSES).to(DEVICE)
    sd = {k: v.to(DEVICE) for k, v in state_dict.items()}
    m.load_state_dict(sd)
    return m

# -----------------------------
# Utility: ensure image tensors have shape (B,C,H,W)
# -----------------------------
def ensure_batch_image_shape(x):
    # x may be (C,H,W), (1,C,H,W), (B,C,H,W), or (B,1,C,H,W) etc.
    # Return a tensor with shape (B,C,H,W)
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            return x.unsqueeze(0)
        if x.dim() == 5 and x.size(1) == 1:
            # shape (B,1,C,H,W) -> squeeze dim1
            return x.squeeze(1)
        if x.dim() == 5 and x.size(2) == 1:
            # shape (B,C,1,H,W) unlikely, try to reshape
            return x.squeeze(2)
        return x
    else:
        return x

# -----------------------------
# FedAvg client update (single-step) with gradient capture
# -----------------------------
def get_model_copy_for_client(state_dict):
    return model_from_state(state_dict)

def client_update(client_id, global_state, return_grad_snapshot=False):
    loader = client_loaders[client_id]
    if loader is None:
        return None, None, None, None
    local_model = get_model_copy_for_client(global_state)
    opt = torch.optim.SGD(local_model.parameters(), lr=CLIENT_LR, momentum=MOMENTUM)
    local_model.train()
    grad_snapshot = None
    captured_label = None
    captured_image = None
    for epoch in range(LOCAL_EPOCHS):
        for x, y, pth in loader:
            # x: (B,C,H,W) normally; guard against extra dims
            x = ensure_batch_image_shape(x)
            # Defensive: if DataLoader produced shape (B,1,C,H,W) or similar, squeeze
            if x.dim() == 5 and x.size(1) == 1:
                x = x.squeeze(1)
            x = x.to(DEVICE)
            y = torch.tensor(y) if not isinstance(y, torch.Tensor) else y
            y = y.to(DEVICE)
            opt.zero_grad()
            logits = local_model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            # Capture per-parameter gradients BEFORE optimizer.step (this is what attacker sees)
            if return_grad_snapshot and grad_snapshot is None:
                # capture per-parameter gradients (CPU)
                grad_snapshot = [p.grad.detach().cpu().clone() if p.grad is not None else None for p in local_model.parameters()]

                # Ensure we capture a single example (first in the batch)
                # y may be shape (B,) or (B,1); take first element and keep batch dim (1,)
                if y.dim() == 0:
                    captured_label = y.detach().cpu().clone().unsqueeze(0)   # scalar -> (1,)
                else:
                    captured_label = y.detach().cpu().clone().view(-1)[0:1]  # (1,)

                # captured_image: ensure shape (1,C,H,W)
                captured_image = x.detach().cpu().clone()[0:1]               # keep first example only

                # Debug print (optional)
                # print(f"[DEBUG] client {client_id} captured shapes: x {x.shape}, y {y.shape}")

            opt.step()
            # For strong leakage baseline, break after first batch
            if LOCAL_EPOCHS == 1:
                break
    # compute delta = local_params - global_params (on CPU)
    new_state = {k: v.cpu() for k, v in local_model.state_dict().items()}
    delta = {k: new_state[k] - global_state[k] for k in new_state.keys()}
    return delta, grad_snapshot, captured_label, captured_image

# -----------------------------
# Server aggregation (FedAvg)
# -----------------------------
def server_aggregate(global_state, client_deltas, client_sizes):
    total = sum(client_sizes) if sum(client_sizes) > 0 else 1
    agg = {k: torch.zeros_like(v) for k, v in global_state.items()}
    for delta, size in zip(client_deltas, client_sizes):
        if delta is None:
            continue
        weight = size / total
        for k in agg.keys():
            agg[k] += delta[k] * weight
    new_state = {k: global_state[k] + agg[k] for k in global_state.keys()}
    return new_state

# -----------------------------
# Reconstruction utilities (IDLG with L-BFGS)
# -----------------------------
def flatten_grads(grad_list):
    vecs = []
    for g in grad_list:
        if g is None:
            continue
        # ensure on DEVICE
        if not g.is_cuda:
            g = g.to(DEVICE)
        vecs.append(g.view(-1))
    if len(vecs) == 0:
        return torch.tensor([], device=DEVICE)
    return torch.cat(vecs)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) + \
           torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))

def assert_shapes_for_recon(x_init, observed_label, observed_grads):
    # x_init should be (1,C,H,W)
    if not isinstance(x_init, torch.Tensor):
        raise RuntimeError("x_init must be a torch.Tensor")
    if x_init.dim() != 4 or x_init.size(0) != 1:
        raise RuntimeError(f"Reconstruction input must be (1,C,H,W), got {x_init.shape}")
    # observed_label should be (1,)
    if observed_label is None:
        raise RuntimeError("Observed label is None")
    if observed_label.dim() != 1 or observed_label.size(0) != 1:
        raise RuntimeError(f"Observed label must be shape (1,), got {observed_label.shape}")
    # grads: ensure flattened length > 0
    flat = flatten_grads(observed_grads)
    if flat.numel() == 0:
        raise RuntimeError("Observed grads flattened to zero length; check grad capture")

def reconstruct_idlg_lbfgs(model_state_for_attack, observed_grads_cpu, observed_label_cpu, iters=ATTACK_LBFGS_MAX_ITER):
    # Ensure observed_label_cpu is shape (1,)
    if observed_label_cpu is None:
        raise ValueError("No observed label provided for reconstruction")
    if observed_label_cpu.dim() == 0:
        observed_label_cpu = observed_label_cpu.unsqueeze(0)
    else:
        observed_label_cpu = observed_label_cpu.view(-1)[0:1]

    # Build a model instance loaded with the pre-update state (on DEVICE)
    model = model_from_state(model_state_for_attack)
    model.eval()

    # Move observed grads to DEVICE and flatten
    obs = [g.to(DEVICE) if g is not None else None for g in observed_grads_cpu]
    obs_flat = flatten_grads(obs).detach()

    # Initialize x_hat in the same normalized space as training ([-1,1])
    if ATTACK_INIT == "random":
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)
    else:
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)

    # Defensive shape assertion
    try:
        assert_shapes_for_recon(x_hat, observed_label_cpu, observed_grads_cpu)
    except Exception as e:
        raise RuntimeError(f"Pre-reconstruction shape check failed: {e}")

    # Use LBFGS optimizer
    optimizer = torch.optim.LBFGS([x_hat], max_iter=iters, line_search_fn='strong_wolfe' if ATTACK_LBFGS_LINESEARCH else None)

    def closure():
        optimizer.zero_grad()
        x_in = ensure_batch_image_shape(x_hat)
        logits = model(x_in)
        loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
        grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
        grads_hat_flat = flatten_grads(grads_hat)
        if obs_flat.numel() == 0 or grads_hat_flat.numel() == 0:
            loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            loss.backward()
            return loss
        loss_match = F.mse_loss(grads_hat_flat, obs_flat)
        loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception as e:
        # LBFGS sometimes raises on line search; fallback to Adam for a while
        print("LBFGS failed, falling back to Adam:", e)
        opt_adam = torch.optim.Adam([x_hat], lr=0.05)
        for _ in range(500):
            opt_adam.zero_grad()
            x_in = ensure_batch_image_shape(x_hat)
            logits = model(x_in)
            loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
            grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
            grads_hat_flat = flatten_grads(grads_hat)
            loss_match = F.mse_loss(grads_hat_flat, obs_flat)
            loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
            loss.backward()
            opt_adam.step()
            with torch.no_grad():
                x_hat.clamp_(-1, 1)

    with torch.no_grad():
        x_rec = x_hat.clamp_(-1, 1).detach().cpu()
    return x_rec

# -----------------------------
# Evaluation metrics
# -----------------------------
def denorm(img_tensor):
    img = img_tensor.clone()
    img = (img * 0.5) + 0.5
    return img.clamp(0,1)

def compute_ssim(img1, img2):
    a = img1.squeeze(0).permute(1,2,0).numpy()
    b = img2.squeeze(0).permute(1,2,0).numpy()
    s = 0.0
    for ch in range(a.shape[2]):
        s += ssim(a[:,:,ch], b[:,:,ch], data_range=1.0)
    return s / a.shape[2]

def compute_cosine_similarity_feature(model_state, img1, img2):
    model = model_from_state(model_state)
    model.eval()
    with torch.no_grad():
        f1, _ = model(ensure_batch_image_shape(img1).to(DEVICE), return_features=True)
        f2, _ = model(ensure_batch_image_shape(img2).to(DEVICE), return_features=True)
    f1 = f1.cpu().numpy()
    f2 = f2.cpu().numpy()
    return float(cosine_similarity(f1, f2)[0,0])

def evaluate_accuracy(model_state, dataloader):
    model = model_from_state(model_state)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x = ensure_batch_image_shape(x).to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

# -----------------------------
# Training loop with attack during training
# -----------------------------
attack_logs = []
global_state = {k: v.cpu() for k, v in global_model.state_dict().items()}

for rnd in range(ROUNDS):
    available_clients = [c for c in range(NUM_CLIENTS) if client_loaders[c] is not None]
    sampled = random.sample(available_clients, min(CLIENTS_PER_ROUND, len(available_clients)))
    client_deltas = []
    client_sizes = []

    victim_client = sampled[0] if len(sampled) > 0 else None
    captured_grad = None
    captured_label = None
    captured_image = None
    model_state_sent_to_client = None

    for c in sampled:
        model_state_sent_to_client = {k: v.clone() for k, v in global_state.items()}
        if c == victim_client:
            delta, grad_snapshot, cap_label, cap_image = client_update(c, model_state_sent_to_client, return_grad_snapshot=True)
            captured_grad = grad_snapshot
            captured_label = cap_label
            captured_image = cap_image
        else:
            delta, _, _, _ = client_update(c, model_state_sent_to_client, return_grad_snapshot=False)
        client_deltas.append(delta)
        size = len(client_loaders[c].dataset) if client_loaders[c] is not None else 0
        client_sizes.append(size)

    global_state = server_aggregate(global_state, client_deltas, client_sizes)

    if rnd % PRINT_EVERY == 0 or rnd == ROUNDS - 1:
        acc = evaluate_accuracy(global_state, test_loader)
        print(f"Round {rnd:03d}/{ROUNDS:03d} - Test accuracy: {acc:.4f} - sampled {len(sampled)} clients")

    if (rnd % ATTACK_EVERY_N_ROUNDS == 0) and (captured_grad is not None):
        # Ensure captured_label and captured_image are present and shaped correctly
        try:
            if captured_label is None or captured_image is None:
                raise RuntimeError("Captured label/image missing for attack")
            # Force shapes: label (1,), image (1,C,H,W)
            if captured_label.dim() == 0:
                captured_label = captured_label.unsqueeze(0)
            else:
                captured_label = captured_label.view(-1)[0:1]
            if captured_image.dim() == 3:
                captured_image = captured_image.unsqueeze(0)
            elif captured_image.dim() == 5 and captured_image.size(1) == 1:
                captured_image = captured_image.squeeze(1)
            captured_image = captured_image[0:1]  # ensure single example
        except Exception as e:
            print("Captured data shape error, skipping attack this round:", e)
            continue

        try:
            reconstructed = reconstruct_idlg_lbfgs(model_state_sent_to_client, captured_grad, captured_label, iters=ATTACK_LBFGS_MAX_ITER)
        except Exception as e:
            print("Reconstruction failed:", e)
            reconstructed = None

        if reconstructed is not None:
            rec_den = denorm(reconstructed)
            true_den = denorm(captured_image.cpu())

            ssim_val = compute_ssim(rec_den, true_den)
            cos_val = compute_cosine_similarity_feature(model_state_sent_to_client, rec_den.unsqueeze(0), true_den.unsqueeze(0))
            success = float(cos_val >= ATTACK_COSINE_THRESHOLD)

            def tensor_to_pil(img_tensor, fname):
                arr = (img_tensor.squeeze(0).permute(1,2,0).numpy() * 255).astype(np.uint8)
                Image.fromarray(arr).save(fname)

            fname_rec = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_rec.png")
            fname_true = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_true.png")
            try:
                tensor_to_pil(rec_den, fname_rec)
                tensor_to_pil(true_den, fname_true)
            except Exception:
                np.save(fname_rec + ".npy", rec_den.numpy())
                np.save(fname_true + ".npy", true_den.numpy())

            log_entry = {
                "round": int(rnd),
                "victim_client": int(victim_client),
                "ssim": float(ssim_val),
                "cosine": float(cos_val),
                "success": float(success),
                "recon_path": fname_rec,
                "true_path": fname_true,
                "label": int(captured_label.item())
            }
            attack_logs.append(log_entry)
            print(f"[Attack] Round {rnd} client {victim_client}: SSIM={ssim_val:.4f}, Cos={cos_val:.4f}, Success={success}")

# -----------------------------
# Final aggregated attack metrics and save logs
# -----------------------------
total_attacks = len(attack_logs)
if total_attacks > 0:
    avg_ssim = sum([e["ssim"] for e in attack_logs]) / total_attacks
    avg_cos = sum([e["cosine"] for e in attack_logs]) / total_attacks
    success_rate = sum([e["success"] for e in attack_logs]) / total_attacks
else:
    avg_ssim = avg_cos = success_rate = 0.0

print("=== Attack summary ===")
print(f"Attacks run: {total_attacks}")
print(f"Avg SSIM: {avg_ssim:.4f}")
print(f"Avg Cosine: {avg_cos:.4f}")
print(f"Success rate (cos>={ATTACK_COSINE_THRESHOLD}): {success_rate:.3f}")

with open(os.path.join(SAVE_RECON_DIR, "attack_logs.json"), "w") as f:
    json.dump(attack_logs, f, indent=2)

print("Reconstruction images and logs saved to:", SAVE_RECON_DIR)


Found 9164 images across 1680 identities (>=2 images each).
Assigned identities to 50 non-empty clients out of 50 total clients.
Round 000/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 0 client 46: SSIM=0.0032, Cos=0.9888, Success=1.0
[Attack] Round 1 client 9: SSIM=0.0098, Cos=0.9913, Success=1.0
[Attack] Round 2 client 28: SSIM=0.0048, Cos=0.9900, Success=1.0
[Attack] Round 3 client 28: SSIM=0.0028, Cos=0.9952, Success=1.0
[Attack] Round 4 client 32: SSIM=0.0089, Cos=0.9968, Success=1.0
Round 005/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 5 client 41: SSIM=0.0122, Cos=0.9879, Success=1.0
[Attack] Round 6 client 28: SSIM=0.0005, Cos=0.9910, Success=1.0
[Attack] Round 7 client 28: SSIM=0.0033, Cos=0.9856, Success=1.0
[Attack] Round 8 client 45: SSIM=0.0038, Cos=0.9920, Success=1.0
[Attack] Round 9 client 24: SSIM=0.0126, Cos=0.9878, Success=1.0
Round 010/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 10 client 21: SSIM=0.0030, Cos

# FL with No Protection (Fed Avg) with 100 Client

In [3]:
# Single-cell runnable script: FedAvg (no protection) with LeNet on a local face dataset
# and IDLG-style reconstruction attack during training using L-BFGS.
#
# This version includes fixes for:
# - Ensuring captured label/image are single-example (shape (1,) and (1,C,H,W))
# - Defensive shape checks to avoid batch-size mismatches during reconstruction
# - Robust handling of unexpected batch shapes from DataLoader
#
# Requirements: torch, torchvision, numpy, pillow, scikit-image, tqdm, sklearn

import os
import random
import math
import json
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics.pairwise import cosine_similarity

# -----------------------------
# Experiment hyperparameters
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Federated learning variables
NUM_CLIENTS = 100                 # total clients
CLIENTS_PER_ROUND = 10           # sampled clients per round
ROUNDS = 100                     # total federated rounds
LOCAL_EPOCHS = 1                 # local epochs per client (E)
LOCAL_BATCH_SIZE = 1             # batch size on client (B) -> 1 for strong leakage
CLIENT_LR = 0.01                 # client optimizer lr
MOMENTUM = 0.9

# Model / data variables
IMAGE_SIZE = 64                  # resize images to 64x64
INPUT_CHANNELS = 3
NUM_CLASSES = None               # set after dataset scan

# Attack variables (IDLG with L-BFGS)
ATTACK_USE_LBFGS = True
ATTACK_LBFGS_MAX_ITER = 300
ATTACK_LBFGS_LINESEARCH = True
ATTACK_INIT = "random"           # "random" or "mean_face"
ATTACK_TV_WEIGHT = 1e-4
ATTACK_GRAD_MATCH_WEIGHT = 1.0
ATTACK_COSINE_THRESHOLD = 0.8    # threshold for attack success (cosine similarity)
ATTACK_EVERY_N_ROUNDS = 5        # attack frequency

# Reconstruction image shape
RECON_SHAPE = (1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

# Logging / misc
PRINT_EVERY = 5
SAVE_RECON_DIR = "./reconstructions"
os.makedirs(SAVE_RECON_DIR, exist_ok=True)

# Determinism
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------------
# Dataset utilities
# -----------------------------
# Set your dataset path here (example: Kaggle path provided earlier)
DATA_ROOT = "/kaggle/input/lfw-facial-recognition/Face Recognition"
# Choose which subfolders to include
FOLDERS = ["Faces"]  # or ["detected faces", "Faces"]

# Image transforms (normalize to [-1,1] as used in reconstruction)
transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),                       # [0,1]
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)  # -> [-1,1]
])

# Helper: gather all image file paths from the chosen folders
def gather_image_paths(root, folders):
    paths = []
    for f in folders:
        folder = os.path.join(root, f)
        if not os.path.isdir(folder):
            continue
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                paths.append(os.path.join(folder, fname))
    return sorted(paths)

all_image_paths = gather_image_paths(DATA_ROOT, FOLDERS)
if len(all_image_paths) == 0:
    raise RuntimeError(f"No images found in {FOLDERS} under {DATA_ROOT}. Check paths.")

# Identity extraction: everything before the last underscore (AJ_Cook_0001 -> AJ_Cook)
def identity_from_filename(path):
    name = os.path.basename(path)
    base = os.path.splitext(name)[0]
    parts = base.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:-1])
    return base

# Build mapping identity -> list of image paths
id_to_paths = defaultdict(list)
for p in all_image_paths:
    id_ = identity_from_filename(p)
    id_to_paths[id_].append(p)

# Filter identities with at least 2 images (so we can have train/test per identity)
valid_id_to_paths = {k: v for k, v in id_to_paths.items() if len(v) >= 2}
if len(valid_id_to_paths) == 0:
    raise RuntimeError("No identities with >=2 images found. Check filename format or dataset.")

# Build a flat dataset list and label mapping
all_valid_paths = []
all_valid_labels = []
unique_ids = sorted(valid_id_to_paths.keys())
label_map = {id_: idx for idx, id_ in enumerate(unique_ids)}
for id_, paths in valid_id_to_paths.items():
    for p in paths:
        all_valid_paths.append(p)
        all_valid_labels.append(label_map[id_])

NUM_CLASSES = len(unique_ids)
print(f"Found {len(all_valid_paths)} images across {NUM_CLASSES} identities (>=2 images each).")

# -----------------------------
# PyTorch Dataset and client partitioning
# -----------------------------
class FaceImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)  # should be (C,H,W)
        label = self.labels[idx]
        return img, label, p  # return path for debugging

# Create a dataset object for convenience (we will create per-client subsets)
full_dataset = FaceImageDataset(all_valid_paths, all_valid_labels, transform=transform)

# Partition identities across clients (non-IID: each client gets images of a few identities)
client_id_to_indices = {c: [] for c in range(NUM_CLIENTS)}
identities = list(unique_ids)
random.shuffle(identities)
# Round-robin assign identities to clients
for i, id_ in enumerate(identities):
    client_id = i % NUM_CLIENTS
    for p in valid_id_to_paths[id_]:
        idx = all_valid_paths.index(p)
        client_id_to_indices[client_id].append(idx)

num_nonempty = sum(1 for c in client_id_to_indices if len(client_id_to_indices[c]) > 0)
print(f"Assigned identities to {num_nonempty} non-empty clients out of {NUM_CLIENTS} total clients.")

# Build client dataloaders
client_loaders = {}
for c in range(NUM_CLIENTS):
    inds = client_id_to_indices[c]
    if len(inds) == 0:
        client_loaders[c] = None
        continue
    sub_paths = [all_valid_paths[i] for i in inds]
    sub_labels = [all_valid_labels[i] for i in inds]
    ds = FaceImageDataset(sub_paths, sub_labels, transform=transform)
    loader = DataLoader(ds, batch_size=LOCAL_BATCH_SIZE, shuffle=True)
    client_loaders[c] = loader

# Build a small held-out test loader from each identity (take one image per identity)
test_paths = []
test_labels = []
for id_, paths in valid_id_to_paths.items():
    p = paths[-1]
    test_paths.append(p)
    test_labels.append(label_map[id_])
test_ds = FaceImageDataset(test_paths, test_labels, transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# -----------------------------
# Model: small LeNet-like CNN
# -----------------------------
class LeNetSmall(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(INPUT_CHANNELS, 6, kernel_size=5)
        self.pool = nn.AvgPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        with torch.no_grad():
            dummy = torch.zeros(1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
            x = self.pool(F.relu(self.conv1(dummy)))
            x = self.pool(F.relu(self.conv2(x)))
            flat = x.view(1, -1).shape[1]
        self.fc1 = nn.Linear(flat, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x, return_features=False):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        feat = F.relu(self.fc1(x))
        feat = F.relu(self.fc2(feat))
        logits = self.fc3(feat)
        if return_features:
            return feat, logits
        return logits

# Instantiate global model
global_model = LeNetSmall(NUM_CLASSES).to(DEVICE)
global_state = {k: v.cpu() for k, v in global_model.state_dict().items()}

# Helper to create model from state dict
def model_from_state(state_dict):
    m = LeNetSmall(NUM_CLASSES).to(DEVICE)
    sd = {k: v.to(DEVICE) for k, v in state_dict.items()}
    m.load_state_dict(sd)
    return m

# -----------------------------
# Utility: ensure image tensors have shape (B,C,H,W)
# -----------------------------
def ensure_batch_image_shape(x):
    # x may be (C,H,W), (1,C,H,W), (B,C,H,W), or (B,1,C,H,W) etc.
    # Return a tensor with shape (B,C,H,W)
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            return x.unsqueeze(0)
        if x.dim() == 5 and x.size(1) == 1:
            # shape (B,1,C,H,W) -> squeeze dim1
            return x.squeeze(1)
        if x.dim() == 5 and x.size(2) == 1:
            # shape (B,C,1,H,W) unlikely, try to reshape
            return x.squeeze(2)
        return x
    else:
        return x

# -----------------------------
# FedAvg client update (single-step) with gradient capture
# -----------------------------
def get_model_copy_for_client(state_dict):
    return model_from_state(state_dict)

def client_update(client_id, global_state, return_grad_snapshot=False):
    loader = client_loaders[client_id]
    if loader is None:
        return None, None, None, None
    local_model = get_model_copy_for_client(global_state)
    opt = torch.optim.SGD(local_model.parameters(), lr=CLIENT_LR, momentum=MOMENTUM)
    local_model.train()
    grad_snapshot = None
    captured_label = None
    captured_image = None
    for epoch in range(LOCAL_EPOCHS):
        for x, y, pth in loader:
            # x: (B,C,H,W) normally; guard against extra dims
            x = ensure_batch_image_shape(x)
            # Defensive: if DataLoader produced shape (B,1,C,H,W) or similar, squeeze
            if x.dim() == 5 and x.size(1) == 1:
                x = x.squeeze(1)
            x = x.to(DEVICE)
            y = torch.tensor(y) if not isinstance(y, torch.Tensor) else y
            y = y.to(DEVICE)
            opt.zero_grad()
            logits = local_model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            # Capture per-parameter gradients BEFORE optimizer.step (this is what attacker sees)
            if return_grad_snapshot and grad_snapshot is None:
                # capture per-parameter gradients (CPU)
                grad_snapshot = [p.grad.detach().cpu().clone() if p.grad is not None else None for p in local_model.parameters()]

                # Ensure we capture a single example (first in the batch)
                # y may be shape (B,) or (B,1); take first element and keep batch dim (1,)
                if y.dim() == 0:
                    captured_label = y.detach().cpu().clone().unsqueeze(0)   # scalar -> (1,)
                else:
                    captured_label = y.detach().cpu().clone().view(-1)[0:1]  # (1,)

                # captured_image: ensure shape (1,C,H,W)
                captured_image = x.detach().cpu().clone()[0:1]               # keep first example only

                # Debug print (optional)
                # print(f"[DEBUG] client {client_id} captured shapes: x {x.shape}, y {y.shape}")

            opt.step()
            # For strong leakage baseline, break after first batch
            if LOCAL_EPOCHS == 1:
                break
    # compute delta = local_params - global_params (on CPU)
    new_state = {k: v.cpu() for k, v in local_model.state_dict().items()}
    delta = {k: new_state[k] - global_state[k] for k in new_state.keys()}
    return delta, grad_snapshot, captured_label, captured_image

# -----------------------------
# Server aggregation (FedAvg)
# -----------------------------
def server_aggregate(global_state, client_deltas, client_sizes):
    total = sum(client_sizes) if sum(client_sizes) > 0 else 1
    agg = {k: torch.zeros_like(v) for k, v in global_state.items()}
    for delta, size in zip(client_deltas, client_sizes):
        if delta is None:
            continue
        weight = size / total
        for k in agg.keys():
            agg[k] += delta[k] * weight
    new_state = {k: global_state[k] + agg[k] for k in global_state.keys()}
    return new_state

# -----------------------------
# Reconstruction utilities (IDLG with L-BFGS)
# -----------------------------
def flatten_grads(grad_list):
    vecs = []
    for g in grad_list:
        if g is None:
            continue
        # ensure on DEVICE
        if not g.is_cuda:
            g = g.to(DEVICE)
        vecs.append(g.view(-1))
    if len(vecs) == 0:
        return torch.tensor([], device=DEVICE)
    return torch.cat(vecs)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) + \
           torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))

def assert_shapes_for_recon(x_init, observed_label, observed_grads):
    # x_init should be (1,C,H,W)
    if not isinstance(x_init, torch.Tensor):
        raise RuntimeError("x_init must be a torch.Tensor")
    if x_init.dim() != 4 or x_init.size(0) != 1:
        raise RuntimeError(f"Reconstruction input must be (1,C,H,W), got {x_init.shape}")
    # observed_label should be (1,)
    if observed_label is None:
        raise RuntimeError("Observed label is None")
    if observed_label.dim() != 1 or observed_label.size(0) != 1:
        raise RuntimeError(f"Observed label must be shape (1,), got {observed_label.shape}")
    # grads: ensure flattened length > 0
    flat = flatten_grads(observed_grads)
    if flat.numel() == 0:
        raise RuntimeError("Observed grads flattened to zero length; check grad capture")

def reconstruct_idlg_lbfgs(model_state_for_attack, observed_grads_cpu, observed_label_cpu, iters=ATTACK_LBFGS_MAX_ITER):
    # Ensure observed_label_cpu is shape (1,)
    if observed_label_cpu is None:
        raise ValueError("No observed label provided for reconstruction")
    if observed_label_cpu.dim() == 0:
        observed_label_cpu = observed_label_cpu.unsqueeze(0)
    else:
        observed_label_cpu = observed_label_cpu.view(-1)[0:1]

    # Build a model instance loaded with the pre-update state (on DEVICE)
    model = model_from_state(model_state_for_attack)
    model.eval()

    # Move observed grads to DEVICE and flatten
    obs = [g.to(DEVICE) if g is not None else None for g in observed_grads_cpu]
    obs_flat = flatten_grads(obs).detach()

    # Initialize x_hat in the same normalized space as training ([-1,1])
    if ATTACK_INIT == "random":
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)
    else:
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)

    # Defensive shape assertion
    try:
        assert_shapes_for_recon(x_hat, observed_label_cpu, observed_grads_cpu)
    except Exception as e:
        raise RuntimeError(f"Pre-reconstruction shape check failed: {e}")

    # Use LBFGS optimizer
    optimizer = torch.optim.LBFGS([x_hat], max_iter=iters, line_search_fn='strong_wolfe' if ATTACK_LBFGS_LINESEARCH else None)

    def closure():
        optimizer.zero_grad()
        x_in = ensure_batch_image_shape(x_hat)
        logits = model(x_in)
        loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
        grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
        grads_hat_flat = flatten_grads(grads_hat)
        if obs_flat.numel() == 0 or grads_hat_flat.numel() == 0:
            loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            loss.backward()
            return loss
        loss_match = F.mse_loss(grads_hat_flat, obs_flat)
        loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception as e:
        # LBFGS sometimes raises on line search; fallback to Adam for a while
        print("LBFGS failed, falling back to Adam:", e)
        opt_adam = torch.optim.Adam([x_hat], lr=0.05)
        for _ in range(500):
            opt_adam.zero_grad()
            x_in = ensure_batch_image_shape(x_hat)
            logits = model(x_in)
            loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
            grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
            grads_hat_flat = flatten_grads(grads_hat)
            loss_match = F.mse_loss(grads_hat_flat, obs_flat)
            loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
            loss.backward()
            opt_adam.step()
            with torch.no_grad():
                x_hat.clamp_(-1, 1)

    with torch.no_grad():
        x_rec = x_hat.clamp_(-1, 1).detach().cpu()
    return x_rec

# -----------------------------
# Evaluation metrics
# -----------------------------
def denorm(img_tensor):
    img = img_tensor.clone()
    img = (img * 0.5) + 0.5
    return img.clamp(0,1)

def compute_ssim(img1, img2):
    a = img1.squeeze(0).permute(1,2,0).numpy()
    b = img2.squeeze(0).permute(1,2,0).numpy()
    s = 0.0
    for ch in range(a.shape[2]):
        s += ssim(a[:,:,ch], b[:,:,ch], data_range=1.0)
    return s / a.shape[2]

def compute_cosine_similarity_feature(model_state, img1, img2):
    model = model_from_state(model_state)
    model.eval()
    with torch.no_grad():
        f1, _ = model(ensure_batch_image_shape(img1).to(DEVICE), return_features=True)
        f2, _ = model(ensure_batch_image_shape(img2).to(DEVICE), return_features=True)
    f1 = f1.cpu().numpy()
    f2 = f2.cpu().numpy()
    return float(cosine_similarity(f1, f2)[0,0])

def evaluate_accuracy(model_state, dataloader):
    model = model_from_state(model_state)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x = ensure_batch_image_shape(x).to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

# -----------------------------
# Training loop with attack during training
# -----------------------------
attack_logs = []
global_state = {k: v.cpu() for k, v in global_model.state_dict().items()}

for rnd in range(ROUNDS):
    available_clients = [c for c in range(NUM_CLIENTS) if client_loaders[c] is not None]
    sampled = random.sample(available_clients, min(CLIENTS_PER_ROUND, len(available_clients)))
    client_deltas = []
    client_sizes = []

    victim_client = sampled[0] if len(sampled) > 0 else None
    captured_grad = None
    captured_label = None
    captured_image = None
    model_state_sent_to_client = None

    for c in sampled:
        model_state_sent_to_client = {k: v.clone() for k, v in global_state.items()}
        if c == victim_client:
            delta, grad_snapshot, cap_label, cap_image = client_update(c, model_state_sent_to_client, return_grad_snapshot=True)
            captured_grad = grad_snapshot
            captured_label = cap_label
            captured_image = cap_image
        else:
            delta, _, _, _ = client_update(c, model_state_sent_to_client, return_grad_snapshot=False)
        client_deltas.append(delta)
        size = len(client_loaders[c].dataset) if client_loaders[c] is not None else 0
        client_sizes.append(size)

    global_state = server_aggregate(global_state, client_deltas, client_sizes)

    if rnd % PRINT_EVERY == 0 or rnd == ROUNDS - 1:
        acc = evaluate_accuracy(global_state, test_loader)
        print(f"Round {rnd:03d}/{ROUNDS:03d} - Test accuracy: {acc:.4f} - sampled {len(sampled)} clients")

    if (rnd % ATTACK_EVERY_N_ROUNDS == 0) and (captured_grad is not None):
        # Ensure captured_label and captured_image are present and shaped correctly
        try:
            if captured_label is None or captured_image is None:
                raise RuntimeError("Captured label/image missing for attack")
            # Force shapes: label (1,), image (1,C,H,W)
            if captured_label.dim() == 0:
                captured_label = captured_label.unsqueeze(0)
            else:
                captured_label = captured_label.view(-1)[0:1]
            if captured_image.dim() == 3:
                captured_image = captured_image.unsqueeze(0)
            elif captured_image.dim() == 5 and captured_image.size(1) == 1:
                captured_image = captured_image.squeeze(1)
            captured_image = captured_image[0:1]  # ensure single example
        except Exception as e:
            print("Captured data shape error, skipping attack this round:", e)
            continue

        try:
            reconstructed = reconstruct_idlg_lbfgs(model_state_sent_to_client, captured_grad, captured_label, iters=ATTACK_LBFGS_MAX_ITER)
        except Exception as e:
            print("Reconstruction failed:", e)
            reconstructed = None

        if reconstructed is not None:
            rec_den = denorm(reconstructed)
            true_den = denorm(captured_image.cpu())

            ssim_val = compute_ssim(rec_den, true_den)
            cos_val = compute_cosine_similarity_feature(model_state_sent_to_client, rec_den.unsqueeze(0), true_den.unsqueeze(0))
            success = float(cos_val >= ATTACK_COSINE_THRESHOLD)

            def tensor_to_pil(img_tensor, fname):
                arr = (img_tensor.squeeze(0).permute(1,2,0).numpy() * 255).astype(np.uint8)
                Image.fromarray(arr).save(fname)

            fname_rec = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_rec.png")
            fname_true = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_true.png")
            try:
                tensor_to_pil(rec_den, fname_rec)
                tensor_to_pil(true_den, fname_true)
            except Exception:
                np.save(fname_rec + ".npy", rec_den.numpy())
                np.save(fname_true + ".npy", true_den.numpy())

            log_entry = {
                "round": int(rnd),
                "victim_client": int(victim_client),
                "ssim": float(ssim_val),
                "cosine": float(cos_val),
                "success": float(success),
                "recon_path": fname_rec,
                "true_path": fname_true,
                "label": int(captured_label.item())
            }
            attack_logs.append(log_entry)
            print(f"[Attack] Round {rnd} client {victim_client}: SSIM={ssim_val:.4f}, Cos={cos_val:.4f}, Success={success}")

# -----------------------------
# Final aggregated attack metrics and save logs
# -----------------------------
total_attacks = len(attack_logs)
if total_attacks > 0:
    avg_ssim = sum([e["ssim"] for e in attack_logs]) / total_attacks
    avg_cos = sum([e["cosine"] for e in attack_logs]) / total_attacks
    success_rate = sum([e["success"] for e in attack_logs]) / total_attacks
else:
    avg_ssim = avg_cos = success_rate = 0.0

print("=== Attack summary ===")
print(f"Attacks run: {total_attacks}")
print(f"Avg SSIM: {avg_ssim:.4f}")
print(f"Avg Cosine: {avg_cos:.4f}")
print(f"Success rate (cos>={ATTACK_COSINE_THRESHOLD}): {success_rate:.3f}")

with open(os.path.join(SAVE_RECON_DIR, "attack_logs.json"), "w") as f:
    json.dump(attack_logs, f, indent=2)

print("Reconstruction images and logs saved to:", SAVE_RECON_DIR)


Found 9164 images across 1680 identities (>=2 images each).
Assigned identities to 100 non-empty clients out of 100 total clients.
Round 000/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 0 client 93: SSIM=0.0161, Cos=0.9953, Success=1.0
Round 005/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 5 client 9: SSIM=0.0117, Cos=0.9905, Success=1.0
Round 010/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 10 client 42: SSIM=0.0002, Cos=0.9952, Success=1.0
Round 015/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 15 client 28: SSIM=0.0078, Cos=0.9911, Success=1.0
Round 020/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 20 client 59: SSIM=0.0091, Cos=0.9939, Success=1.0
Round 025/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 25 client 42: SSIM=0.0083, Cos=0.9947, Success=1.0
Round 030/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 30 client 30: SSIM=-0.0007, Cos=0.9921, Success=

# Diffrential Privacy - SGD in FL ( 50 Client)

In [1]:
# Single-cell runnable script: FedAvg (no protection) with ResNet-18 backbone on a local face dataset
# and IDLG-style reconstruction attack during training using L-BFGS.
#
# This updated version fixes the RuntimeError:
#   "result type Float can't be cast to the desired output type Long"
# by ensuring model state tensors used for aggregation are float dtype
# (BatchNorm's num_batches_tracked is Long and caused dtype mismatch).
#
# Requirements: torch, torchvision, numpy, pillow, scikit-image, tqdm, sklearn

import os
import random
import math
import json
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics.pairwise import cosine_similarity

# -----------------------------
# Experiment hyperparameters
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Federated learning variables
NUM_CLIENTS = 50
CLIENTS_PER_ROUND = 10
ROUNDS = 100
LOCAL_EPOCHS = 1
LOCAL_BATCH_SIZE = 1
CLIENT_LR = 0.01
MOMENTUM = 0.9

# Model / data variables
IMAGE_SIZE = 224
INPUT_CHANNELS = 3
NUM_CLASSES = None

# Attack variables (IDLG with L-BFGS)
ATTACK_USE_LBFGS = True
ATTACK_LBFGS_MAX_ITER = 300
ATTACK_LBFGS_LINESEARCH = True
ATTACK_INIT = "random"
ATTACK_TV_WEIGHT = 1e-4
ATTACK_GRAD_MATCH_WEIGHT = 1.0
ATTACK_COSINE_THRESHOLD = 0.5
ATTACK_EVERY_N_ROUNDS = 1

RECON_SHAPE = (1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

PRINT_EVERY = 5
SAVE_RECON_DIR = "./reconstructions"
os.makedirs(SAVE_RECON_DIR, exist_ok=True)

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------------
# Dataset utilities
# -----------------------------
DATA_ROOT = "/kaggle/input/lfw-facial-recognition/Face Recognition"
FOLDERS = ["Faces"]

transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def gather_image_paths(root, folders):
    paths = []
    for f in folders:
        folder = os.path.join(root, f)
        if not os.path.isdir(folder):
            continue
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                paths.append(os.path.join(folder, fname))
    return sorted(paths)

all_image_paths = gather_image_paths(DATA_ROOT, FOLDERS)
if len(all_image_paths) == 0:
    raise RuntimeError(f"No images found in {FOLDERS} under {DATA_ROOT}. Check paths.")

def identity_from_filename(path):
    name = os.path.basename(path)
    base = os.path.splitext(name)[0]
    parts = base.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:-1])
    return base

id_to_paths = defaultdict(list)
for p in all_image_paths:
    id_ = identity_from_filename(p)
    id_to_paths[id_].append(p)

valid_id_to_paths = {k: v for k, v in id_to_paths.items() if len(v) >= 2}
if len(valid_id_to_paths) == 0:
    raise RuntimeError("No identities with >=2 images found. Check filename format or dataset.")

all_valid_paths = []
all_valid_labels = []
unique_ids = sorted(valid_id_to_paths.keys())
label_map = {id_: idx for idx, id_ in enumerate(unique_ids)}
for id_, paths in valid_id_to_paths.items():
    for p in paths:
        all_valid_paths.append(p)
        all_valid_labels.append(label_map[id_])

NUM_CLASSES = len(unique_ids)
print(f"Found {len(all_valid_paths)} images across {NUM_CLASSES} identities (>=2 images each).")

# -----------------------------
# PyTorch Dataset and client partitioning
# -----------------------------
class FaceImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(self.labels[idx])
        return img, label, p

full_dataset = FaceImageDataset(all_valid_paths, all_valid_labels, transform=transform)

client_id_to_indices = {c: [] for c in range(NUM_CLIENTS)}
identities = list(unique_ids)
random.shuffle(identities)
for i, id_ in enumerate(identities):
    client_id = i % NUM_CLIENTS
    for p in valid_id_to_paths[id_]:
        idx = all_valid_paths.index(p)
        client_id_to_indices[client_id].append(idx)

num_nonempty = sum(1 for c in client_id_to_indices if len(client_id_to_indices[c]) > 0)
print(f"Assigned identities to {num_nonempty} non-empty clients out of {NUM_CLIENTS} total clients.")

client_loaders = {}
for c in range(NUM_CLIENTS):
    inds = client_id_to_indices[c]
    if len(inds) == 0:
        client_loaders[c] = None
        continue
    sub_paths = [all_valid_paths[i] for i in inds]
    sub_labels = [all_valid_labels[i] for i in inds]
    ds = FaceImageDataset(sub_paths, sub_labels, transform=transform)
    loader = DataLoader(ds, batch_size=LOCAL_BATCH_SIZE, shuffle=True)
    client_loaders[c] = loader

test_paths = []
test_labels = []
for id_, paths in valid_id_to_paths.items():
    p = paths[-1]
    test_paths.append(p)
    test_labels.append(label_map[id_])
test_ds = FaceImageDataset(test_paths, test_labels, transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# -----------------------------
# Model: ResNet-18 wrapper
# -----------------------------
class ResNet18Wrapper(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        backbone = models.resnet18(pretrained=pretrained)
        in_features = backbone.fc.in_features
        backbone.fc = nn.Identity()
        self.backbone = backbone
        self.classifier = nn.Linear(in_features, num_classes)
    def forward(self, x, return_features=False):
        feats = self.backbone(x)
        logits = self.classifier(feats)
        if return_features:
            return feats, logits
        return logits

global_model = ResNet18Wrapper(NUM_CLASSES, pretrained=True).to(DEVICE)

# IMPORTANT: store global_state as float tensors to avoid dtype mismatch (BatchNorm num_batches_tracked is Long)
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

def model_from_state(state_dict):
    m = ResNet18Wrapper(NUM_CLASSES, pretrained=False).to(DEVICE)
    # state_dict may be float; ensure mapping to model dtype
    sd = {}
    for k, v in state_dict.items():
        # cast to model device dtype
        sd[k] = v.to(DEVICE)
    m.load_state_dict(sd)
    return m

# -----------------------------
# Utility: ensure image tensors have shape (B,C,H,W)
# -----------------------------
def ensure_batch_image_shape(x):
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            return x.unsqueeze(0)
        if x.dim() == 5 and x.size(1) == 1:
            return x.squeeze(1)
        if x.dim() == 5 and x.size(2) == 1:
            return x.squeeze(2)
        return x
    else:
        return x

# -----------------------------
# FedAvg client update (single-step) with gradient capture
# -----------------------------
def get_model_copy_for_client(state_dict):
    return model_from_state(state_dict)

def client_update(client_id, global_state, return_grad_snapshot=False):
    loader = client_loaders[client_id]
    if loader is None:
        return None, None, None, None
    local_model = get_model_copy_for_client(global_state)
    opt = torch.optim.SGD(local_model.parameters(), lr=CLIENT_LR, momentum=MOMENTUM)
    local_model.train()
    grad_snapshot = None
    captured_label = None
    captured_image = None
    for epoch in range(LOCAL_EPOCHS):
        for x, y, pth in loader:
            x = ensure_batch_image_shape(x)
            if x.dim() == 5 and x.size(1) == 1:
                x = x.squeeze(1)
            x = x.to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)
            opt.zero_grad()
            logits = local_model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            if return_grad_snapshot and grad_snapshot is None:
                grad_snapshot = [p.grad.detach().cpu().clone() if p.grad is not None else None for p in local_model.parameters()]
                if y.dim() == 0:
                    captured_label = y.detach().cpu().clone().unsqueeze(0).long()
                else:
                    captured_label = y.detach().cpu().clone().view(-1)[0:1].long()
                captured_image = x.detach().cpu().clone()[0:1]
            opt.step()
            if LOCAL_EPOCHS == 1:
                break
    new_state = {k: v.cpu().to(torch.float32) for k, v in local_model.state_dict().items()}
    delta = {k: new_state[k] - global_state[k] for k in new_state.keys()}
    return delta, grad_snapshot, captured_label, captured_image

# -----------------------------
# Server aggregation (FedAvg) - ensure aggregation uses float dtype
# -----------------------------
def server_aggregate(global_state, client_deltas, client_sizes):
    total = sum(client_sizes) if sum(client_sizes) > 0 else 1
    # create agg with float dtype explicitly
    agg = {k: torch.zeros_like(global_state[k], dtype=torch.float32) for k in global_state.keys()}
    for delta, size in zip(client_deltas, client_sizes):
        if delta is None:
            continue
        weight = float(size) / float(total)
        for k in agg.keys():
            # ensure delta[k] is float
            d = delta[k].to(torch.float32)
            agg[k] = agg[k] + d * weight
    # apply aggregation and keep float dtype
    new_state = {k: (global_state[k].to(torch.float32) + agg[k].to(global_state[k].device)).cpu() for k in global_state.keys()}
    return new_state

# -----------------------------
# Reconstruction utilities (IDLG with L-BFGS)
# -----------------------------
def flatten_grads(grad_list):
    vecs = []
    for g in grad_list:
        if g is None:
            continue
        if not g.is_cuda:
            g = g.to(DEVICE)
        vecs.append(g.view(-1))
    if len(vecs) == 0:
        return torch.tensor([], device=DEVICE)
    return torch.cat(vecs)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) + \
           torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))

def assert_shapes_for_recon(x_init, observed_label, observed_grads):
    if not isinstance(x_init, torch.Tensor):
        raise RuntimeError("x_init must be a torch.Tensor")
    if x_init.dim() != 4 or x_init.size(0) != 1:
        raise RuntimeError(f"Reconstruction input must be (1,C,H,W), got {x_init.shape}")
    if observed_label is None:
        raise RuntimeError("Observed label is None")
    if observed_label.dim() != 1 or observed_label.size(0) != 1:
        raise RuntimeError(f"Observed label must be shape (1,), got {observed_label.shape}")
    flat = flatten_grads(observed_grads)
    if flat.numel() == 0:
        raise RuntimeError("Observed grads flattened to zero length; check grad capture")

def reconstruct_idlg_lbfgs(model_state_for_attack, observed_grads_cpu, observed_label_cpu, iters=ATTACK_LBFGS_MAX_ITER):
    if observed_label_cpu is None:
        raise ValueError("No observed label provided for reconstruction")
    if observed_label_cpu.dim() == 0:
        observed_label_cpu = observed_label_cpu.unsqueeze(0).long()
    else:
        observed_label_cpu = observed_label_cpu.view(-1)[0:1].long()

    model = model_from_state(model_state_for_attack)
    model.eval()

    obs = [g.to(DEVICE) if g is not None else None for g in observed_grads_cpu]
    obs_flat = flatten_grads(obs).detach()

    if ATTACK_INIT == "random":
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)
    else:
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)

    try:
        assert_shapes_for_recon(x_hat, observed_label_cpu, observed_grads_cpu)
    except Exception as e:
        raise RuntimeError(f"Pre-reconstruction shape check failed: {e}")

    optimizer = torch.optim.LBFGS([x_hat], max_iter=iters, line_search_fn='strong_wolfe' if ATTACK_LBFGS_LINESEARCH else None)

    def closure():
        optimizer.zero_grad()
        x_in = ensure_batch_image_shape(x_hat)
        logits = model(x_in)
        loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
        grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
        grads_hat_flat = flatten_grads(grads_hat)
        if obs_flat.numel() == 0 or grads_hat_flat.numel() == 0:
            loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            loss.backward()
            return loss
        loss_match = F.mse_loss(grads_hat_flat, obs_flat)
        loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception as e:
        print("LBFGS failed, falling back to Adam:", e)
        opt_adam = torch.optim.Adam([x_hat], lr=0.05)
        for _ in range(500):
            opt_adam.zero_grad()
            x_in = ensure_batch_image_shape(x_hat)
            logits = model(x_in)
            loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
            grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
            grads_hat_flat = flatten_grads(grads_hat)
            loss_match = F.mse_loss(grads_hat_flat, obs_flat)
            loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
            loss.backward()
            opt_adam.step()
            with torch.no_grad():
                x_hat.clamp_(-1, 1)

    with torch.no_grad():
        x_rec = x_hat.clamp_(-1, 1).detach().cpu()
    return x_rec

# -----------------------------
# Evaluation metrics
# -----------------------------
def denorm(img_tensor):
    img = img_tensor.clone()
    img = (img * 0.5) + 0.5
    return img.clamp(0,1)

def compute_ssim(img1, img2):
    a = img1.squeeze(0).permute(1,2,0).numpy()
    b = img2.squeeze(0).permute(1,2,0).numpy()
    s = 0.0
    for ch in range(a.shape[2]):
        s += ssim(a[:,:,ch], b[:,:,ch], data_range=1.0)
    return s / a.shape[2]

def compute_cosine_similarity_feature(model_state, img1, img2):
    model = model_from_state(model_state)
    model.eval()
    with torch.no_grad():
        f1, _ = model(ensure_batch_image_shape(img1).to(DEVICE), return_features=True)
        f2, _ = model(ensure_batch_image_shape(img2).to(DEVICE), return_features=True)
    f1 = f1.cpu().numpy()
    f2 = f2.cpu().numpy()
    return float(cosine_similarity(f1, f2)[0,0])

def evaluate_accuracy(model_state, dataloader):
    model = model_from_state(model_state)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x = ensure_batch_image_shape(x).to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

# -----------------------------
# Training loop with attack during training
# -----------------------------
attack_logs = []
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

for rnd in range(ROUNDS):
    available_clients = [c for c in range(NUM_CLIENTS) if client_loaders[c] is not None]
    sampled = random.sample(available_clients, min(CLIENTS_PER_ROUND, len(available_clients)))
    client_deltas = []
    client_sizes = []

    victim_client = sampled[0] if len(sampled) > 0 else None
    captured_grad = None
    captured_label = None
    captured_image = None
    model_state_sent_to_client = None

    for c in sampled:
        model_state_sent_to_client = {k: v.clone() for k, v in global_state.items()}
        if c == victim_client:
            delta, grad_snapshot, cap_label, cap_image = client_update(c, model_state_sent_to_client, return_grad_snapshot=True)
            captured_grad = grad_snapshot
            captured_label = cap_label
            captured_image = cap_image
        else:
            delta, _, _, _ = client_update(c, model_state_sent_to_client, return_grad_snapshot=False)
        client_deltas.append(delta)
        size = len(client_loaders[c].dataset) if client_loaders[c] is not None else 0
        client_sizes.append(size)

    global_state = server_aggregate(global_state, client_deltas, client_sizes)

    if rnd % PRINT_EVERY == 0 or rnd == ROUNDS - 1:
        acc = evaluate_accuracy(global_state, test_loader)
        print(f"Round {rnd:03d}/{ROUNDS:03d} - Test accuracy: {acc:.4f} - sampled {len(sampled)} clients")

    if (rnd % ATTACK_EVERY_N_ROUNDS == 0) and (captured_grad is not None):
        try:
            if captured_label is None or captured_image is None:
                raise RuntimeError("Captured label/image missing for attack")
            if captured_label.dim() == 0:
                captured_label = captured_label.unsqueeze(0).long()
            else:
                captured_label = captured_label.view(-1)[0:1].long()
            if captured_image.dim() == 3:
                captured_image = captured_image.unsqueeze(0)
            elif captured_image.dim() == 5 and captured_image.size(1) == 1:
                captured_image = captured_image.squeeze(1)
            captured_image = captured_image[0:1]
        except Exception as e:
            print("Captured data shape error, skipping attack this round:", e)
            continue

        try:
            reconstructed = reconstruct_idlg_lbfgs(model_state_sent_to_client, captured_grad, captured_label, iters=ATTACK_LBFGS_MAX_ITER)
        except Exception as e:
            print("Reconstruction failed:", e)
            reconstructed = None

        if reconstructed is not None:
            rec_den = denorm(reconstructed)
            true_den = denorm(captured_image.cpu())

            ssim_val = compute_ssim(rec_den, true_den)
            cos_val = compute_cosine_similarity_feature(model_state_sent_to_client, rec_den.unsqueeze(0), true_den.unsqueeze(0))
            success = float(cos_val >= ATTACK_COSINE_THRESHOLD)

            def tensor_to_pil(img_tensor, fname):
                arr = (img_tensor.squeeze(0).permute(1,2,0).numpy() * 255).astype(np.uint8)
                Image.fromarray(arr).save(fname)

            fname_rec = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_rec.png")
            fname_true = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_true.png")
            try:
                tensor_to_pil(rec_den, fname_rec)
                tensor_to_pil(true_den, fname_true)
            except Exception:
                np.save(fname_rec + ".npy", rec_den.numpy())
                np.save(fname_true + ".npy", true_den.numpy())

            log_entry = {
                "round": int(rnd),
                "victim_client": int(victim_client),
                "ssim": float(ssim_val),
                "cosine": float(cos_val),
                "success": float(success),
                "recon_path": fname_rec,
                "true_path": fname_true,
                "label": int(captured_label.item())
            }
            attack_logs.append(log_entry)
            print(f"[Attack] Round {rnd} client {victim_client}: SSIM={ssim_val:.4f}, Cos={cos_val:.4f}, Success={success}")

# -----------------------------
# Final aggregated attack metrics and save logs
# -----------------------------
total_attacks = len(attack_logs)
if total_attacks > 0:
    avg_ssim = sum([e["ssim"] for e in attack_logs]) / total_attacks
    avg_cos = sum([e["cosine"] for e in attack_logs]) / total_attacks
    success_rate = sum([e["success"] for e in attack_logs]) / total_attacks
else:
    avg_ssim = avg_cos = success_rate = 0.0

print("=== Attack summary ===")
print(f"Attacks run: {total_attacks}")
print(f"Avg SSIM: {avg_ssim:.4f}")
print(f"Avg Cosine: {avg_cos:.4f}")
print(f"Success rate (cos>={ATTACK_COSINE_THRESHOLD}): {success_rate:.3f}")

with open(os.path.join(SAVE_RECON_DIR, "attack_logs.json"), "w") as f:
    json.dump(attack_logs, f, indent=2)

print("Reconstruction images and logs saved to:", SAVE_RECON_DIR)


Found 9164 images across 1680 identities (>=2 images each).
Assigned identities to 50 non-empty clients out of 50 total clients.
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 226MB/s]


Round 000/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 0 client 46: SSIM=0.0056, Cos=0.3156, Success=0.0




[Attack] Round 1 client 9: SSIM=0.0031, Cos=0.3703, Success=0.0




[Attack] Round 2 client 28: SSIM=0.0057, Cos=0.3647, Success=0.0




[Attack] Round 3 client 28: SSIM=0.0057, Cos=0.2810, Success=0.0




[Attack] Round 4 client 32: SSIM=0.0068, Cos=0.3263, Success=0.0




Round 005/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 5 client 41: SSIM=0.0053, Cos=0.3426, Success=0.0




[Attack] Round 6 client 28: SSIM=0.0052, Cos=0.3413, Success=0.0




[Attack] Round 7 client 28: SSIM=0.0052, Cos=0.3948, Success=0.0




[Attack] Round 8 client 45: SSIM=0.0049, Cos=0.3738, Success=0.0




[Attack] Round 9 client 24: SSIM=0.0046, Cos=0.3476, Success=0.0




Round 010/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 10 client 21: SSIM=0.0045, Cos=0.3905, Success=0.0




[Attack] Round 11 client 21: SSIM=0.0055, Cos=0.3863, Success=0.0




[Attack] Round 12 client 38: SSIM=0.0035, Cos=0.3365, Success=0.0




[Attack] Round 13 client 5: SSIM=0.0064, Cos=0.4566, Success=0.0




[Attack] Round 14 client 13: SSIM=0.0050, Cos=0.4625, Success=0.0




Round 015/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 15 client 32: SSIM=0.0044, Cos=0.4232, Success=0.0




[Attack] Round 16 client 24: SSIM=0.0056, Cos=0.4031, Success=0.0




[Attack] Round 17 client 23: SSIM=0.0032, Cos=0.4910, Success=0.0




[Attack] Round 18 client 15: SSIM=0.0066, Cos=0.4440, Success=0.0




[Attack] Round 19 client 49: SSIM=0.0053, Cos=0.3974, Success=0.0




Round 020/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 20 client 3: SSIM=0.0037, Cos=0.4575, Success=0.0




[Attack] Round 21 client 48: SSIM=0.0059, Cos=0.4948, Success=0.0




[Attack] Round 22 client 22: SSIM=0.0052, Cos=0.4654, Success=0.0




[Attack] Round 23 client 23: SSIM=0.0048, Cos=0.4932, Success=0.0




[Attack] Round 24 client 18: SSIM=0.0041, Cos=0.4818, Success=0.0




Round 025/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 25 client 1: SSIM=0.0040, Cos=0.4744, Success=0.0




[Attack] Round 26 client 15: SSIM=0.0057, Cos=0.4708, Success=0.0




[Attack] Round 27 client 7: SSIM=0.0053, Cos=0.5394, Success=1.0




[Attack] Round 28 client 18: SSIM=0.0038, Cos=0.5371, Success=1.0




[Attack] Round 29 client 33: SSIM=0.0048, Cos=0.5344, Success=1.0




Round 030/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 30 client 30: SSIM=0.0045, Cos=0.5054, Success=1.0




[Attack] Round 31 client 15: SSIM=0.0058, Cos=0.4601, Success=0.0




[Attack] Round 32 client 34: SSIM=0.0044, Cos=0.5142, Success=1.0




[Attack] Round 33 client 4: SSIM=0.0047, Cos=0.4393, Success=0.0




[Attack] Round 34 client 19: SSIM=0.0065, Cos=0.5475, Success=1.0




Round 035/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 35 client 22: SSIM=0.0045, Cos=0.4900, Success=0.0




[Attack] Round 36 client 24: SSIM=0.0053, Cos=0.4959, Success=0.0




[Attack] Round 37 client 46: SSIM=0.0055, Cos=0.4803, Success=0.0




[Attack] Round 38 client 5: SSIM=0.0033, Cos=0.5489, Success=1.0




[Attack] Round 39 client 22: SSIM=0.0046, Cos=0.5410, Success=1.0




Round 040/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 40 client 4: SSIM=0.0016, Cos=0.5413, Success=1.0




[Attack] Round 41 client 38: SSIM=0.0040, Cos=0.4679, Success=0.0




[Attack] Round 42 client 24: SSIM=0.0045, Cos=0.4396, Success=0.0




[Attack] Round 43 client 48: SSIM=0.0054, Cos=0.5479, Success=1.0




[Attack] Round 44 client 15: SSIM=0.0054, Cos=0.5379, Success=1.0




Round 045/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 45 client 49: SSIM=0.0023, Cos=0.4901, Success=0.0




[Attack] Round 46 client 8: SSIM=0.0043, Cos=0.4548, Success=0.0




[Attack] Round 47 client 35: SSIM=0.0053, Cos=0.5102, Success=1.0




[Attack] Round 48 client 14: SSIM=0.0043, Cos=0.5153, Success=1.0




[Attack] Round 49 client 23: SSIM=0.0059, Cos=0.4770, Success=0.0




Round 050/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 50 client 33: SSIM=0.0060, Cos=0.4809, Success=0.0




[Attack] Round 51 client 29: SSIM=0.0051, Cos=0.4768, Success=0.0




[Attack] Round 52 client 27: SSIM=0.0058, Cos=0.5069, Success=1.0




[Attack] Round 53 client 45: SSIM=0.0051, Cos=0.5061, Success=1.0




[Attack] Round 54 client 17: SSIM=0.0038, Cos=0.4208, Success=0.0




Round 055/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 55 client 37: SSIM=0.0048, Cos=0.5605, Success=1.0




[Attack] Round 56 client 1: SSIM=0.0046, Cos=0.4573, Success=0.0




[Attack] Round 57 client 39: SSIM=0.0036, Cos=0.4498, Success=0.0




[Attack] Round 58 client 15: SSIM=0.0043, Cos=0.4823, Success=0.0




[Attack] Round 59 client 30: SSIM=0.0064, Cos=0.6003, Success=1.0




Round 060/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 60 client 8: SSIM=0.0053, Cos=0.5296, Success=1.0




[Attack] Round 61 client 17: SSIM=0.0042, Cos=0.5466, Success=1.0




[Attack] Round 62 client 15: SSIM=0.0058, Cos=0.6035, Success=1.0




[Attack] Round 63 client 31: SSIM=0.0049, Cos=0.5179, Success=1.0




[Attack] Round 64 client 37: SSIM=0.0054, Cos=0.4794, Success=0.0




Round 065/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 65 client 34: SSIM=0.0035, Cos=0.5080, Success=1.0




[Attack] Round 66 client 25: SSIM=0.0036, Cos=0.4842, Success=0.0




[Attack] Round 67 client 32: SSIM=0.0032, Cos=0.5146, Success=1.0




[Attack] Round 68 client 25: SSIM=0.0072, Cos=0.5199, Success=1.0




[Attack] Round 69 client 32: SSIM=0.0055, Cos=0.4124, Success=0.0




Round 070/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 70 client 38: SSIM=0.0047, Cos=0.5250, Success=1.0




[Attack] Round 71 client 37: SSIM=0.0047, Cos=0.5169, Success=1.0




[Attack] Round 72 client 23: SSIM=0.0070, Cos=0.4670, Success=0.0




[Attack] Round 73 client 30: SSIM=0.0052, Cos=0.5279, Success=1.0




[Attack] Round 74 client 33: SSIM=0.0048, Cos=0.4852, Success=0.0




Round 075/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 75 client 44: SSIM=0.0058, Cos=0.4938, Success=0.0




[Attack] Round 76 client 26: SSIM=0.0057, Cos=0.5046, Success=1.0




[Attack] Round 77 client 42: SSIM=0.0048, Cos=0.5230, Success=1.0




[Attack] Round 78 client 22: SSIM=0.0062, Cos=0.5143, Success=1.0




[Attack] Round 79 client 15: SSIM=0.0055, Cos=0.4985, Success=0.0




Round 080/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 80 client 30: SSIM=0.0065, Cos=0.5643, Success=1.0




[Attack] Round 81 client 25: SSIM=0.0066, Cos=0.5350, Success=1.0




[Attack] Round 82 client 47: SSIM=0.0041, Cos=0.5252, Success=1.0




[Attack] Round 83 client 39: SSIM=0.0055, Cos=0.5134, Success=1.0




[Attack] Round 84 client 14: SSIM=0.0053, Cos=0.5399, Success=1.0




Round 085/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 85 client 10: SSIM=0.0050, Cos=0.5398, Success=1.0




[Attack] Round 86 client 41: SSIM=0.0055, Cos=0.5471, Success=1.0




[Attack] Round 87 client 11: SSIM=0.0045, Cos=0.4899, Success=0.0




[Attack] Round 88 client 38: SSIM=0.0061, Cos=0.5625, Success=1.0




[Attack] Round 89 client 32: SSIM=0.0058, Cos=0.4614, Success=0.0




Round 090/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 90 client 45: SSIM=0.0059, Cos=0.5006, Success=1.0




[Attack] Round 91 client 40: SSIM=0.0053, Cos=0.5083, Success=1.0




[Attack] Round 92 client 8: SSIM=0.0052, Cos=0.5274, Success=1.0




[Attack] Round 93 client 21: SSIM=0.0059, Cos=0.5380, Success=1.0




[Attack] Round 94 client 3: SSIM=0.0045, Cos=0.5617, Success=1.0




Round 095/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 95 client 40: SSIM=0.0035, Cos=0.5691, Success=1.0




[Attack] Round 96 client 47: SSIM=0.0050, Cos=0.5797, Success=1.0




[Attack] Round 97 client 43: SSIM=0.0054, Cos=0.5255, Success=1.0




[Attack] Round 98 client 7: SSIM=0.0037, Cos=0.4732, Success=0.0




Round 099/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 99 client 33: SSIM=0.0055, Cos=0.5008, Success=1.0
=== Attack summary ===
Attacks run: 100
Avg SSIM: 0.0050
Avg Cosine: 0.4828
Success rate (cos>=0.5): 0.470
Reconstruction images and logs saved to: ./reconstructions


# Diffrential Privacy - SGD ( 10 client)

In [2]:
# Single-cell runnable script: FedAvg with ResNet-18 backbone + DP-SGD on clients,
# and IDLG-style reconstruction attack during training using L-BFGS.
#
# Notes:
# - This implements a simple, explicit DP-SGD on each client by computing per-sample
#   gradients (via a loop over microbatches / samples), clipping each sample gradient
#   to CLIP_NORM, averaging, adding Gaussian noise, and applying the noisy gradient
#   update to the local model parameters.
# - This is a straightforward, library-free DP implementation intended for clarity.
#   For production-scale experiments, use Opacus or functorch for efficiency.
#
# Requirements: torch, torchvision, numpy, pillow, scikit-image, tqdm, sklearn
# Adjust DATA_ROOT and FOLDERS to your dataset layout.

import os
import random
import math
import json
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics.pairwise import cosine_similarity

# -----------------------------
# Experiment hyperparameters
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Federated learning variables
NUM_CLIENTS = 10
CLIENTS_PER_ROUND = 10
ROUNDS = 100
LOCAL_EPOCHS = 1

# LOCAL_BATCH_SIZE is the client batch size used for computing per-sample grads.
# For DP-SGD it's common to use batch sizes >= 8 to amortize noise; set as needed.
LOCAL_BATCH_SIZE = 8

# Client optimizer hyperparameters (used only for non-DP fallback updates)
CLIENT_LR = 0.01
MOMENTUM = 0.9

# Model / data variables
IMAGE_SIZE = 224
INPUT_CHANNELS = 3
NUM_CLASSES = None  # set after dataset scan

# DP-SGD variables
DP_ENABLED = True
CLIP_NORM = 1.0            # C
NOISE_MULTIPLIER = 1.0     # sigma
# If you want to compute epsilon, use a privacy accountant (not included here)

# Attack variables (IDLG with L-BFGS)
ATTACK_USE_LBFGS = True
ATTACK_LBFGS_MAX_ITER = 300
ATTACK_LBFGS_LINESEARCH = True
ATTACK_INIT = "random"
ATTACK_TV_WEIGHT = 1e-4
ATTACK_GRAD_MATCH_WEIGHT = 1.0
ATTACK_COSINE_THRESHOLD = 0.6
ATTACK_EVERY_N_ROUNDS = 1

RECON_SHAPE = (1, INPUT_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

PRINT_EVERY = 5
SAVE_RECON_DIR = "./reconstructions"
os.makedirs(SAVE_RECON_DIR, exist_ok=True)

# Determinism
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------------
# Dataset utilities
# -----------------------------
# Set your dataset path here
DATA_ROOT = "/kaggle/input/lfw-facial-recognition/Face Recognition"
FOLDERS = ["Faces"]  # or ["detected faces", "Faces"]

transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    # Keep normalized to [-1,1] for reconstruction compatibility
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def gather_image_paths(root, folders):
    paths = []
    for f in folders:
        folder = os.path.join(root, f)
        if not os.path.isdir(folder):
            continue
        for fname in sorted(os.listdir(folder)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                paths.append(os.path.join(folder, fname))
    return sorted(paths)

all_image_paths = gather_image_paths(DATA_ROOT, FOLDERS)
if len(all_image_paths) == 0:
    raise RuntimeError(f"No images found in {FOLDERS} under {DATA_ROOT}. Check paths.")

def identity_from_filename(path):
    name = os.path.basename(path)
    base = os.path.splitext(name)[0]
    parts = base.split("_")
    if len(parts) >= 2:
        return "_".join(parts[:-1])
    return base

id_to_paths = defaultdict(list)
for p in all_image_paths:
    id_ = identity_from_filename(p)
    id_to_paths[id_].append(p)

valid_id_to_paths = {k: v for k, v in id_to_paths.items() if len(v) >= 2}
if len(valid_id_to_paths) == 0:
    raise RuntimeError("No identities with >=2 images found. Check filename format or dataset.")

all_valid_paths = []
all_valid_labels = []
unique_ids = sorted(valid_id_to_paths.keys())
label_map = {id_: idx for idx, id_ in enumerate(unique_ids)}
for id_, paths in valid_id_to_paths.items():
    for p in paths:
        all_valid_paths.append(p)
        all_valid_labels.append(label_map[id_])

NUM_CLASSES = len(unique_ids)
print(f"Found {len(all_valid_paths)} images across {NUM_CLASSES} identities (>=2 images each).")

# -----------------------------
# PyTorch Dataset and client partitioning
# -----------------------------
class FaceImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(self.labels[idx])
        return img, label, p

full_dataset = FaceImageDataset(all_valid_paths, all_valid_labels, transform=transform)

# Partition identities across clients (non-IID)
client_id_to_indices = {c: [] for c in range(NUM_CLIENTS)}
identities = list(unique_ids)
random.shuffle(identities)
for i, id_ in enumerate(identities):
    client_id = i % NUM_CLIENTS
    for p in valid_id_to_paths[id_]:
        idx = all_valid_paths.index(p)
        client_id_to_indices[client_id].append(idx)

num_nonempty = sum(1 for c in client_id_to_indices if len(client_id_to_indices[c]) > 0)
print(f"Assigned identities to {num_nonempty} non-empty clients out of {NUM_CLIENTS} total clients.")

client_loaders = {}
for c in range(NUM_CLIENTS):
    inds = client_id_to_indices[c]
    if len(inds) == 0:
        client_loaders[c] = None
        continue
    sub_paths = [all_valid_paths[i] for i in inds]
    sub_labels = [all_valid_labels[i] for i in inds]
    ds = FaceImageDataset(sub_paths, sub_labels, transform=transform)
    loader = DataLoader(ds, batch_size=LOCAL_BATCH_SIZE, shuffle=True)
    client_loaders[c] = loader

# Held-out test loader (one image per identity)
test_paths = []
test_labels = []
for id_, paths in valid_id_to_paths.items():
    p = paths[-1]
    test_paths.append(p)
    test_labels.append(label_map[id_])
test_ds = FaceImageDataset(test_paths, test_labels, transform=transform)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# -----------------------------
# Model: ResNet-18 wrapper
# -----------------------------
class ResNet18Wrapper(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        backbone = models.resnet18(pretrained=pretrained)
        in_features = backbone.fc.in_features
        backbone.fc = nn.Identity()
        self.backbone = backbone
        self.classifier = nn.Linear(in_features, num_classes)
    def forward(self, x, return_features=False):
        feats = self.backbone(x)
        logits = self.classifier(feats)
        if return_features:
            return feats, logits
        return logits

global_model = ResNet18Wrapper(NUM_CLASSES, pretrained=True).to(DEVICE)

# Store global_state as float tensors to avoid dtype mismatch
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

def model_from_state(state_dict):
    m = ResNet18Wrapper(NUM_CLASSES, pretrained=False).to(DEVICE)
    sd = {k: v.to(DEVICE) for k, v in state_dict.items()}
    m.load_state_dict(sd)
    return m

# -----------------------------
# Utility: ensure image tensors have shape (B,C,H,W)
# -----------------------------
def ensure_batch_image_shape(x):
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            return x.unsqueeze(0)
        if x.dim() == 5 and x.size(1) == 1:
            return x.squeeze(1)
        if x.dim() == 5 and x.size(2) == 1:
            return x.squeeze(2)
        return x
    else:
        return x

# -----------------------------
# DP helpers: per-sample gradient computation, clipping, noise
# -----------------------------
def compute_per_sample_grads(model, loss_fn, x_batch, y_batch):
    """
    Compute per-sample gradients for a batch by looping over samples.
    Returns a list (len = batch_size) of lists of gradients (one list per sample),
    where each gradient list matches model.parameters() order.
    """
    model.zero_grad()
    per_sample_grads = []
    batch_size = x_batch.size(0)
    for i in range(batch_size):
        xi = x_batch[i:i+1]
        yi = y_batch[i:i+1]
        logits = model(xi)
        loss = loss_fn(logits, yi)
        grads = torch.autograd.grad(loss, list(model.parameters()), retain_graph=False, create_graph=False)
        # grads is a tuple of tensors (one per param)
        per_sample_grads.append([g.detach().clone() if g is not None else None for g in grads])
    return per_sample_grads

def clip_and_aggregate(per_sample_grads, clip_norm):
    """
    per_sample_grads: list of samples, each is list of param grads
    Returns aggregated (averaged) clipped gradients (list matching model.parameters()).
    """
    batch_size = len(per_sample_grads)
    # compute norms per sample
    norms = []
    for grads in per_sample_grads:
        total_sq = 0.0
        for g in grads:
            if g is None:
                continue
            total_sq += (g.view(-1).float() ** 2).sum().item()
        norms.append(math.sqrt(total_sq))
    # clip and sum
    summed = [torch.zeros_like(g) if g is not None else None for g in per_sample_grads[0]]
    for i, grads in enumerate(per_sample_grads):
        norm = norms[i]
        clip_factor = 1.0
        if norm > clip_norm and norm > 0:
            clip_factor = clip_norm / (norm + 1e-12)
        for j, g in enumerate(grads):
            if g is None:
                continue
            g_clipped = g * clip_factor
            if summed[j] is None:
                summed[j] = g_clipped.clone()
            else:
                summed[j] += g_clipped
    # average
    avg = [ (s / float(batch_size)) if s is not None else None for s in summed ]
    return avg

def add_noise_to_aggregated(aggregated_grads, clip_norm, noise_multiplier, batch_size):
    """
    Add Gaussian noise to each parameter's aggregated gradient.
    Noise std per parameter = noise_multiplier * clip_norm / batch_size
    """
    noisy = []
    std = noise_multiplier * clip_norm / float(max(1, batch_size))
    for g in aggregated_grads:
        if g is None:
            noisy.append(None)
            continue
        noise = torch.normal(mean=0.0, std=std, size=g.shape, device=g.device, dtype=g.dtype)
        noisy.append(g + noise)
    return noisy

# -----------------------------
# FedAvg client update with DP-SGD
# -----------------------------
def get_model_copy_for_client(state_dict):
    return model_from_state(state_dict)

def client_update(client_id, global_state, return_grad_snapshot=False):
    """
    Performs local training on client data.
    If DP_ENABLED: compute per-sample grads, clip, add noise, and apply noisy update.
    Returns: delta (state diff), optionally grad_snapshot (per-parameter grads for attack),
             captured_label (1,), captured_image (1,C,H,W)
    """
    loader = client_loaders[client_id]
    if loader is None:
        return None, None, None, None

    local_model = get_model_copy_for_client(global_state)
    local_model.train()

    # We'll perform LOCAL_EPOCHS over the client's loader
    grad_snapshot = None
    captured_label = None
    captured_image = None

    loss_fn = nn.CrossEntropyLoss(reduction='mean')

    for epoch in range(LOCAL_EPOCHS):
        for x, y, pth in loader:
            # ensure shapes
            x = ensure_batch_image_shape(x)
            if x.dim() == 5 and x.size(1) == 1:
                x = x.squeeze(1)
            # move to device
            x = x.to(DEVICE)
            # ensure labels are LongTensor
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)

            if DP_ENABLED:
                # Compute per-sample gradients (on DEVICE)
                per_sample_grads = compute_per_sample_grads(local_model, loss_fn, x, y)
                # Clip and aggregate
                aggregated = clip_and_aggregate(per_sample_grads, CLIP_NORM)
                # Add Gaussian noise
                noisy_agg = add_noise_to_aggregated(aggregated, CLIP_NORM, NOISE_MULTIPLIER, x.size(0))
                # Apply noisy aggregated gradients as a gradient descent step:
                # p.data = p.data - lr * noisy_grad
                with torch.no_grad():
                    for p, g_noisy in zip(local_model.parameters(), noisy_agg):
                        if g_noisy is None:
                            continue
                        # ensure dtype/device
                        g_noisy = g_noisy.to(p.device).to(p.dtype)
                        p.data = p.data - CLIENT_LR * g_noisy
                # For attack simulation: capture the (noisy or unclipped?) gradient snapshot.
                # Attacker typically sees raw per-parameter gradients or parameter delta.
                # We'll capture the aggregated (pre-noise) clipped gradient as the observed signal
                if return_grad_snapshot and grad_snapshot is None:
                    # store aggregated (clipped, pre-noise) grads on CPU
                    grad_snapshot = [g.detach().cpu().clone() if g is not None else None for g in aggregated]
                    # capture first sample label and image
                    if y.dim() == 0:
                        captured_label = y.detach().cpu().clone().unsqueeze(0).long()
                    else:
                        captured_label = y.detach().cpu().clone().view(-1)[0:1].long()
                    captured_image = x.detach().cpu().clone()[0:1]
            else:
                # Non-DP baseline: standard SGD step (per-batch)
                opt = torch.optim.SGD(local_model.parameters(), lr=CLIENT_LR, momentum=MOMENTUM)
                opt.zero_grad()
                logits = local_model(x)
                loss = loss_fn(logits, y)
                loss.backward()
                if return_grad_snapshot and grad_snapshot is None:
                    grad_snapshot = [p.grad.detach().cpu().clone() if p.grad is not None else None for p in local_model.parameters()]
                    if y.dim() == 0:
                        captured_label = y.detach().cpu().clone().unsqueeze(0).long()
                    else:
                        captured_label = y.detach().cpu().clone().view(-1)[0:1].long()
                    captured_image = x.detach().cpu().clone()[0:1]
                opt.step()

            # break after first batch if LOCAL_EPOCHS==1 to simulate single-step leakage
            if LOCAL_EPOCHS == 1:
                break

    # compute delta = local_params - global_params (on CPU, float)
    new_state = {k: v.cpu().to(torch.float32) for k, v in local_model.state_dict().items()}
    delta = {k: new_state[k] - global_state[k] for k in new_state.keys()}
    return delta, grad_snapshot, captured_label, captured_image

# -----------------------------
# Server aggregation (FedAvg) - ensure float dtype
# -----------------------------
def server_aggregate(global_state, client_deltas, client_sizes):
    total = sum(client_sizes) if sum(client_sizes) > 0 else 1
    agg = {k: torch.zeros_like(global_state[k], dtype=torch.float32) for k in global_state.keys()}
    for delta, size in zip(client_deltas, client_sizes):
        if delta is None:
            continue
        weight = float(size) / float(total)
        for k in agg.keys():
            d = delta[k].to(torch.float32)
            agg[k] = agg[k] + d * weight
    new_state = {k: (global_state[k].to(torch.float32) + agg[k]).cpu() for k in global_state.keys()}
    return new_state

# -----------------------------
# Reconstruction utilities (IDLG with L-BFGS)
# -----------------------------
def flatten_grads(grad_list):
    vecs = []
    for g in grad_list:
        if g is None:
            continue
        if not g.is_cuda:
            g = g.to(DEVICE)
        vecs.append(g.view(-1))
    if len(vecs) == 0:
        return torch.tensor([], device=DEVICE)
    return torch.cat(vecs)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) + \
           torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))

def assert_shapes_for_recon(x_init, observed_label, observed_grads):
    if not isinstance(x_init, torch.Tensor):
        raise RuntimeError("x_init must be a torch.Tensor")
    if x_init.dim() != 4 or x_init.size(0) != 1:
        raise RuntimeError(f"Reconstruction input must be (1,C,H,W), got {x_init.shape}")
    if observed_label is None:
        raise RuntimeError("Observed label is None")
    if observed_label.dim() != 1 or observed_label.size(0) != 1:
        raise RuntimeError(f"Observed label must be shape (1,), got {observed_label.shape}")
    flat = flatten_grads(observed_grads)
    if flat.numel() == 0:
        raise RuntimeError("Observed grads flattened to zero length; check grad capture")

def reconstruct_idlg_lbfgs(model_state_for_attack, observed_grads_cpu, observed_label_cpu, iters=ATTACK_LBFGS_MAX_ITER):
    if observed_label_cpu is None:
        raise ValueError("No observed label provided for reconstruction")
    if observed_label_cpu.dim() == 0:
        observed_label_cpu = observed_label_cpu.unsqueeze(0).long()
    else:
        observed_label_cpu = observed_label_cpu.view(-1)[0:1].long()

    model = model_from_state(model_state_for_attack)
    model.eval()

    obs = [g.to(DEVICE) if g is not None else None for g in observed_grads_cpu]
    obs_flat = flatten_grads(obs).detach()

    if ATTACK_INIT == "random":
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)
    else:
        x_hat = torch.randn(RECON_SHAPE, device=DEVICE, requires_grad=True)

    try:
        assert_shapes_for_recon(x_hat, observed_label_cpu, observed_grads_cpu)
    except Exception as e:
        raise RuntimeError(f"Pre-reconstruction shape check failed: {e}")

    optimizer = torch.optim.LBFGS([x_hat], max_iter=iters, line_search_fn='strong_wolfe' if ATTACK_LBFGS_LINESEARCH else None)

    def closure():
        optimizer.zero_grad()
        x_in = ensure_batch_image_shape(x_hat)
        logits = model(x_in)
        loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
        grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
        grads_hat_flat = flatten_grads(grads_hat)
        if obs_flat.numel() == 0 or grads_hat_flat.numel() == 0:
            loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            loss.backward()
            return loss
        loss_match = F.mse_loss(grads_hat_flat, obs_flat)
        loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
        loss.backward()
        return loss

    try:
        optimizer.step(closure)
    except Exception as e:
        print("LBFGS failed, falling back to Adam:", e)
        opt_adam = torch.optim.Adam([x_hat], lr=0.05)
        for _ in range(500):
            opt_adam.zero_grad()
            x_in = ensure_batch_image_shape(x_hat)
            logits = model(x_in)
            loss_cls = F.cross_entropy(logits, observed_label_cpu.to(DEVICE))
            grads_hat = torch.autograd.grad(loss_cls, list(model.parameters()), create_graph=True)
            grads_hat_flat = flatten_grads(grads_hat)
            loss_match = F.mse_loss(grads_hat_flat, obs_flat)
            loss = ATTACK_GRAD_MATCH_WEIGHT * loss_match + ATTACK_TV_WEIGHT * tv_loss(x_hat)
            loss.backward()
            opt_adam.step()
            with torch.no_grad():
                x_hat.clamp_(-1, 1)

    with torch.no_grad():
        x_rec = x_hat.clamp_(-1, 1).detach().cpu()
    return x_rec

# -----------------------------
# Evaluation metrics
# -----------------------------
def denorm(img_tensor):
    img = img_tensor.clone()
    img = (img * 0.5) + 0.5
    return img.clamp(0,1)

def compute_ssim(img1, img2):
    a = img1.squeeze(0).permute(1,2,0).numpy()
    b = img2.squeeze(0).permute(1,2,0).numpy()
    s = 0.0
    for ch in range(a.shape[2]):
        s += ssim(a[:,:,ch], b[:,:,ch], data_range=1.0)
    return s / a.shape[2]

def compute_cosine_similarity_feature(model_state, img1, img2):
    model = model_from_state(model_state)
    model.eval()
    with torch.no_grad():
        f1, _ = model(ensure_batch_image_shape(img1).to(DEVICE), return_features=True)
        f2, _ = model(ensure_batch_image_shape(img2).to(DEVICE), return_features=True)
    f1 = f1.cpu().numpy()
    f2 = f2.cpu().numpy()
    return float(cosine_similarity(f1, f2)[0,0])

def evaluate_accuracy(model_state, dataloader):
    model = model_from_state(model_state)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x = ensure_batch_image_shape(x).to(DEVICE)
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y, dtype=torch.long)
            else:
                y = y.long()
            y = y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

# -----------------------------
# Training loop with DP-SGD on clients and attack during training
# -----------------------------
attack_logs = []
global_state = {k: v.cpu().to(torch.float32) for k, v in global_model.state_dict().items()}

for rnd in range(ROUNDS):
    available_clients = [c for c in range(NUM_CLIENTS) if client_loaders[c] is not None]
    sampled = random.sample(available_clients, min(CLIENTS_PER_ROUND, len(available_clients)))
    client_deltas = []
    client_sizes = []

    victim_client = sampled[0] if len(sampled) > 0 else None
    captured_grad = None
    captured_label = None
    captured_image = None
    model_state_sent_to_client = None

    for c in sampled:
        model_state_sent_to_client = {k: v.clone() for k, v in global_state.items()}
        if c == victim_client:
            delta, grad_snapshot, cap_label, cap_image = client_update(c, model_state_sent_to_client, return_grad_snapshot=True)
            captured_grad = grad_snapshot
            captured_label = cap_label
            captured_image = cap_image
        else:
            delta, _, _, _ = client_update(c, model_state_sent_to_client, return_grad_snapshot=False)
        client_deltas.append(delta)
        size = len(client_loaders[c].dataset) if client_loaders[c] is not None else 0
        client_sizes.append(size)

    # Aggregate and update global model
    global_state = server_aggregate(global_state, client_deltas, client_sizes)

    # Periodic evaluation
    if rnd % PRINT_EVERY == 0 or rnd == ROUNDS - 1:
        acc = evaluate_accuracy(global_state, test_loader)
        print(f"Round {rnd:03d}/{ROUNDS:03d} - Test accuracy: {acc:.4f} - sampled {len(sampled)} clients")

    # Attack during training (using pre-update model_state_sent_to_client)
    if (rnd % ATTACK_EVERY_N_ROUNDS == 0) and (captured_grad is not None):
        try:
            if captured_label is None or captured_image is None:
                raise RuntimeError("Captured label/image missing for attack")
            # Force shapes and dtype
            if captured_label.dim() == 0:
                captured_label = captured_label.unsqueeze(0).long()
            else:
                captured_label = captured_label.view(-1)[0:1].long()
            if captured_image.dim() == 3:
                captured_image = captured_image.unsqueeze(0)
            elif captured_image.dim() == 5 and captured_image.size(1) == 1:
                captured_image = captured_image.squeeze(1)
            captured_image = captured_image[0:1]
        except Exception as e:
            print("Captured data shape error, skipping attack this round:", e)
            continue

        try:
            reconstructed = reconstruct_idlg_lbfgs(model_state_sent_to_client, captured_grad, captured_label, iters=ATTACK_LBFGS_MAX_ITER)
        except Exception as e:
            print("Reconstruction failed:", e)
            reconstructed = None

        if reconstructed is not None:
            rec_den = denorm(reconstructed)
            true_den = denorm(captured_image.cpu())

            ssim_val = compute_ssim(rec_den, true_den)
            cos_val = compute_cosine_similarity_feature(model_state_sent_to_client, rec_den.unsqueeze(0), true_den.unsqueeze(0))
            success = float(cos_val >= ATTACK_COSINE_THRESHOLD)

            def tensor_to_pil(img_tensor, fname):
                arr = (img_tensor.squeeze(0).permute(1,2,0).numpy() * 255).astype(np.uint8)
                Image.fromarray(arr).save(fname)

            fname_rec = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_rec.png")
            fname_true = os.path.join(SAVE_RECON_DIR, f"round{rnd:03d}_client{victim_client}_true.png")
            try:
                tensor_to_pil(rec_den, fname_rec)
                tensor_to_pil(true_den, fname_true)
            except Exception:
                np.save(fname_rec + ".npy", rec_den.numpy())
                np.save(fname_true + ".npy", true_den.numpy())

            log_entry = {
                "round": int(rnd),
                "victim_client": int(victim_client),
                "ssim": float(ssim_val),
                "cosine": float(cos_val),
                "success": float(success),
                "recon_path": fname_rec,
                "true_path": fname_true,
                "label": int(captured_label.item())
            }
            attack_logs.append(log_entry)
            print(f"[Attack] Round {rnd} client {victim_client}: SSIM={ssim_val:.4f}, Cos={cos_val:.4f}, Success={success}")

# -----------------------------
# Final aggregated attack metrics and save logs
# -----------------------------
total_attacks = len(attack_logs)
if total_attacks > 0:
    avg_ssim = sum([e["ssim"] for e in attack_logs]) / total_attacks
    avg_cos = sum([e["cosine"] for e in attack_logs]) / total_attacks
    success_rate = sum([e["success"] for e in attack_logs]) / total_attacks
else:
    avg_ssim = avg_cos = success_rate = 0.0

print("=== Attack summary ===")
print(f"Attacks run: {total_attacks}")
print(f"Avg SSIM: {avg_ssim:.4f}")
print(f"Avg Cosine: {avg_cos:.4f}")
print(f"Success rate (cos>={ATTACK_COSINE_THRESHOLD}): {success_rate:.3f}")

with open(os.path.join(SAVE_RECON_DIR, "attack_logs.json"), "w") as f:
    json.dump(attack_logs, f, indent=2)

print("Reconstruction images and logs saved to:", SAVE_RECON_DIR)


Found 9164 images across 1680 identities (>=2 images each).
Assigned identities to 10 non-empty clients out of 10 total clients.




Round 000/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 0 client 2: SSIM=0.0031, Cos=0.2944, Success=0.0




[Attack] Round 1 client 2: SSIM=0.0047, Cos=0.3689, Success=0.0




[Attack] Round 2 client 2: SSIM=0.0054, Cos=0.4416, Success=0.0




[Attack] Round 3 client 8: SSIM=0.0049, Cos=0.5508, Success=0.0




[Attack] Round 4 client 1: SSIM=0.0052, Cos=0.4950, Success=0.0




Round 005/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 5 client 9: SSIM=0.0055, Cos=0.5267, Success=0.0




[Attack] Round 6 client 0: SSIM=0.0029, Cos=0.4846, Success=0.0




[Attack] Round 7 client 6: SSIM=0.0061, Cos=0.5028, Success=0.0




[Attack] Round 8 client 8: SSIM=0.0046, Cos=0.5319, Success=0.0




[Attack] Round 9 client 5: SSIM=0.0035, Cos=0.4898, Success=0.0




Round 010/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 10 client 2: SSIM=0.0061, Cos=0.5258, Success=0.0




[Attack] Round 11 client 2: SSIM=0.0074, Cos=0.4399, Success=0.0




[Attack] Round 12 client 4: SSIM=0.0052, Cos=0.5089, Success=0.0




[Attack] Round 13 client 2: SSIM=0.0046, Cos=0.5245, Success=0.0




[Attack] Round 14 client 9: SSIM=0.0051, Cos=0.5786, Success=0.0




Round 015/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 15 client 9: SSIM=0.0046, Cos=0.6031, Success=1.0




[Attack] Round 16 client 4: SSIM=0.0051, Cos=0.5244, Success=0.0




[Attack] Round 17 client 1: SSIM=0.0061, Cos=0.5913, Success=0.0




[Attack] Round 18 client 5: SSIM=0.0036, Cos=0.5184, Success=0.0




[Attack] Round 19 client 5: SSIM=0.0045, Cos=0.5120, Success=0.0




Round 020/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 20 client 7: SSIM=0.0057, Cos=0.5622, Success=0.0




[Attack] Round 21 client 2: SSIM=0.0061, Cos=0.5138, Success=0.0




[Attack] Round 22 client 6: SSIM=0.0039, Cos=0.5253, Success=0.0




[Attack] Round 23 client 4: SSIM=0.0060, Cos=0.4875, Success=0.0




[Attack] Round 24 client 1: SSIM=0.0056, Cos=0.5296, Success=0.0




Round 025/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 25 client 5: SSIM=0.0052, Cos=0.6271, Success=1.0




[Attack] Round 26 client 8: SSIM=0.0043, Cos=0.4901, Success=0.0




[Attack] Round 27 client 1: SSIM=0.0052, Cos=0.5376, Success=0.0




[Attack] Round 28 client 8: SSIM=0.0058, Cos=0.4896, Success=0.0




[Attack] Round 29 client 4: SSIM=0.0041, Cos=0.5177, Success=0.0




Round 030/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 30 client 7: SSIM=0.0007, Cos=0.5762, Success=0.0




[Attack] Round 31 client 5: SSIM=0.0053, Cos=0.4850, Success=0.0




[Attack] Round 32 client 7: SSIM=0.0057, Cos=0.5638, Success=0.0




[Attack] Round 33 client 6: SSIM=0.0048, Cos=0.5375, Success=0.0




[Attack] Round 34 client 3: SSIM=0.0042, Cos=0.5656, Success=0.0




Round 035/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 35 client 5: SSIM=0.0037, Cos=0.5382, Success=0.0




[Attack] Round 36 client 3: SSIM=0.0054, Cos=0.5534, Success=0.0




[Attack] Round 37 client 4: SSIM=0.0076, Cos=0.5228, Success=0.0




[Attack] Round 38 client 7: SSIM=0.0057, Cos=0.5949, Success=0.0




[Attack] Round 39 client 4: SSIM=0.0051, Cos=0.5113, Success=0.0




Round 040/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 40 client 3: SSIM=0.0047, Cos=0.5481, Success=0.0




[Attack] Round 41 client 1: SSIM=0.0042, Cos=0.5849, Success=0.0




[Attack] Round 42 client 5: SSIM=0.0043, Cos=0.5480, Success=0.0




[Attack] Round 43 client 6: SSIM=0.0056, Cos=0.6056, Success=1.0




[Attack] Round 44 client 5: SSIM=0.0043, Cos=0.6503, Success=1.0




Round 045/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 45 client 4: SSIM=0.0048, Cos=0.5748, Success=0.0




[Attack] Round 46 client 0: SSIM=0.0042, Cos=0.4983, Success=0.0




[Attack] Round 47 client 7: SSIM=0.0026, Cos=0.5854, Success=0.0




[Attack] Round 48 client 7: SSIM=0.0033, Cos=0.5467, Success=0.0




[Attack] Round 49 client 9: SSIM=0.0033, Cos=0.5166, Success=0.0




Round 050/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 50 client 8: SSIM=0.0047, Cos=0.4716, Success=0.0




[Attack] Round 51 client 2: SSIM=0.0058, Cos=0.5466, Success=0.0




[Attack] Round 52 client 9: SSIM=0.0073, Cos=0.5684, Success=0.0




[Attack] Round 53 client 0: SSIM=0.0068, Cos=0.5401, Success=0.0




[Attack] Round 54 client 0: SSIM=0.0047, Cos=0.5352, Success=0.0




Round 055/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 55 client 6: SSIM=0.0041, Cos=0.4894, Success=0.0




[Attack] Round 56 client 9: SSIM=0.0058, Cos=0.5364, Success=0.0




[Attack] Round 57 client 9: SSIM=0.0055, Cos=0.5642, Success=0.0




[Attack] Round 58 client 3: SSIM=0.0050, Cos=0.5318, Success=0.0




[Attack] Round 59 client 8: SSIM=0.0035, Cos=0.5439, Success=0.0




Round 060/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 60 client 6: SSIM=0.0051, Cos=0.5325, Success=0.0




[Attack] Round 61 client 0: SSIM=0.0057, Cos=0.5604, Success=0.0




[Attack] Round 62 client 4: SSIM=0.0026, Cos=0.5916, Success=0.0




[Attack] Round 63 client 9: SSIM=0.0044, Cos=0.5876, Success=0.0




[Attack] Round 64 client 6: SSIM=0.0036, Cos=0.5325, Success=0.0




Round 065/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 65 client 7: SSIM=0.0061, Cos=0.5788, Success=0.0




[Attack] Round 66 client 9: SSIM=0.0041, Cos=0.5807, Success=0.0




[Attack] Round 67 client 3: SSIM=0.0059, Cos=0.5360, Success=0.0




[Attack] Round 68 client 6: SSIM=0.0051, Cos=0.5571, Success=0.0




[Attack] Round 69 client 7: SSIM=0.0059, Cos=0.5426, Success=0.0




Round 070/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 70 client 3: SSIM=0.0044, Cos=0.5063, Success=0.0




[Attack] Round 71 client 1: SSIM=0.0044, Cos=0.4990, Success=0.0




[Attack] Round 72 client 5: SSIM=0.0065, Cos=0.5731, Success=0.0




[Attack] Round 73 client 6: SSIM=0.0045, Cos=0.5468, Success=0.0




[Attack] Round 74 client 0: SSIM=0.0035, Cos=0.5231, Success=0.0




Round 075/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 75 client 3: SSIM=0.0069, Cos=0.6310, Success=1.0




[Attack] Round 76 client 5: SSIM=0.0071, Cos=0.5947, Success=0.0




[Attack] Round 77 client 6: SSIM=0.0037, Cos=0.6143, Success=1.0




[Attack] Round 78 client 3: SSIM=0.0065, Cos=0.5223, Success=0.0




[Attack] Round 79 client 0: SSIM=0.0057, Cos=0.5074, Success=0.0




Round 080/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 80 client 5: SSIM=0.0028, Cos=0.5148, Success=0.0




[Attack] Round 81 client 8: SSIM=0.0037, Cos=0.5862, Success=0.0




[Attack] Round 82 client 1: SSIM=0.0036, Cos=0.5231, Success=0.0




[Attack] Round 83 client 5: SSIM=0.0032, Cos=0.5277, Success=0.0




[Attack] Round 84 client 7: SSIM=0.0056, Cos=0.6189, Success=1.0




Round 085/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 85 client 6: SSIM=0.0041, Cos=0.5095, Success=0.0




[Attack] Round 86 client 1: SSIM=0.0045, Cos=0.5971, Success=0.0




[Attack] Round 87 client 6: SSIM=0.0059, Cos=0.5666, Success=0.0




[Attack] Round 88 client 1: SSIM=0.0046, Cos=0.5790, Success=0.0




[Attack] Round 89 client 8: SSIM=0.0052, Cos=0.5266, Success=0.0




Round 090/100 - Test accuracy: 0.0006 - sampled 10 clients
[Attack] Round 90 client 3: SSIM=0.0059, Cos=0.5182, Success=0.0




[Attack] Round 91 client 1: SSIM=0.0057, Cos=0.5320, Success=0.0




[Attack] Round 92 client 7: SSIM=0.0055, Cos=0.4583, Success=0.0




[Attack] Round 93 client 3: SSIM=0.0034, Cos=0.5852, Success=0.0




[Attack] Round 94 client 8: SSIM=0.0046, Cos=0.5443, Success=0.0




Round 095/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 95 client 6: SSIM=0.0045, Cos=0.5368, Success=0.0




[Attack] Round 96 client 2: SSIM=0.0043, Cos=0.5288, Success=0.0




[Attack] Round 97 client 3: SSIM=0.0041, Cos=0.5438, Success=0.0




[Attack] Round 98 client 9: SSIM=0.0052, Cos=0.5243, Success=0.0




Round 099/100 - Test accuracy: 0.0000 - sampled 10 clients
[Attack] Round 99 client 4: SSIM=0.0063, Cos=0.5614, Success=0.0
=== Attack summary ===
Attacks run: 100
Avg SSIM: 0.0049
Avg Cosine: 0.5373
Success rate (cos>=0.6): 0.070
Reconstruction images and logs saved to: ./reconstructions
