1.Preparations

In [1]:
import csv
import gc
import math
import os
import random
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from itertools import chain
from typing import List, Optional

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoTokenizer

In [2]:
ROOT = "/content/Images"
ANN_FILE = "/content/captions.txt"

def to_2tuple(x):
    if isinstance(x, tuple):
        return x
    return (x, x)


def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


def clear_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


def tensor_to_device(*tensors, device=torch.device("cpu"), non_blocking=True):
    moved = tuple(t.to(device, non_blocking=non_blocking) for t in tensors)
    return moved if len(moved) > 1 else moved[0]

def get_ctx(use_mixed: bool, device: torch.device, amp_mode: str = "auto"):
    if not use_mixed or amp_mode == "off":
        print("Not using autocast context")
        return nullcontext()

    device_type = device.type

    if amp_mode == "fp16":
        dtype = torch.float16
    elif amp_mode == "bf16":
        dtype = torch.bfloat16
    else:
        if device_type == "cuda":
            dtype = torch.bfloat16
        elif device_type == "mps":
            dtype = torch.float16
        elif device_type == "cpu":
            dtype = torch.float16
        else:
            return nullcontext()

    print(f"Using autocast with dtype={dtype} on device type={device_type}")
    return torch.autocast(device_type=device_type, dtype=dtype)


def load_tokenizer(model_name="openai-community/gpt2-medium"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return tokenizer


def clean_caption(caption: str) -> str:
    caption = caption.strip()
    caption = re.sub(r"\s+", " ", caption)
    caption = re.sub(r"\s+([.,!?;:])", r"\1", caption)

    return caption

def show_samples(
    dataset,
    nrows=2,
    ncols=2,
):
    num_samples = nrows * ncols
    idxs = random.sample(range(len(dataset)), k=num_samples)

    fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))

    for ax, idx in zip(axes.flatten(), idxs):
        sample = dataset[idx]
        if isinstance(sample, tuple):
            img_t, cap = sample
        else:
            img_t, cap = sample["pixel_values"], sample["caption"]

        # Convert CHW tensor to HWC numpy
        if torch.is_tensor(img_t):
            img_np = img_t.detach().cpu().numpy().transpose(1, 2, 0)
        else:  # PIL Image
            img_np = transforms.ToTensor()(img_t).numpy().transpose(1, 2, 0)

        ax.imshow(img_np)
        ax.axis("off")
        ax.set_title(cap[:40] + ("…" if len(cap) > 40 else ""))

    plt.tight_layout()
    plt.show()

2.CLIP

In [3]:
@dataclass
class ModelConfig:
    # Vision Modeling
    vit_hf_model_name:str = "google/vit-base-patch16-224"
    vit_image_size:int = 224
    vit_patch_size:int = 16
    vit_hidden_size:int = 768 # 16*16*3
    vit_num_hidden_layers:int = 12
    vit_num_attention_heads:int = 12
    vit_intermediate_size:int = 3072
    vit_layer_norm_eps:float = 1e-12
    vit_dropout:float=0.1
    
    #Language Modeling
    lm_hf_model_name:str="openai-community/gpt2-medium"
    lm_vocab_size:int = 50257
    lm_n_layer:int = 24
    lm_n_head:int = 16
    lm_n_embd:int = 1024
    lm_n_positions:int = 76
    lm_resid_pdrop:float=0.1
    lm_embd_pdrop:float=0.1
    lm_attn_pdrop:float=0.1
    lm_layer_norm_epsilon:float=1e-5
    lm_tie_word_embeddings:bool=True
    
    lm_eot_token_id:int = 50256
    lm_pad_token_id:int = (50256) # If you set pad_token_id, you need to add it to tokenizer, and resize model embeddings
    
    #Projection
    project_dim:int = 512
    

2.1Vision Modeling

In [18]:
class MLP(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.vit_hidden_size, config.vit_intermediate_size)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(config.vit_intermediate_size,config.vit_hidden_size)
        self.drop = nn.Dropout(config.vit_dropout)
    
    def forward(self,x):
        return self.drop(self.fc2(self.act(self.fc1(x))))

