**1. Title & Imports**

In [None]:
# ðŸ“Œ AdaFLUX-LoRA Similarity & Divergence Analysis Notebook
# --------------------------------------------------------
# This notebook evaluates how similar different client models are using:
# âœ“ Jensenâ€“Shannon Divergence
# âœ“ Cosine similarity of model logits
# âœ“ Cluster vs Non-cluster model performance

import os, sys, json
import numpy as np
import torch
from pathlib import Path
from itertools import combinations
import seaborn as sns
import matplotlib.pyplot as plt

# Import project modules (adjust path if needed)
PROJECT_ROOT = Path("..").resolve()
sys.path.append(str(PROJECT_ROOT))

from server.router import FLUXRouter
from client_adaflux_lora import collect_descriptor_vector, FlowerClient
from visualize_clusters import plot_flux_embeddings

# HuggingFace utilities
from torch.nn.functional import softmax, log_softmax


**2. Load Saved FL State (Cluster + Params + Assignments)**

In [None]:
# âœ” Loads server-side cluster assignments and checkpoint paths

CHECKPOINT_DIR = Path("../checkpoints")
CLUSTER_META_PATH = Path("../results/cluster_assignments.json")

if not CLUSTER_META_PATH.exists():
    raise FileNotFoundError("Cluster metadata not found. Run federated training first.")

with open(CLUSTER_META_PATH, "r") as f:
    cluster_info = json.load(f)

cluster_assignments = cluster_info["client_to_cluster"]
clusters = {}
for cid, cl in cluster_assignments.items():
    clusters.setdefault(str(cl), []).append(str(cid))

print("ðŸ“Œ Loaded Clusters:", clusters)


**3. Restore Trained Models (Cluster & Global Baselines)**

In [None]:
# Load a fresh model for evaluation
from models.model_loader import load_base_model  # You already have this in client init

device = "cuda" if torch.cuda.is_available() else "cpu"

client_models = {}
global_model = None

for cid, ckpt in cluster_info["client_checkpoints"].items():
    model = load_base_model()
    state = torch.load(ckpt, map_location=device)
    tensors = [torch.tensor(t).to(device) for t in state["tensors"]]
    keys = state["keys"]

    # Inject LoRA params
    sd = model.state_dict()
    for k, v in zip(keys, tensors): sd[k].copy_(v)

    client_models[str(cid)] = model.to(device).eval()

# Load global baseline model
GLOBAL_CKPT = cluster_info["global_checkpoint"]
global_model = load_base_model().to(device)
state_global = torch.load(GLOBAL_CKPT, map_location=device)
for k, v in zip(state_global["keys"], state_global["tensors"]):
    global_model.state_dict()[k].copy_(torch.tensor(v).to(device))

print(f"Loaded: {len(client_models)} client models + global baseline.")


**4. JS Divergence Function (Cleaned & Efficient)**

In [None]:
@torch.no_grad()
def js_divergence_matrix(models: dict, dataloader, topk=50):
    ids = list(models.keys())
    C = len(ids)
    matrix = np.zeros((C, C))

    for (i, ci), (j, cj) in combinations(list(enumerate(ids)), 2):
        Pi_probs, Pj_probs = [], []

        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch[0].items()}
            Pi = softmax(models[ci](**inputs).logits, dim=-1)
            Pj = softmax(models[cj](**inputs).logits, dim=-1)

            # Top-k reduce for speed
            idx = torch.topk((Pi + Pj) / 2, topk, dim=-1).indices
            Pi = torch.gather(Pi, -1, idx)
            Pj = torch.gather(Pj, -1, idx)

            M = 0.5 * (Pi + Pj)
            js_token = 0.5 * (
                torch.sum(Pi * (torch.log(Pi + 1e-9) - torch.log(M + 1e-9)), dim=-1) +
                torch.sum(Pj * (torch.log(Pj + 1e-9) - torch.log(M + 1e-9)), dim=-1)
            ).mean()

            Pi_probs.append(float(js_token))

        matrix[i, j] = matrix[j, i] = np.mean(Pi_probs)

    return ids, matrix


**5. Run Experiments (Clustered vs Non-clustered)**

In [None]:
# Load probe dataset from a real client's eval set
example_client = list(client_models.keys())[0]
probe_loader = FlowerClient.load_eval_loader(cid=int(example_client))

# 1) JS divergence across all clients
ids, full_js = js_divergence_matrix(client_models, probe_loader)

# 2) JS divergence against global baseline
global_compare = []
for cid in ids:
    _, mat = js_divergence_matrix(
        {cid: client_models[cid], "global": global_model},
        probe_loader
    )
    global_compare.append(mat[0,1])

print("JS Divergence vs Global:\n", dict(zip(ids, global_compare)))


**6. Visualization**

In [None]:
sns.set(font_scale=1.1)
plt.figure(figsize=(10, 8))
sns.heatmap(full_js, annot=True, fmt=".03f", xticklabels=ids, yticklabels=ids, cmap="viridis")
plt.title("Jensenâ€“Shannon Divergence Between Client Models")
plt.show()

plt.figure(figsize=(10, 4))
sns.barplot(x=ids, y=global_compare)
plt.ylabel("JS Divergence to Global Model")
plt.title("Similarity of Each Client to Global Model")
plt.show()
