In [None]:
dataset_train_ssl = create_dataset_train_ssl(SwaVModel)
data_loader, dataloader_train_kNN, dataloader_test = get_data_loaders(
    batch_size=batch_size, dataset_train_ssl=dataset_train_ssl
)
model = SwaVModel(dataloader_train_kNN, classes)



In [None]:
import torch
from scipy.stats import entropy
from itertools import combinations
import numpy as np
from sklearn.neighbors import KernelDensity
from itertools import combinations
from scipy.special import comb
# Assuming you have a 'model' and a 'data_loader' already defined

def generate_embeddings(model, data_loader):
    model.eval()
    all_embeddings = []
    with torch.no_grad():
        for batch in data_loader:
            views, _, _ = batch
            zs = [model(view) for view in views]  # Process each view
            zs = torch.stack(zs)  # Stack embeddings
            all_embeddings.append(zs.cpu().numpy())
    return np.concatenate(all_embeddings, axis=1)  # Concatenate along the sample dimension




# Function to calculate the entropy using Kernel Density Estimation
def kde_entropy(embedding, bandwidth=0.1):
    if embedding.ndim == 1:
        embedding = embedding.reshape(-1, 1)  # Reshaping for KDE if 1D

    # Standardizing the data
    embedding = (embedding - np.mean(embedding, axis=0)) / np.std(embedding, axis=0)

    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
    kde.fit(embedding)
    log_dens = kde.score_samples(embedding)

    # Ensure that log densities are not positive
    log_dens = np.minimum(log_dens, 0)

    return -np.mean(log_dens)


def calculate_joint_entropy(embeddings, bandwidth=0.1):
    # Reshape and stack all embeddings for joint entropy calculation
    stacked_embeddings = np.vstack([e.reshape(-1, 1) for e in embeddings])
    return kde_entropy(stacked_embeddings, bandwidth)

def calculate_mutual_information(embeddings, bandwidth=0.1):
    mi = 0
    n_views = len(embeddings)
    for combo in combinations(range(n_views), 2):
        # Stack the embeddings for the combination
        combo_embedding = np.vstack([embeddings[i].reshape(-1, 1) for i in combo])
        
        # Calculate the joint entropy for the combination
        joint_entropy = kde_entropy(combo_embedding, bandwidth)
        
        # Calculate individual entropies
        individual_entropies = sum(kde_entropy(embeddings[i].reshape(-1, 1), bandwidth) for i in combo)
        
        mi += individual_entropies - joint_entropy
    return mi / comb(n_views, 2)



from scipy.special import comb
from tqdm import tqdm
# Initialize an array to store the results
results = []
embeddings = generate_embeddings(model, data_loader)

for sample_idx in tqdm(range(400)): #embeddings.shape[1]
    # Example usage for the 18th sample
    sample_embeddings = embeddings[:, sample_idx, :]

    # Calculate mean embedding for the current sample
    mean_embedding = np.mean(sample_embeddings, axis=0)

    # Calculate entropy of the mean embedding
    mean_entropy = kde_entropy(mean_embedding.reshape(-1, 1))

    # Reshape embeddings for each view for joint entropy and MI calculations
    reshaped_embeddings = [sample_embeddings[i, :].reshape(-1, 1) for i in range(sample_embeddings.shape[0])]

    # Calculate joint entropy of all embeddings for the current sample
    joint_entropy = calculate_joint_entropy(reshaped_embeddings)

    # Calculate mutual information for the current sample
    mi = calculate_mutual_information(reshaped_embeddings)

    # Append the results for this sample
    results.append([mi, mean_entropy, joint_entropy])

# Convert results to a numpy array
results_array = np.array(results)

# results_array is of shape (samples, 3) where columns are MI, Mean Entropy, Joint Entropy



In [None]:
import numpy as np
from scipy import stats

# Assuming you have the results_array
satisfactory_samples = np.sum((results_array[:, 0] <= results_array[:, 1]) & (results_array[:, 1] <= results_array[:, 2]))
total_samples = results_array.shape[0]
proportion_satisfactory = satisfactory_samples / total_samples

# Set the threshold for 'most' of the data
threshold = 0.85

# Perform a binomial test
test_result = stats.binomtest(satisfactory_samples, total_samples, threshold, alternative='greater')

# Output the results
print(f"Number of samples satisfying the condition: {satisfactory_samples} out of {total_samples}")
print(f"Proportion of satisfactory samples: {proportion_satisfactory:.2f}")
print(f"Binomial test p-value for threshold {threshold}: {test_result.pvalue:.4f}")
