In [7]:
import os 
import re

import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch import nn
from tqdm import tqdm
from transformers import AutoImageProcessor, CLIPVisionModel, SiglipVisionModel, TrainingArguments, Trainer
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

In [2]:
def filter_params_to_merge(param_names, exclude_param_regex):
    params_to_merge = []
    for name in param_names:
        valid = not any([re.match(patt, name) for patt in exclude_param_regex])
        if valid:
            params_to_merge.append(name)
    return params_to_merge


def filter_modules_by_regex(base_module, include_patterns, include_type):
    modules = {}
    for name, module in base_module.named_modules():
        valid_name = not include_patterns or any([re.match(patt, name) for patt in include_patterns])
        valid_type = not include_type or any([isinstance(module, md_cls) for md_cls in include_type])
        if valid_type and valid_name:
            modules[name] = module
    return modules

In [4]:
# Load the ImageNet dataset
dataset = load_dataset("imagenet-1k", split="train[:32000]")  # Using a subset for quicker loading

# Load processors
clip_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-large-patch16-384")

# Create a custom dataset class
class ImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        processed = self.processor(images=image, return_tensors="pt")
        return {'pixel_values': processed['pixel_values'].squeeze()}

# Create datasets and dataloaders
clip_dataset = ImageNetDataset(dataset, clip_processor)
clip_dataloader = DataLoader(clip_dataset, batch_size=512, shuffle=True, num_workers=32)

siglip_dataset = ImageNetDataset(dataset, siglip_processor)
siglip_dataloader = DataLoader(siglip_dataset, batch_size=512, shuffle=True, num_workers=32)

In [6]:
def compute_gram(model, dataloader):
    grams = {} # gram matrices for each linear layer inputs
    xn = {} # number of examples used for computing gram

    def get_gram(name):
        def hook(module, input, output):
            x = input[0].detach() # $[b,t,h]
            x = x.view(-1, x.size(-1))
            xtx = torch.matmul(x.transpose(0,1), x) # [h,h]
            if name not in grams:
                grams[name] = xtx / x.size(0)
                xn[name] = x.size(0)
            else:
                grams[name] = (grams[name] * xn[name] + xtx) / (x.size(0) + xn[name])
                xn[name] += x.size(0)
        return hook

    linear_modules = filter_modules_by_regex(model, None, [nn.Linear])
    handles = []
    for name, module in linear_modules.items():
        handle = module.register_forward_hook(get_gram(name))
        handles.append(handle)

    n_step = 30
    total = n_step if n_step > 0 else len(dataloader)
    for step, inputs in tqdm(enumerate(dataloader), total=total, desc='Computing gram matrix'):
        if n_step > 0 and step == n_step:
            break

        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        outputs = model(**inputs)

    for handle in handles:
        handle.remove()

    return grams

In [8]:
model1 = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336').cuda()

In [9]:
model2 = SiglipVisionModel.from_pretrained('google/siglip-large-patch16-384').cuda()

In [11]:
with torch.no_grad():
    grams1 = compute_gram(model1, clip_dataloader)

