In [None]:
import sys
sys.path.append('/home/cindy2000_sh/TsakVectorBasis')
from pathlib import Path

import json
from src.task_vectors import NonLinearTaskVector, LinearizedTaskVector
import numpy as np

import torch

Utils

In [None]:
def load_task_vector(model, task, pretrained_ckpt_name, finetuning_mode):
    if finetuning_mode == 'standard':
        pretrained_checkpoint = f'/home/cindy2000_sh/ntk-llm/tangent_task_arithmetic/checkpoints_{pretrained_ckpt_name}/{model}/zeroshot.pt'
        finetuned_checkpoint = f'/home/cindy2000_sh/ntk-llm/tangent_task_arithmetic/checkpoints_{pretrained_ckpt_name}/{model}/{task}Val/finetuned.pt'
        return NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
    else:
        pretrained_checkpoint = f'/home/cindy2000_sh/ntk-llm/tangent_task_arithmetic/checkpoints_{pretrained_ckpt_name}/{model}/{task}Val/linear_zeroshot.pt'
        finetuned_checkpoint = f'/home/cindy2000_sh/ntk-llm/tangent_task_arithmetic/checkpoints_{pretrained_ckpt_name}/{model}/{task}Val/linear_finetuned.pt'
        return LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)

def flatten_task_vector(task_vector):
    flattened_vector = []
    param_keys = []

    for key, tensor in task_vector.vector.items():
        flattened_vector.append(tensor.cpu().numpy().ravel())  
        param_keys.append(key)

    flat_vector = np.concatenate(flattened_vector)
    return flat_vector, param_keys

def load_and_flatten_all_task_vectors(model, task_names, pretrained_ckpt_name, finetuning_mode):
    flattened_vectors = []
    all_param_keys = None  

    for task in task_names:
        task_vector = load_task_vector(model, task, pretrained_ckpt_name, finetuning_mode)
        flat_vector, param_keys = flatten_task_vector(task_vector)

        if all_param_keys is None:
            all_param_keys = param_keys 

        flattened_vectors.append(flat_vector)

    return np.vstack(flattened_vectors), all_param_keys


def recover_task_vector_from_centroid(centroid, param_keys, task_vector_template):
    recovered_vector = {}
    start = 0

    for key in param_keys:
        original_shape = task_vector_template.vector[key].shape

        num_elements = int(np.prod(original_shape))

        recovered_vector[key] = torch.tensor(centroid[start:start + num_elements]).reshape(original_shape)

        start += num_elements

    return recovered_vector

def save_recovered_task_vector_as_checkpoint(recovered_vector, model_checkpoint_path):
    torch.save(recovered_vector, model_checkpoint_path)


def save_all_centroids_as_checkpoints(centroids, param_keys, task_vector_template, output_dir):
    import os

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for idx, centroid in enumerate(centroids):
        recovered_vector = recover_task_vector_from_centroid(centroid, param_keys, task_vector_template)

        checkpoint_path = os.path.join(output_dir, f'centroid_{idx}.pt')

        save_recovered_task_vector_as_checkpoint(recovered_vector, checkpoint_path)
        print(f"Saved centroid {idx} as {checkpoint_path}")

Topk

In [None]:
def top_5_percent_masked_centroid(cluster_points):
    if not isinstance(cluster_points, torch.Tensor):
        cluster_points = torch.tensor(cluster_points, dtype=torch.float32)

    if torch.cuda.is_available():
        cluster_points = cluster_points.to(device='cuda:2')

    n, d = cluster_points.shape
    top_5_percent = int(torch.ceil(torch.tensor(d * 0.05)))
    abs_values = torch.abs(cluster_points)
    thresholds, _ = torch.kthvalue(abs_values, k=d - top_5_percent + 1, dim=1)
    mask = abs_values >= thresholds.unsqueeze(1)
    normalized_mask = mask.float() / n
    masked_rows = normalized_mask * cluster_points
    centroid = masked_rows.sum(dim=0)

    del cluster_points, abs_values, thresholds, mask, normalized_mask, masked_rows
    torch.cuda.empty_cache()

    return centroid

TIES

In [None]:
def topk_values_mask(M, K, return_mask=False):
    if K > 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  

    _, indices = M.abs().topk(k, dim=1, largest=True, sorted=False)
    mask = torch.zeros_like(M, dtype=torch.bool).scatter_(1, indices, True)
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=1), final_mask
    return M * final_mask, final_mask.float().mean(dim=1)