class Attention(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.num_heads = config.vit_num_attention_heads
        self.head_dim = config.vit_hidden_size // config.vit_num_attention_heads
        self.scale = self.head_dim**-0.5
        
        self.q_proj = nn.Linear(config.vit_hidden_size, config.vit_hidden_size)
        self.k_proj = nn.Linear(config.vit_hidden_size, config.vit_hidden_size)
        self.v_proj = nn.Linear(config.vit_hidden_size, config.vit_hidden_size)
        
        self.proj = nn.Linear(config.vit_hidden_size, config.vit_hidden_size)
        self.attn_drop = nn.Dropout(config.vit_dropout)
        self.proj_drop = nn.Dropout(config.vit_dropout)
    
    def forward(self,x):
        B, N, C = x.shape
        q = self.q_proj(x).reshape(B,N,self.num_heads,self.head_dim).transpose(1,2)
        k = self.k_proj(x).reshape(B,N,self.num_heads,self.head_dim).transpose(1,2)
        v = self.v_proj(x).reshape(B,N,self.num_heads,self.head_dim).transpose(1,2)
        
        attn = (q@k.transpose(-2,-1))*self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        out = (attn@v).transpose(1,2).reshape(B,N,C)
        out = self.proj(out)
        return self.proj_drop(out)
    
class Block(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.vit_hidden_size, eps=config.vit_layer_norm_eps)
        self.attn = Attention(config)
        self.ln2 = nn.LayerNorm(config.vit_hidden_size, eps=config.vit_layer_norm_eps)
        self.mlp = MLP(config)
        
    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        
        return x
        
        
# -----------------------
# Vision Transformer
# -----------------------
class VisionTransformer(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.config=config
        
        self.patch_embed = nn.Conv2d(3, config.vit_hidden_size, kernel_size=config.vit_patch_size, stride=config.vit_patch_size)
        num_patches = (config.vit_image_size//config.vit_patch_size)**2
        self.cls_token = nn.Parameter(torch.zeros(1,1,config.vit_hidden_size))
        self.pos_embd = nn.Parameter(torch.zeros(1,num_patches+1,config.vit_hidden_size))
        self.pos_drop = nn.Dropout(config.vit_dropout)
        
        self.blocks = nn.ModuleList([Block(config) for _  in range(config.vit_num_hidden_layers)])
        self.norm = nn.LayerNorm(config.vit_hidden_size,eps = config.vit_layer_norm_eps)
    def forward(self,x):
        # x: (B, 3, H, W)
        B = x.shape[0]
        x = self.patch_embed(x) # (B, C, H/P, W/P)
        x = x.flatten(2).transpose(1,2) # (B, N, C)
        cls = self.cls_token.expand(B,-1,-1) # (B, 1, C)
        x = torch.cat((cls,x),dim=1)
        x = x + self.pos_embd
        x = self.pos_drop(x)
        
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        cls_out = x[:,0]  # CLS token
        
        return cls_out

@torch.no_grad()
def load_hf_vit_into_scratch(model:VisionTransformer, hf_name:str = "google/vit-base-patch16-224", device:Optional[str]=None):
    from transformers import AutoConfig, ViTModel
    
    hf_cfg = AutoConfig.from_pretrained(hf_name, add_pooling_layer = False)
    assert hf_cfg.hidden_size == model.config.vit_hidden_size
    assert hf_cfg.num_hidden_layers == model.config.vit_num_hidden_layers
    assert hf_cfg.num_attention_heads == model.config.vit_num_attention_heads
    assert hf_cfg.intermediate_size == model.config.vit_intermediate_size
    assert hf_cfg.image_size == model.config.vit_image_size
    assert hf_cfg.patch_size == model.config.vit_patch_size
    
    hf_model = ViTModel.from_pretrained(hf_name)
    hf_sd = hf_model.state_dict()
    sd = model.state_dict()
    
    # patch embed
    sd["patch_embed.weight"] = hf_sd["embeddings.patch_embeddings.projection.weight"]
    sd["patch_embed.bias"] = hf_sd["embeddings.patch_embeddings.projection.bias"]

    sd["cls_token"] = hf_sd["embeddings.cls_token"]
    sd["pos_embed"] = hf_sd["embeddings.position_embeddings"]

    for i in range(model.config.vit_num_hidden_layers):
        p = f"blocks.{i}"
        h = f"encoder.layer.{i}"

        sd[f"{p}.ln1.weight"] = hf_sd[f"{h}.layernorm_before.weight"]
        sd[f"{p}.ln1.bias"] = hf_sd[f"{h}.layernorm_before.bias"]

        sd[f"{p}.attn.q_proj.weight"] = hf_sd[f"{h}.attention.attention.query.weight"]
        sd[f"{p}.attn.q_proj.bias"] = hf_sd[f"{h}.attention.attention.query.bias"]
        sd[f"{p}.attn.k_proj.weight"] = hf_sd[f"{h}.attention.attention.key.weight"]
        sd[f"{p}.attn.k_proj.bias"] = hf_sd[f"{h}.attention.attention.key.bias"]
        sd[f"{p}.attn.v_proj.weight"] = hf_sd[f"{h}.attention.attention.value.weight"]
        sd[f"{p}.attn.v_proj.bias"] = hf_sd[f"{h}.attention.attention.value.bias"]
        # but HF stores q,k,v separately, so you’ll need to adapt here
        # to keep it simple, load q/k/v separately like in CLIP
        # (adjust your Attention class accordingly if you want exact parity)

        sd[f"{p}.ln2.weight"] = hf_sd[f"{h}.layernorm_after.weight"]
        sd[f"{p}.ln2.bias"] = hf_sd[f"{h}.layernorm_after.bias"]

        sd[f"{p}.mlp.fc1.weight"] = hf_sd[f"{h}.intermediate.dense.weight"]
        sd[f"{p}.mlp.fc1.bias"] = hf_sd[f"{h}.intermediate.dense.bias"]
        sd[f"{p}.mlp.fc2.weight"] = hf_sd[f"{h}.output.dense.weight"]
        sd[f"{p}.mlp.fc2.bias"] = hf_sd[f"{h}.output.dense.bias"]

    sd["norm.weight"] = hf_sd["layernorm.weight"]
    sd["norm.bias"] = hf_sd["layernorm.bias"]
    
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if device:
        model.to(device)
    print("Loaded HF ViT weights.")
    if missing:
        print("Missing keys:", missing)
    if unexpected:
        print("Unexpected keys:", unexpected)
    return model

2.2Language Model

In [23]:
class GELU(nn.Module):
    def forward(self, x):
        return F.gelu(x)
    
class CausalSelfAttention(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        assert config.lm_n_embd % config.lm_n_head == 0,  "n_embd must be divisible by n_head"
        self.n_head = config.lm_n_head
        self.head_dim = config.lm_n_embd // config.lm_n_head
        self.c_attn = nn.Linear(config.lm_n_embd, 3 * config.lm_n_embd, bias=True)#qkv
        self.c_proj = nn.Linear(config.lm_n_embd, config.lm_n_embd, bias=True)
        self.attn_drop=nn.Dropout(config.lm_attn_pdrop)
        self.resid_drop = nn.Dropout(config.lm_resid_pdrop)
        
        mask = torch.triu(torch.ones(config.lm_n_positions, config.lm_n_positions),1).bool()
        self.register_buffer("causal_mask",mask,persistent=False)
        
    def forward(self, x, attention_mask: Optional[torch.Tensor]=None):
        B, T, C = x.size() # (batch, time, channels)
        
        qkv = self.c_attn(x) # (B, T, 3*C)
        q,k,v = qkv.split(C,dim=-1)  # 3 x (B, T, C)
        
        q = q.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        k = k.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        v = v.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        
        # Scaled dot-product attention with causal mask
        att = (q @ k.transpose(-2,-1))/math.sqrt(self.head_dim) # (B, n_head, T, T)
        
        att = att.masked_fill(self.causal_mask.unsqueeze(0).unsqueeze(0)[..., :T, :T], float("-inf"))
        if attention_mask is not None:
            attention_mask = attention_mask[:,None,None,:].to(torch.bool) # (B, 1, 1, T)
            att = att.masked_fill(~attention_mask, float("-inf"))
            
        att = F.softmax(att,dim=-1)
        att = self.attn_drop(att)
        y = att@v # (B, n_head, T, head_dim)
        y = y.transpose(1,2).contiguous().view(B,T,C)
        y = self.resid_drop(self.c_proj(y))
        
        return y
        
class TMLP(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.c_fc = nn.Linear(config.lm_n_embd, 4 * config.lm_n_embd, bias=True)
        self.act = GELU()
        self.c_proj = nn.Linear(4 * config.lm_n_embd, config.lm_n_embd, bias=True)
        self.drop = nn.Dropout(config.lm_resid_pdrop)
        
    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.drop(x)
        
        return x
    
class TBlock(nn.Module):
    def __init__(self, config:ModelConfig): 
        super().__init__()
        
        self.ln1 = nn.LayerNorm(config.lm_n_embd,config.lm_layer_norm_epsilon)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.lm_n_embd,config.lm_layer_norm_epsilon)
        self.mlp = TMLP(config)
        
    def forward(self, x, attention_mask:Optional[torch.Tensor]=None):
        x = x + self.attn(self.ln1(x), attention_mask)
        x = x + self.mlp(self.ln2(x))
        return x        
    
class TextEncoder(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.config = config
        self.tok_emb = nn.Embedding(config.lm_vocab_size, config.lm_n_embd)
        self.pos_emb = nn.Embedding(config.lm_n_positions, config.lm_n_embd) # learnable
        self.drop = nn.Dropout(config.lm_embd_pdrop)
        self.blocks = nn.ModuleList([TBlock(config) for _ in range(config.lm_n_layer)])
        self.ln_f = nn.LayerNorm(config.lm_n_embd, eps = config.lm_layer_norm_epsilon)
        self.lm_head = nn.Linear(config.lm_n_embd, config.lm_vocab_size, bias=False)
        
        if config.lm_tie_word_embeddings:
            self.lm_head.weight = self.tok_emb.weight
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            #GPT-2
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    @torch.no_grad()
    def _pos_indices(self, T, device):
        return torch.arange(T, device=device).unsqueeze(0) # (1, T)
    
    def forward(self, input_ids, attention_mask:Optional[torch.Tensor]=None):
        """
        input_ids: (B, T) with T <= context_length; must include EOT token id
        attention_mask is optional (1 for valid tokens, 0 for pad) — not needed for CLIP pooling,
        but accepted for API symmetry.
        """
        B, T = input_ids.shape
        device = input_ids.device
        
        pos = self._pos_indices(T, device)
        x = self.tok_emb(input_ids) + self.pos_emb(pos)
        x= self.drop(x)
        
        for blk in self.blocks:
            x = blk(x, attention_mask)
        x = self.ln_f(x)
        
        eot_idx = (input_ids ==self.config.lm_eot_token_id).int().argmax(dim=1) # (B,)
        eot_rep = x[torch.arange(B, device=device), eot_idx,:]# (B, hidden)
        
        return eot_rep
        
# -----------------------
# Checkpoint loader (HF -> this model)
# -----------------------
def load_from_hf_into_scratch(
    model: TextEncoder, hf_name: str = "openai-community/gpt2-medium", device: Optional[str] = None
):
    """
    Load weights from a Hugging Face GPT-2 LMHead checkpoint into this scratch model.
    Fixes Conv1D (in,out) -> Linear (out,in) by transposing c_attn, c_fc, c_proj weights.
    """
    import torch
    from transformers import AutoConfig, GPT2LMHeadModel

    def maybe_T(src: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
        # If src is stored as (in, out) but target wants (out, in), transpose.
        if src.shape == target_shape:
            return src
        if src.shape[::-1] == tuple(target_shape):
            return src.t()
        # Fallback: let load_state_dict complain if it's truly incompatible.
        return src

    # sanity checks
    hf_cfg = AutoConfig.from_pretrained(hf_name)
    assert hf_cfg.n_layer == model.config.lm_n_layer
    assert hf_cfg.n_head == model.config.lm_n_head
    assert hf_cfg.n_embd == model.config.lm_n_embd
    assert hf_cfg.vocab_size == model.config.lm_vocab_size

    hf_model = GPT2LMHeadModel.from_pretrained(hf_name)
    hf_sd = hf_model.state_dict()
    sd = model.state_dict()

    # embeddings
    sd["tok_emb.weight"] = hf_sd["transformer.wte.weight"]
    pos_src = hf_sd["transformer.wpe.weight"]
    if pos_src.size(0) >= model.config.lm_n_positions:
        sd["pos_emb.weight"] = pos_src[: model.config.lm_n_positions]
    else:
        pad = model.config.lm_n_positions - pos_src.size(0)
        sd["pos_emb.weight"] = torch.cat([pos_src, pos_src[-1:].repeat(pad, 1)], dim=0)

    C = model.config.lm_n_embd
    for i in range(model.config.lm_n_layer):
        p = f"blocks.{i}"
        h = f"transformer.h.{i}"

        # LayerNorms
        sd[f"{p}.ln1.weight"] = hf_sd[f"{h}.ln_1.weight"]
        sd[f"{p}.ln1.bias"] = hf_sd[f"{h}.ln_1.bias"]
        sd[f"{p}.ln2.weight"] = hf_sd[f"{h}.ln_2.weight"]
        sd[f"{p}.ln2.bias"] = hf_sd[f"{h}.ln_2.bias"]

        # Attention: c_attn (fused qkv) and c_proj
        # our shapes: c_attn.weight (3C, C), c_proj.weight (C, C)
        sd[f"{p}.attn.c_attn.weight"] = maybe_T(hf_sd[f"{h}.attn.c_attn.weight"], torch.Size([3 * C, C]))
        sd[f"{p}.attn.c_attn.bias"] = hf_sd[f"{h}.attn.c_attn.bias"]
        sd[f"{p}.attn.c_proj.weight"] = maybe_T(hf_sd[f"{h}.attn.c_proj.weight"], torch.Size([C, C]))
        sd[f"{p}.attn.c_proj.bias"] = hf_sd[f"{h}.attn.c_proj.bias"]

        # MLP: c_fc (4C, C) and c_proj (C, 4C)
        sd[f"{p}.mlp.c_fc.weight"] = maybe_T(hf_sd[f"{h}.mlp.c_fc.weight"], torch.Size([4 * C, C]))
        sd[f"{p}.mlp.c_fc.bias"] = hf_sd[f"{h}.mlp.c_fc.bias"]
        sd[f"{p}.mlp.c_proj.weight"] = maybe_T(hf_sd[f"{h}.mlp.c_proj.weight"], torch.Size([C, 4 * C]))
        sd[f"{p}.mlp.c_proj.bias"] = hf_sd[f"{h}.mlp.c_proj.bias"]

    # final LayerNorm
    sd["ln_f.weight"] = hf_sd["transformer.ln_f.weight"]
    sd["ln_f.bias"] = hf_sd["transformer.ln_f.bias"]

    # lm_head (tied to tok_emb by default). If untied, copy explicitly (no transpose needed).
    if not model.config.lm_tie_word_embeddings:
        if "lm_head.weight" in hf_sd:
            sd["lm_head.weight"] = hf_sd["lm_head.weight"]
        else:
            sd["lm_head.weight"] = hf_sd["transformer.wte.weight"]

    missing, unexpected = model.load_state_dict(sd, strict=False)
    if device is not None:
        model.to(device)

    print("Loaded HF weights with Conv1D->Linear transposes.")
    if missing:
        print("Missing keys:", missing)
    if unexpected:
        print("Unexpected keys:", unexpected)
    return model
        

2.3Modality Projector

In [24]:
class ModalityProjector(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.1):
        super().__init__()
        
        self.projection = nn.Linear(in_dim, out_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(out_dim)
        
        self._init_weights()
        
    def _init_weights(self):
        # Xavier init for linear layers
        nn.init.xavier_normal_(self.projection.weight)
        nn.init.zeros_(self.projection.bias)
        
        nn.init.xavier_normal_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)
        
    def forward(self,x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        
        return self.layer_norm(x)

2.4Full CLIP Model

In [25]:
def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)
        
class CLIPModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        
        self.config = config
        self.vision_encoder = VisionTransformer(config)
        self.text_encoder = TextEncoder(config)
        
        self.vision_proj = ModalityProjector(config.vit_hidden_size, config.project_dim)
        self.text_proj = ModalityProjector(config.lm_n_embd, config.project_dim)
        
        self.load_from_state_dict()
        
    def forward(self, image, input_ids, attention_mask:Optional[torch.Tensor]=None):
        image_embeds = self.encode_image(image) # (B, project_dim)
        text_embeds = self.encode_text(input_ids, attention_mask)
        
        return image_embeds,text_embeds
        
    def encode_image(self,image):
        if image.ndim == 3:
            image = image.unsqueeze(0)
        image_features = self.vision_encoder(image)# (B, vit_hidden_size)
        image_embeds = self.vision_proj(image_features)  # (B, project_dim)
        
        return image_embeds
    
    def encode_text(self, input_ids, attention_mask:Optional[torch.Tensor]=None,):
        if input_ids.ndim==1:
            input_ids = input_ids.unsqueeze(0)
        text_features = self.text_encoder(input_ids, attention_mask)  # (B, lm_n_embd)
        text_embeds = self.text_proj(text_features)  # (B, 512)
        return text_embeds
        
    def load_from_state_dict(self,):
        load_from_hf_into_scratch(self.text_encoder, hf_name=self.config.lm_hf_model_name)
        load_hf_vit_into_scratch(self.vision_encoder, hf_name=self.config.vit_hf_model_name)
        print(f"Loaded Vision Encoder from {self.config.vit_hf_model_name}")
        print(f"Loaded Text Encoder from {self.config.lm_hf_model_name}")
        
        self.freeze_partial_backbone()
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Trainable parameters: {trainable_params} / {total_params}")
        
    def freeze_partial_backbone(self, freeze_ratio:float=0.4):
        num_layers = len(self.vision_encoder.blocks)
        num_freeze = int(num_layers*freeze_ratio)
        for layer in self.vision_encoder.blocks[:num_freeze]:
            for param in layer.parameters():
                param.requires_grad=False
                
        self.vision_encoder.cls_token.requires_grad=False
        self.vision_encoder.pos_embd.requires_grad=False
        
        for layer in self.vision_encoder.blocks[num_freeze:]:
            layer.apply(init_weights)
        
        num_layers = len(self.text_encoder.blocks)
        num_freeze = int(num_layers*freeze_ratio)
        
        for layer in self.text_encoder.blocks[:num_freeze]:
            for param in layer.parameters():
                param.requires_grad=False
                
        for param in self.text_encoder.tok_emb.parameters():
            param.requires_grad=False
        
        for param in self.text_encoder.pos_emb.parameters():
            param.requires_grad=False
        
        for layer in self.text_encoder.blocks[num_freeze:]:
            layer.apply(init_weights)
        

2.5Dummy Test

In [26]:
config = ModelConfig()
device = get_device()

tokenizer = load_tokenizer(config.lm_hf_model_name)
config.lm_eot_token_id = tokenizer.eos_token_id


text = [
     "a photo of a cat <|endoftext|>",
     "contrastingly, a dog <|endoftext|>",
]
imgs = torch.randn(len(text), 3, config.vit_image_size, config.vit_image_size)

output = tokenizer(
     text,
     padding="max_length",
     truncation=True,
     max_length=config.lm_n_positions,
     return_tensors="pt",
     add_special_tokens=True,
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]


imgs, input_ids, attention_mask = tensor_to_device(imgs, input_ids, attention_mask, device=device)

model = CLIPModel(config)
model.to(device)

image_embeds, text_embeds = model(imgs, input_ids, attention_mask)

assert image_embeds.shape == (len(text), config.project_dim)
assert text_embeds.shape == (len(text), config.project_dim)
print("Dummy forward pass successful!")

Loading weights:   0%|          | 0/292 [00:00<?, ?it/s]

[1mGPT2LMHeadModel LOAD REPORT[0m from: openai-community/gpt2-medium
Key                  | Status     |  | 
---------------------+------------+--+-
h.{0...23}.attn.bias | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Loaded HF weights with Conv1D->Linear transposes.


Loading weights:   0%|          | 0/198 [00:00<?, ?it/s]

[1mViTModel LOAD REPORT[0m from: google/vit-base-patch16-224
Key                 | Status     | 
--------------------+------------+-
classifier.bias     | UNEXPECTED | 
classifier.weight   | UNEXPECTED | 
pooler.dense.weight | MISSING    | 
pooler.dense.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


Loaded HF ViT weights.
Unexpected keys: ['pos_embed']
Loaded Vision Encoder from google/vit-base-patch16-224
Loaded Text Encoder from openai-community/gpt2-medium
Trainable parameters: 247686400 / 441096960
Dummy forward pass successful!