Computing gram matrix:   0%|          | 0/30 [00:06<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.13 GiB. GPU 

In [10]:
with torch.no_grad():
    grams2 = compute_gram(model2, siglip_dataloader)

Computing gram matrix:   0%|          | 0/30 [00:01<?, ?it/s]


ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/data/gavin/miniconda3/envs/cambrian/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/gavin/miniconda3/envs/cambrian/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gavin/miniconda3/envs/cambrian/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_810165/344727276.py", line 19, in __getitem__
    processed = self.processor(images=image, return_tensors="pt")
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gavin/miniconda3/envs/cambrian/lib/python3.11/site-packages/transformers/image_processing_utils.py", line 41, in __call__
    return self.preprocess(images, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gavin/miniconda3/envs/cambrian/lib/python3.11/site-packages/transformers/models/siglip/image_processing_siglip.py", line 233, in preprocess
    input_data_format = infer_channel_dimension_format(images[0])
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gavin/miniconda3/envs/cambrian/lib/python3.11/site-packages/transformers/image_utils.py", line 244, in infer_channel_dimension_format
    raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
ValueError: Unsupported number of image dimensions: 2


In [14]:
grams1.keys()

dict_keys(['vision_model.encoder.layers.0.self_attn.q_proj', 'vision_model.encoder.layers.0.self_attn.k_proj', 'vision_model.encoder.layers.0.self_attn.v_proj', 'vision_model.encoder.layers.0.self_attn.out_proj', 'vision_model.encoder.layers.0.mlp.fc1', 'vision_model.encoder.layers.0.mlp.fc2', 'vision_model.encoder.layers.1.self_attn.q_proj', 'vision_model.encoder.layers.1.self_attn.k_proj', 'vision_model.encoder.layers.1.self_attn.v_proj', 'vision_model.encoder.layers.1.self_attn.out_proj', 'vision_model.encoder.layers.1.mlp.fc1', 'vision_model.encoder.layers.1.mlp.fc2', 'vision_model.encoder.layers.2.self_attn.q_proj', 'vision_model.encoder.layers.2.self_attn.k_proj', 'vision_model.encoder.layers.2.self_attn.v_proj', 'vision_model.encoder.layers.2.self_attn.out_proj', 'vision_model.encoder.layers.2.mlp.fc1', 'vision_model.encoder.layers.2.mlp.fc2', 'vision_model.encoder.layers.3.self_attn.q_proj', 'vision_model.encoder.layers.3.self_attn.k_proj', 'vision_model.encoder.layers.3.self_a

In [86]:
type(grams1)

dict

In [16]:
grams2.keys()

dict_keys(['encoder.layer.0.attention.attention.query', 'encoder.layer.0.attention.attention.key', 'encoder.layer.0.attention.attention.value', 'encoder.layer.0.attention.output.dense', 'encoder.layer.0.mlp.fc1', 'encoder.layer.0.mlp.fc2', 'encoder.layer.1.attention.attention.query', 'encoder.layer.1.attention.attention.key', 'encoder.layer.1.attention.attention.value', 'encoder.layer.1.attention.output.dense', 'encoder.layer.1.mlp.fc1', 'encoder.layer.1.mlp.fc2', 'encoder.layer.2.attention.attention.query', 'encoder.layer.2.attention.attention.key', 'encoder.layer.2.attention.attention.value', 'encoder.layer.2.attention.output.dense', 'encoder.layer.2.mlp.fc1', 'encoder.layer.2.mlp.fc2', 'encoder.layer.3.attention.attention.query', 'encoder.layer.3.attention.attention.key', 'encoder.layer.3.attention.attention.value', 'encoder.layer.3.attention.output.dense', 'encoder.layer.3.mlp.fc1', 'encoder.layer.3.mlp.fc2', 'encoder.layer.4.attention.attention.query', 'encoder.layer.4.attention.a

In [17]:
model1

CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(577, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=

In [18]:
model2

Dinov2Model(
  (embeddings): Dinov2Embeddings(
    (patch_embeddings): Dinov2PatchEmbeddings(
      (projection): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Dinov2Encoder(
    (layer): ModuleList(
      (0-23): 24 x Dinov2Layer(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attention): Dinov2Attention(
          (attention): Dinov2SelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): Dinov2SelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (layer_scale1): Dinov2LayerScale()
        (drop_

In [37]:
for k,v in n2p.items():
    if 'layrnorm' in k:
        print(k)

vision_model.pre_layrnorm.weight
vision_model.pre_layrnorm.bias


In [30]:
for k,v in n2p.items():
    if 'projection.bias' in k:
        print(k)

embeddings.patch_embeddings.projection.bias


In [38]:
n2p1 = {k:v for k,v in model1.named_parameters()}
n2p2 = {k:v for k,v in model2.named_parameters()}

In [72]:
n1 = filter_params_to_merge([n for n in n2p1], ['.*classifier.*', '.*projection.bias', '.*layrnorm.*', '.*layernorm.*', '.*scale.*', '.*class_embedding', '.*position_embedding.*', '.*mask_token.*', '.*cls_token.*'])
n2 = filter_params_to_merge([n for n in n2p2], ['.*classifier.*', '.*projection.bias', '.*layrnorm.*', '.*layernorm.*', '.*scale.*', '.*class_embedding', '.*position_embedding.*', '.*mask_token.*', '.*cls_token.*'])

In [81]:
n_new = []
for n in n2:
    n_new.append(get_clip_name(n))

In [74]:
len(n1)

385

In [78]:
len(n2)

385

In [76]:
len(n_new)

385

In [83]:
n1 = sorted(n1)
n_new = sorted(n_new)

In [84]:
# check if n1 and n_new are the same
for i in range(len(n1)):
    if n1[i] != n_new[i]:
        print(n1[i], n_new[i])


In [85]:
n1 == n_new

True

In [71]:
get_clip_name('encoder.layer.23.mlp.fc1.weight')

'vision_model.encoder.layers.23.mlp.fc1.weight'

In [80]:
def get_clip_name(dinov2_name):
    if 'embeddings.patch_embeddings.projection' in dinov2_name:
        return dinov2_name.replace('embeddings.patch_embeddings.projection', 'vision_model.embeddings.patch_embedding')
    if 'encoder.layer' in dinov2_name:
        dinov2_name = dinov2_name.replace('encoder.layer', 'vision_model.encoder.layers')
        if 'attention' in dinov2_name:
            dinov2_name = dinov2_name.replace('attention', 'self_attn', 1)
            if 'attention.query' in dinov2_name:
                dinov2_name = dinov2_name.replace('attention.query', 'q_proj')
            elif 'attention.key' in dinov2_name:
                dinov2_name = dinov2_name.replace('attention.key', 'k_proj')
            elif 'attention.value' in dinov2_name:
                dinov2_name = dinov2_name.replace('attention.value', 'v_proj')
            elif 'output.dense' in dinov2_name:
                dinov2_name = dinov2_name.replace('output.dense', 'out_proj')
            return dinov2_name
        elif 'norm1' in dinov2_name:
            return dinov2_name.replace('norm1', 'layer_norm1')
        elif 'norm2' in dinov2_name:
            return dinov2_name.replace('norm2', 'layer_norm2')
        return dinov2_name
    return dinov2_name
        

In [89]:
def avg_merge(local_models, global_model, regmean_grams=None, **kwargs):
    params = {}
    for i, local_model in enumerate(local_models):
        
        n2p = {k: v for k,v in local_model.named_parameters()}
        merge_param_names = filter_params_to_merge([n for n in n2p], ['.*classifier.*', '.*projection.bias', '.*layrnorm.*', '.*layernorm.*', '.*scale.*', '.*class_embedding', '.*position_embedding.*', '.*mask_token.*', '.*cls_token.*']) # for glue label spaces are different
        for n in merge_param_names:
            # if it is the second model, the second model should be dinov2
            if i == 1:
                n_clip = get_clip_name(n)
                params[n_clip].append(n2p[n])
            else:
                if n not in params:
                    params[n] = []
                params[n].append(n2p[n])
    
    # we also need to modify the name of the parameters in the regmean grams
    if regmean_grams:
        for k in list(regmean_grams[1].keys()):
            new_k = get_clip_name(k)
            regmean_grams[1][new_k] = regmean_grams[1].pop(k)

    if regmean_grams: # regmean average
        avg_params = regmean_merge(params, regmean_grams)

    else: # simple average
        avg_params = {k: torch.stack(v,0).mean(0) for k, v in params.items()}

    return avg_params

def copy_params_to_model(avg_params, model):
    for n, p in model.named_parameters():
        if n in avg_params:
            p.data.copy_(avg_params[n])

def reduce_non_diag(cov_mat, a):
    diag_weight = torch.diag(torch.ones(cov_mat.size(0)) - a).to(cov_mat.device)
    non_diag_weight = torch.zeros_like(diag_weight).fill_(a)
    weight = diag_weight + non_diag_weight
    ret = cov_mat * weight
    return ret

def regmean_merge(all_params, all_grams):
    avg_params = {}
    n_model = len(all_grams)
    for name in all_params:
        h_avged = False
        if name.endswith('.weight'):
            print(f'Regmean: {name}')
            module_name = name[:-len('.weight')]
            if module_name in all_grams[0]:
                gram_m_ws, grams = [], []

                for model_id, model_grams in enumerate(all_grams):
                    param_grams = model_grams[module_name]

                    # for roberta we dont need this; but it is important for deberta and t5
                    #param_grams = reduce_non_diag(param_grams, a=0.9)

                    param = all_params[name][model_id]
                    gram_m_ws.append(torch.matmul(param_grams, param.transpose(0,1)))
                    grams.append(param_grams)
                sum_gram = sum(grams)
                sum_gram_m_ws = sum(gram_m_ws)
                sum_gram_inv = torch.inverse(sum_gram)
                wt = torch.matmul(sum_gram_inv, sum_gram_m_ws)
                w = wt.transpose(0,1)
                avg_params[name] = w
                h_avged = True
        if not h_avged: # if not averaged with regmean, then do simple avg
            avg_params[name] = torch.stack(all_params[name],0).mean(0)
           
    return avg_params

In [87]:
merged_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336').cuda()

In [90]:
regmean_avg_params = avg_merge([model1, model2], merged_model, regmean_grams=[grams1, grams2])

Regmean: vision_model.embeddings.patch_embedding.weight
Regmean: vision_model.encoder.layers.0.self_attn.k_proj.weight
Regmean: vision_model.encoder.layers.0.self_attn.v_proj.weight
Regmean: vision_model.encoder.layers.0.self_attn.q_proj.weight
Regmean: vision_model.encoder.layers.0.self_attn.out_proj.weight
Regmean: vision_model.encoder.layers.0.layer_norm1.weight
Regmean: vision_model.encoder.layers.0.mlp.fc1.weight
Regmean: vision_model.encoder.layers.0.mlp.fc2.weight
Regmean: vision_model.encoder.layers.0.layer_norm2.weight
Regmean: vision_model.encoder.layers.1.self_attn.k_proj.weight
Regmean: vision_model.encoder.layers.1.self_attn.v_proj.weight
Regmean: vision_model.encoder.layers.1.self_attn.q_proj.weight
Regmean: vision_model.encoder.layers.1.self_attn.out_proj.weight
Regmean: vision_model.encoder.layers.1.layer_norm1.weight
Regmean: vision_model.encoder.layers.1.mlp.fc1.weight
Regmean: vision_model.encoder.layers.1.mlp.fc2.weight
Regmean: vision_model.encoder.layers.1.layer_n

In [91]:
copy_params_to_model(regmean_avg_params, merged_model)

In [92]:
!pwd

13983.68s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


/data/gavin/cambrian/cambrian/merge


In [94]:
merged_model.save_pretrained('../../checkpoints/regmean_dinov2_clip_vit_new')