In [1]:
import os 
import re

import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch import nn
from tqdm import tqdm
from transformers import CLIPVisionModel, Dinov2Model, TrainingArguments, Trainer
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def filter_params_to_merge(param_names, exclude_param_regex):
    params_to_merge = []
    for name in param_names:
        valid = not any([re.match(patt, name) for patt in exclude_param_regex])
        if valid:
            params_to_merge.append(name)
    return params_to_merge


def filter_modules_by_regex(base_module, include_patterns, include_type):
    modules = {}
    for name, module in base_module.named_modules():
        valid_name = not include_patterns or any([re.match(patt, name) for patt in include_patterns])
        valid_type = not include_type or any([isinstance(module, md_cls) for md_cls in include_type])
        if valid_type and valid_name:
            modules[name] = module
    return modules

In [3]:
# Load the ImageNet dataset
# Note: You need to have your Hugging Face account set up and authenticated
dataset = load_dataset("imagenet-1k", split="train[:32000]")  # Using validation set for quicker loading

def convert_to_rgb(image):
    if image.mode != 'RGB':
        return image.convert('RGB')
    return image

# Define image transformations
transform_clip = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((336, 336)),  # Resize to match CLIP's input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711])  # CLIP's normalization
])

# Define image transformations for DINOv2
# DINOv2 uses a different input size and normalization compared to CLIP
transform_dino = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((336, 336)),  # DINOv2 typically uses 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

# Create a custom dataset class
class ImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        if self.transform:
            image = self.transform(image)
        return {'pixel_values': image}

# Create the dataset and dataloader
clip_dataset = ImageNetDataset(dataset, transform=transform_clip)
clip_dataloader = DataLoader(clip_dataset, batch_size=512, shuffle=True, num_workers=32)

dino_dataset = ImageNetDataset(dataset, transform=transform_dino)
dino_dataloader = DataLoader(dino_dataset, batch_size=512, shuffle=True, num_workers=32)

In [4]:
def compute_gram(model, dataloader):
    grams = {} # gram matrices for each linear layer inputs
    xn = {} # number of examples used for computing gram

    def get_gram(name):
        def hook(module, input, output):
            x = input[0].detach() # $[b,t,h]
            x = x.view(-1, x.size(-1))
            xtx = torch.matmul(x.transpose(0,1), x) # [h,h]
            if name not in grams:
                grams[name] = xtx / x.size(0)
                xn[name] = x.size(0)
            else:
                grams[name] = (grams[name] * xn[name] + xtx) / (x.size(0) + xn[name])
                xn[name] += x.size(0)
        return hook

    linear_modules = filter_modules_by_regex(model, None, [nn.Linear])
    handles = []
    for name, module in linear_modules.items():
        handle = module.register_forward_hook(get_gram(name))
        handles.append(handle)

    n_step = 30
    total = n_step if n_step > 0 else len(dataloader)
    for step, inputs in tqdm(enumerate(dataloader), total=total, desc='Computing gram matrix'):
        if n_step > 0 and step == n_step:
            break

        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        outputs = model(**inputs)

    for handle in handles:
        handle.remove()

    return grams

In [5]:
model1 = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336').cuda()

In [6]:
model2 = CLIPVisionModel.from_pretrained('../../checkpoints/dino2clip_vit').cuda()

In [7]:
with torch.no_grad():
    grams1 = compute_gram(model1, clip_dataloader)

Computing gram matrix: 100%|██████████| 30/30 [14:25<00:00, 28.86s/it]


In [8]:
with torch.no_grad():
    grams2 = compute_gram(model2, dino_dataloader)

Computing gram matrix: 100%|██████████| 30/30 [14:41<00:00, 29.38s/it]


In [9]:
def avg_merge(local_models, global_model, regmean_grams=None, **kwargs):
    params = {}
    for local_model in local_models:
        n2p = {k: v for k,v in local_model.named_parameters()}
        merge_param_names = filter_params_to_merge([n for n in n2p], ['.*classifier.*']) # for glue label spaces are different
        for n in merge_param_names:
            if n not in params:
                params[n] = []
            params[n].append(n2p[n])

    if regmean_grams: # regmean average
        avg_params = regmean_merge(params, regmean_grams)

    else: # simple average
        avg_params = {k: torch.stack(v,0).mean(0) for k, v in params.items()}

    return avg_params

def copy_params_to_model(avg_params, model):
    for n, p in model.named_parameters():
        if n in avg_params:
            p.data.copy_(avg_params[n])

def reduce_non_diag(cov_mat, a):
    diag_weight = torch.diag(torch.ones(cov_mat.size(0)) - a).to(cov_mat.device)
    non_diag_weight = torch.zeros_like(diag_weight).fill_(a)
    weight = diag_weight + non_diag_weight
    ret = cov_mat * weight
    return ret

def regmean_merge(all_params, all_grams):
    avg_params = {}
    n_model = len(all_grams)
    for name in all_params:
        h_avged = False
        if name.endswith('.weight'):
            print(f'Regmean: {name}')
            module_name = name[:-len('.weight')]
            if module_name in all_grams[0]:
                gram_m_ws, grams = [], []

                for model_id, model_grams in enumerate(all_grams):
                    param_grams = model_grams[module_name]

                    # for roberta we dont need this; but it is important for deberta and t5
                    #param_grams = reduce_non_diag(param_grams, a=0.9)

                    param = all_params[name][model_id]
                    gram_m_ws.append(torch.matmul(param_grams, param.transpose(0,1)))
                    grams.append(param_grams)
                sum_gram = sum(grams)
                sum_gram_m_ws = sum(gram_m_ws)
                sum_gram_inv = torch.inverse(sum_gram)
                wt = torch.matmul(sum_gram_inv, sum_gram_m_ws)
                w = wt.transpose(0,1)
                avg_params[name] = w
                h_avged = True
        if not h_avged: # if not averaged with regmean, then do simple avg
            avg_params[name] = torch.stack(all_params[name],0).mean(0)
           
    return avg_params

In [10]:
merged_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336').cuda()

In [11]:
regmean_avg_params = avg_merge([model1, model2], merged_model, regmean_grams=[grams1, grams2])

Regmean: vision_model.embeddings.patch_embedding.weight
Regmean: vision_model.embeddings.position_embedding.weight
Regmean: vision_model.pre_layrnorm.weight
Regmean: vision_model.encoder.layers.0.self_attn.k_proj.weight
Regmean: vision_model.encoder.layers.0.self_attn.v_proj.weight
Regmean: vision_model.encoder.layers.0.self_attn.q_proj.weight
Regmean: vision_model.encoder.layers.0.self_attn.out_proj.weight
Regmean: vision_model.encoder.layers.0.layer_norm1.weight
Regmean: vision_model.encoder.layers.0.mlp.fc1.weight
Regmean: vision_model.encoder.layers.0.mlp.fc2.weight
Regmean: vision_model.encoder.layers.0.layer_norm2.weight
Regmean: vision_model.encoder.layers.1.self_attn.k_proj.weight
Regmean: vision_model.encoder.layers.1.self_attn.v_proj.weight
Regmean: vision_model.encoder.layers.1.self_attn.q_proj.weight
Regmean: vision_model.encoder.layers.1.self_attn.out_proj.weight
Regmean: vision_model.encoder.layers.1.layer_norm1.weight
Regmean: vision_model.encoder.layers.1.mlp.fc1.weight

In [12]:
copy_params_to_model(regmean_avg_params, merged_model)

In [16]:
merged_model.save_pretrained('../../checkpoints/regmean_dinov2_clip_vit')

: 