In [None]:
from transformers import AutoImageProcessor, Dinov2Model
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = Dinov2Model.from_pretrained("facebook/dinov2-base")

inputs = image_processor([image,image], return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state
list(last_hidden_states.shape)

In [None]:
inputs.cuda()

In [None]:
model.config._name_or_path

In [12]:
import torch
import torch.nn as nn

class SiglipMLP(nn.Module):
    def __init__(self, input_dim, intermediate_dim, output_dim):
        super().__init__()
        self.pre_norm = nn.LayerNorm(input_dim)
        self.proj = nn.Sequential(
            nn.Linear(input_dim, intermediate_dim),
            nn.GELU(),
            nn.Linear(intermediate_dim, output_dim)
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states)
        hidden_states = hidden_states+self.proj(hidden_states)
        return hidden_states

class VLContrastHead(nn.Module):
    def __init__(self, vision_dimesion, text_dimension, device, target_dimension=512, linear=False):
        super(VLContrastHead, self).__init__()
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.linear = linear
        if self.linear:
            self.vision_mapping_network = nn.Linear(vision_dimesion, target_dimension)
            self.text_mapping_network = nn.Linear(text_dimension, target_dimension)
        else:
            # self.vision_mapping_network = SiglipMLP(vision_dimesion, target_dimension, target_dimension)
            # self.text_mapping_network = SiglipMLP(text_dimension, target_dimension, target_dimension)
            self.vision_mapping_network = nn.Linear(vision_dimesion, target_dimension)
            self.text_mapping_network = nn.Linear(text_dimension, target_dimension)
            self.mapping_network = SiglipMLP(target_dimension, target_dimension, target_dimension)

        self.vision_layer_norm = nn.LayerNorm(vision_dimesion)
        self.text_layer_norm = nn.LayerNorm(text_dimension)
        self.logit_scale = nn.Parameter(torch.randn(1))
        self.logit_bias = nn.Parameter(torch.randn(1))

        self._initialize_weights()
    
    def _initialize_weights(self):

        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                torch.nn.init.ones_(module.weight)
                torch.nn.init.zeros_(module.bias)

        # Initialize logit_scale and logit_bias
        logit_scale_init = torch.log(torch.tensor(10.0))
        self.logit_scale.data.fill_(logit_scale_init)
        self.logit_bias.data.fill_(torch.tensor(-10.0))

In [13]:
head = VLContrastHead(512, 512, 'cuda')

In [29]:
head.mapping_network.proj[2].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.

In [1]:
import torch

# 尝试释放显存
torch.cuda.empty_cache()

# 查看显存使用情况
print(torch.cuda.memory_allocated(0))
print(torch.cuda.memory_reserved(0))


0
0


In [17]:
import torch
import torch.nn as nn
from typing import Optional
class StarMLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        intermediate_dim: Optional[int] = None,
    ):
        super().__init__()
        intermediate_dim = intermediate_dim if intermediate_dim is not None else output_dim
        self.Wa = nn.Linear(input_dim, input_dim, bias=False)
        self.Wb = nn.Linear(input_dim, input_dim, bias=False)
        self.g = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        a = self.Wa(x)  # N x d
        b = self.Wb(x)  # N x d
        x = torch.einsum('bij,bj->bi', torch.sigmoid(a.unsqueeze(-1) * b.unsqueeze(1)), x)
        x = self.g(x)

        assert not torch.isnan(x).any(), "Output contains NaN"
        assert not torch.isinf(x).any(), "Output contains infinite values"

        return x

In [18]:
networ = StarMLP(64, 128)


In [20]:
x = torch.randn(16, 64)
networ(x).shape

torch.Size([16, 128])