def resolve_zero_signs(sign_to_mult, method="majority"):
    majority_sign = torch.sign(sign_to_mult.sum())

    if method == "majority":
        sign_to_mult[sign_to_mult == 0] = majority_sign
    elif method == "minority":
        sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
    return sign_to_mult

def resolve_sign(Tensor):
    sign_to_mult = torch.sign(Tensor.sum(dim=0))
    sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
    return sign_to_mult

def disjoint_merge(Tensor, merge_func, sign_to_mult):
    merge_func = merge_func.split("-")[-1]

    if sign_to_mult is not None:
        rows_to_keep = torch.where(
            sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
        )
        selected_entries = Tensor * rows_to_keep
    else:
        rows_to_keep = Tensor != 0
        selected_entries = Tensor * rows_to_keep

    if merge_func == "mean":
        non_zero_counts = (selected_entries != 0).sum(dim=0).float()
        disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
            non_zero_counts, min=1
        )
    elif merge_func == "sum":
        disjoint_aggs = torch.sum(selected_entries, dim=0)
    elif merge_func == "max":
        disjoint_aggs = selected_entries.abs().max(dim=0)[0]
        disjoint_aggs *= sign_to_mult
    else:
        raise ValueError(f"Merge method {merge_func} is not defined.")

    return disjoint_aggs

def ties_merging(flat_task_checks, reset_thresh=None, merge_func="dis-mean"):
    all_checks = flat_task_checks.clone()
    updated_checks, *_ = topk_values_mask(
        all_checks, K=reset_thresh, return_mask=False
    )
    final_signs = resolve_sign(updated_checks)
    merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
    return merged_tv

Fisher

In [None]:
def compute_fisher_for_dataset(fisher_dir, dataset_name):
    def flatten_fisher_matrix(fisher):
        return torch.cat([v.flatten() for v in fisher.values()])
    fisher_path = f"{fisher_dir}/{dataset_name}Val/fisher_train.pth"
    fisher = torch.load(fisher_path)  
    return flatten_fisher_matrix(fisher)

def compute_fisher_weighted_centroid(cluster_points, fisher_matrices, lamb):
    """
    Computes the Fisher-weighted centroid using the formula:
    Merged Model = (sum_i (lambda * Fisher_i * Task_Vector_i)) / (sum_i (lambda * Fisher_i))
    """
    fisher_weighted_sum = None
    fisher_sum = None

    for task_vector, fisher in zip(cluster_points, fisher_matrices):
        task_vector_tensor = torch.tensor(task_vector, dtype=torch.float32)
        fisher_tensor = fisher.to(dtype=torch.float32, device=task_vector_tensor.device)
        weighted_task_vector = lamb * fisher_tensor * task_vector_tensor
        fisher_weighted_sum = (
            weighted_task_vector if fisher_weighted_sum is None else fisher_weighted_sum + weighted_task_vector
        )

        fisher_sum = (
            lamb * fisher_tensor if fisher_sum is None else fisher_sum + lamb * fisher_tensor
        )

    centroid = fisher_weighted_sum / fisher_sum
    centroid[torch.isnan(centroid)] = 0 
    return centroid

Inner Merge (Here we evaluate on all partitions)

In [None]:
def compute_cluster_centroids(flattened_matrix, labels, n_clusters, mean='euclidean', reset_thresh=20, merge_func="dis-mean"):
    centroids = []
    fisher_dir = "/home/cindy2000_sh/ntk-llm/tangent_task_arithmetic/checkpoints_laion2b_e16/ViT-B-32"

    for cluster in range(n_clusters):
        cluster_points = flattened_matrix[labels == cluster]
        dataset_names = np.array(task_names)[labels == cluster] 

        if len(cluster_points) == 0:
            continue 
        if len(cluster_points) == 1:
            centroids.append(np.mean(cluster_points, axis=0))
            continue 

        cluster_points_tensor = torch.tensor(cluster_points, dtype=torch.float32).to('cuda:2')

        if mean == 'mean' or 'tangent':
            centroid = np.mean(cluster_points, axis=0)

        elif mean == 'ties':
            centroid = ties_merging(
                flat_task_checks=cluster_points_tensor, 
                reset_thresh=reset_thresh,
                merge_func=merge_func
            ).numpy()  

        elif mean == 'fisher':
            fisher_matrices = [
                compute_fisher_for_dataset(fisher_dir, dataset) for dataset in dataset_names
            ]
            lamb = 1 / len(cluster_points)  # Equal weighting for all datasets in the cluster
            centroid = compute_fisher_weighted_centroid(cluster_points, fisher_matrices, lamb).numpy()

        centroids.append(centroid)

    return centroids

