In [None]:
import torch
import numpy as np
from sklearn.cluster import SpectralClustering
from scipy.spatial.distance import cosine
import sys
sys.path.append('/home/cindy2000_sh/TaskVectorBasis/L&S/language')
sys.path.append('/home/cindy2000_sh/TaskVectorBasis/L&S/language/src')
import pickle

def load_checkpoint(path):
    try:
        finetuned_model = torch.load(path)
    except:
        finetuned_model = pickle.load(open(path, 'rb'))
    return finetuned_model

def compute_task_vector(trainable_params, finetuned_checkpoint, pretrained_checkpoint, mask):
    task_vector = {}
    counter = 0
    pretrained_state_dict = pretrained_checkpoint.state_dict()
    finetuned_state_dict = finetuned_checkpoint.state_dict()
    for name, param in pretrained_state_dict.items():
        if name in trainable_params:
            param_pre = param
            param_ft = finetuned_state_dict[name]

            if counter < len(mask):
                mask[counter] = mask[counter].to(param_ft.device)
                task_vector[name] = (param_ft - param_pre) * mask[counter]
                counter += 1
            else:
                task_vector[name] = param_ft - param_pre

    return task_vector

def flatten_task_vector(task_vector):
    flattened_vector = []

    for param in task_vector.values():
        flattened_vector.append(param.flatten())

    return torch.cat(flattened_vector)

def compute_cosine_similarity_matrix(task_vectors):
    num_tasks = len(task_vectors)
    similarity_matrix = np.zeros((num_tasks, num_tasks))

    for i in range(num_tasks):
        for j in range(num_tasks):
            if i <= j:  
                similarity = 1 - cosine(task_vectors[i].numpy(), task_vectors[j].numpy())
                similarity_matrix[i, j] = similarity
                similarity_matrix[j, i] = similarity  

    return similarity_matrix

def compute_laplacian(similarity_matrix):
    degree_matrix = np.diag(similarity_matrix.sum(axis=1))
    laplacian = degree_matrix - similarity_matrix
    return laplacian

def spectral_clustering(similarity_matrix, num_clusters):
    laplacian = compute_laplacian(similarity_matrix)
    eigenvalues, eigenvectors = np.linalg.eigh(laplacian)
    print(eigenvalues)

    gaps = np.diff(eigenvalues)
    optimal_k = np.argmax(gaps) + 1

    if num_clusters is None:
        num_clusters = optimal_k

    sc = SpectralClustering(n_clusters=num_clusters, affinity='precomputed', random_state=42)
    labels = sc.fit_predict(similarity_matrix)

    return labels, optimal_k


In [None]:

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 
model_name = 'ViT-B-32'
root = '/data/common/task-arithmetic'
pretrained_path = root+'/task_vectors_checkpoints/'+model_name+'/zeroshot.pt'
finetuned_paths = [root+'/task_vectors_checkpoints/'+model_name+'/'+dataset_name+'/finetuned.pt' for dataset_name in exam_datasets]
pretrained_model = torch.load(pretrained_path)
pretrained_model_dic = pretrained_model.state_dict()
trainable_params = {}
frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
for k, v in pretrained_model_dic.items():
    if k not in frozen:
        trainable_params[k] = v
masks = [torch.load(f'/home/cindy2000_sh/Localize-and-Stitch/1e-2_vision/{dataset_name}_mask.pt') for dataset_name in exam_datasets] 

pretrained_checkpoint = pretrained_model
task_vectors = []

for i, path in enumerate(finetuned_paths):
    finetuned_checkpoint = load_checkpoint(path)
    task_vector = compute_task_vector(trainable_params, finetuned_checkpoint, pretrained_checkpoint, masks[i])
    flattened_vector = flatten_task_vector(task_vector)
    task_vectors.append(flattened_vector)

similarity_matrix = compute_cosine_similarity_matrix(task_vectors)

# Perform spectral clustering
num_clusters = 2  # You can specify a number or let it find the optimal one (None)
labels, optimal_k = spectral_clustering(similarity_matrix, num_clusters)

print(exam_datasets)
print(f"Cluster labels: {labels}")
print(f"Optimal number of clusters (based on eigenvalue gap): {optimal_k}")
