In [None]:
import pickle
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import concurrent.futures

In [None]:
# Define the directory containing the .bin files
directory_path = (
    "/home/zep/fl_powerpropagation/outputs/2024-06-11/18-26-01/results/parameters"
)

# Get a list of all .bin files
bin_files = sorted(glob.glob(f"{directory_path}/parameters*.bin"))
# remove first element of bin_files
bin_files.pop(0)


# Function to read binary file
def load_binary_file(filepath):
    with open(filepath, "rb") as f:
        data = np.fromfile(
            f, dtype=np.float32
        )  # Assuming the weights are stored as float32
    return data


# Load the model parameters from binary files
masks = []
i = 0
for file in bin_files:
    try:
        model_state = load_binary_file(file)
        mask = [param != 0 for param in model_state]
        masks.append(mask)
        print(f"Loaded {i}")
        i += 1
    except Exception as e:
        print(f"Error loading {file}: {e}")
    if i > 10:
        break

In [None]:
# Convert model parameters to binary masks
# def create_mask(model_state):
#     return model_state != 0
# masks = [create_mask(model_state) for model_state in models]

masks = [[np.array(layer_mask) for layer_mask in mask] for mask in masks]

In [None]:
# # Function to compute overlap percentage between two masks
# def compute_overlap_percentage(mask1, mask2):
#     # total_weights = np.sum([np.sum(m1) for m1 in mask1])  # Total number of weights
#     n_mask1 = np.sum([np.sum(m1) for m1 in mask1])  # count number of weight for each mask
#     n_mask2 = np.sum([np.sum(m2) for m2 in mask2])
#     overlap_weights = np.sum([np.sum(m1 & m2) for m1, m2 in zip(mask1, mask2)])  # Overlapping weights
#     return 100 - (overlap_weights / max(n_mask1, n_mask2)) * 100  # Percentage of overlap

# # Compute overlap percentage matrix
# num_masks = len(masks)
# overlap_matrix = np.zeros((num_masks, num_masks))

# for i in range(num_masks):
#     for j in range(num_masks):
#         overlap_matrix[i, j] = compute_overlap_percentage(masks[i], masks[j])

In [None]:
# Function to compute overlap percentage between two masks
def compute_overlap_percentage(mask1, mask2):
    n_mask1 = np.sum([
        np.sum(m1) for m1 in mask1
    ])  # count number of weight for each mask
    n_mask2 = np.sum([np.sum(m2) for m2 in mask2])
    overlap_weights = np.sum([
        np.sum(m1 & m2) for m1, m2 in zip(mask1, mask2)
    ])  # Overlapping weights
    return (
        100 - (overlap_weights / max(n_mask1, n_mask2)) * 100
    )  # Percentage of overlap


# Compute overlap percentage matrix
num_masks = len(masks)
overlap_matrix = np.zeros((num_masks, num_masks))


def compute_overlap(i, j):
    overlap_matrix[i, j] = compute_overlap_percentage(masks[i], masks[j])


# Use multi-threading to parallelize the computation
with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = []
    for i in range(num_masks):
        for j in range(num_masks):
            futures.append(executor.submit(compute_overlap, i, j))

    # Wait for all computations to complete
    concurrent.futures.wait(futures)

In [None]:
# Plot heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(
    overlap_matrix,
    annot=False,
    fmt=".2f",
    cmap="viridis",
    xticklabels=range(num_masks),
    yticklabels=range(num_masks),
)
plt.xlabel("Round")
plt.ylabel("Round")
plt.title("Overlap Percentage Heatmap Between Rounds")
plt.tight_layout()
plt.show()