In [None]:
means = ['ties','mean','topk']
finetuning_mode = 'standard'
task_names = ['MNIST', 'EuroSAT', 'RESISC45', 'SVHN']
model_names = ['ViT-L-14']
pretrained_ckpt_names = ['openai']

label_comb = [[1,0,1,1],
              [2,0,1,2],
              [2,0,2,1],
              [0,1,2,2],
              [0,2,1,2],
              [0,2,3,1],
              [0,2,2,1],
              [0,1,1,1],
              [0,0,1,1],
              [0,1,0,1],
              [0,1,1,0],
              [0,0,0,0],
              [1,1,0,1],
              [1,1,1,0],
              [2,2,1,0]]

flattened_matrix, param_keys = load_and_flatten_all_task_vectors(model_names[0], task_names, pretrained_ckpt_names[0], finetuning_mode)
task_vector_template = load_task_vector(model_names[0], task_names[0], pretrained_ckpt_names[0], finetuning_mode)

for mean in means:
    for labels in label_comb:
        labels = np.array(labels)
        parts = [""] * len(task_names)
        for task, lbl in zip(task_names, labels):
            parts[lbl] += task[0] 
        folder_name = '-'.join(filter(None, parts))
        output_dir = f'tangent_task_arithmetic/checkpoints_{pretrained_ckpt_names[0]}/{model_names[0]}/MERS-{mean}/{folder_name}_checkpoints/' # create folder
        Path(output_dir).mkdir(parents=True)
        task_to_label_mapping = dict(zip(task_names, [int(l) for l in labels]))
        with open(output_dir + 'task_to_label_mapping.json', 'w') as f:
            json.dump(task_to_label_mapping, f)
        centroids = compute_cluster_centroids(flattened_matrix, labels, len(np.unique(labels)), mean)
        save_all_centroids_as_checkpoints(centroids, param_keys, task_vector_template, output_dir)

Outer Merge (Mean/Topk/TIES)

In [None]:
import os
import torch

def flatten_state_dict(state_dict):
    flattened = []
    param_keys = []
    for key, param in state_dict.items():
        flattened.append(param.view(-1))  
        param_keys.append(key)
    return torch.cat(flattened), param_keys

def reconstruct_state_dict(flattened, param_keys, original_state_dict):
    reconstructed = {}
    offset = 0
    for key in param_keys:
        original_param = original_state_dict[key]
        numel = original_param.numel()
        reconstructed[key] = flattened[offset:offset+numel].view_as(original_param)
        offset += numel
    return reconstructed

# Function to outer merge centroids using different methods
def merge_centroids(flattened_centroids, method, top_k=5, reset_thresh=20, merge_func="dis-mean"):
    centroids_tensor = torch.stack(flattened_centroids, dim=0)

    if method == "mean":
        return torch.mean(centroids_tensor, dim=0)
    elif method == "topk":
        return top_5_percent_masked_centroid(centroids_tensor)
    elif method == "ties":
        return ties_merging(centroids_tensor, reset_thresh=reset_thresh, merge_func=merge_func)
    else:
        # for remaining method like TA, don't use the notebook
        raise ValueError(f"Unsupported merge method: {method}")

In [None]:
for method in ['mean','ties','topk']:
    base_dir = f'/data/common/cindy2000_sh/tangent_task_arithmetic/checkpoints/ViT-B-16/MERS-{method}'

    for split_folder in os.listdir(base_dir):
        split_path = os.path.join(base_dir, split_folder)
        if not os.path.isdir(split_path):
            continue 

        flattened_centroids = []
        param_keys = None
        example_state_dict = None

        for file_name in os.listdir(split_path):
            if file_name.startswith("centroid_") and file_name.endswith(".pt"):
                centroid_path = os.path.join(split_path, file_name)
                state_dict = torch.load(centroid_path)
                flattened, keys = flatten_state_dict(state_dict)
                flattened_centroids.append(flattened)
                if param_keys is None:
                    param_keys = keys
                    example_state_dict = state_dict

        if not flattened_centroids:
            continue  

        merged_flattened_centroid = merge_centroids(flattened_centroids, method)

        merged_state_dict = reconstruct_state_dict(merged_flattened_centroid, param_keys, example_state_dict)

        output_file = os.path.join(split_path, f"merged_centroid_{method}.pt")
        torch.save(merged_state_dict, output_file)
        print(f"Merged centroid saved for {split_folder} using {method}: {output_file}")

In [None]:
# when outer merge = TA
# run task vector addition outside
# see cmd/run_outerTA.sh