In [3]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights

In [2]:
x = torch.randn(2,3,224,224)

b,c,h,w = x.shape
x.shape

torch.Size([2, 3, 224, 224])

In [3]:
num_patches = 196
patch_size = 16

In [4]:
unfolded_h = x.unfold(dimension=2, size=patch_size, step=patch_size)
unfolded_h.shape

torch.Size([2, 3, 14, 224, 16])

In [5]:

unfolded_w = unfolded_h.unfold(dimension=3, size=patch_size, step=patch_size)
unfolded_w.shape

torch.Size([2, 3, 14, 14, 16, 16])

In [6]:
# channels, patches vertically, patches horizontally ,height of each patch, width of each patch

In [7]:
patches = unfolded_w.reshape(
    b,num_patches, -1
)  # [num_patches, channels * patch_h * patch_w] => [4, 3*4*4] = [4, 48]
patches.shape

torch.Size([2, 196, 768])

In [8]:
class VisionAligment(nn.Module):
    def __init__(self,channels,num_patches,patch_size):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.embed_dim = channels*patch_size*patch_size
        
        self.pos_embed = nn.Parameter(torch.randn(1,self.num_patches,self.embed_dim)*0.02)
        self.fc = nn.Linear(self.embed_dim,self.embed_dim,bias=False)

        self.vit_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(
            d_model=768,
            nhead=12,
            dim_feedforward= 3072,
            activation= 'gelu',
            batch_first=True,
            norm_first=True
            ),
            num_layers=6,
            enable_nested_tensor=False)
        
        for param in self.vit_encoder.parameters():
            param.requires_grad = False


        self.proj_out = nn.Linear(self.embed_dim,self.embed_dim,bias=False)

    def _extract_patches(self,x,patch_size,num_patches):
        batch = x.shape[0]
        unfolded_h = x.unfold(dimension=2, size=patch_size, step=patch_size)
        unfolded_w = unfolded_h.unfold(dimension=3, size=patch_size, step=patch_size)
        patches = unfolded_w.reshape(
            batch, num_patches, -1
        ) # [batch, num_patches, embed_dim
        return patches
    
    def forward(self,x):
        # x: [b,c,h,w]
        patches = self._extract_patches(x,self.patch_size,self.num_patches)  # [batch, num_patches, embed_dim]
        patches = self.fc(patches) + self.pos_embed  # [batch, num_patches, embed_dim]
        patches = self.vit_encoder(patches) # [batch, num_patches, embed_dim]
        patches = self.proj_out(patches) # [batch, num_patches, embed_dim]
        return patches

In [9]:
model = VisionAligment(3,196,16)
x = torch.randn(2,3,224,224)
out = model(x)
out.shape

torch.Size([2, 196, 768])

In [4]:
class VisionEncoder(nn.Module):
    """
    Vision encoder for LLaMA 4 vision-language alignment.
    Uses pretrained ViT-B/16 (frozen) with a trainable MLP projector.
    
    Flow: Image -> Pretrained ViT -> MLP Projector -> LLM-compatible embeddings
    """
    def __init__(self, llm_embed_dim=768):
        super().__init__()
        
        # Load pretrained ViT-B/16 (trained on ImageNet)
        vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        
        # Extract components
        self.patch_embed = vit.conv_proj          # [B,3,224,224] -> [B,768,14,14]
        self.pos_embed = vit.encoder.pos_embedding # [1, 197, 768] (includes CLS)
        self.encoder = vit.encoder.layers          # 12 transformer blocks
        self.norm = vit.encoder.ln                 # Final LayerNorm
        
        # Freeze entire vision encoder
        for param in self.parameters():
            param.requires_grad = False
        
        # Trainable 2-layer MLP projector (like LLaVA)
        self.proj = nn.Sequential(
            nn.Linear(768, 768 * 4),
            nn.GELU(),
            nn.Linear(768 * 4, llm_embed_dim)
        )
    
    def forward(self, x):
        # x: [B, 3, 224, 224]
        
        # Patch embedding via Conv2d
        x = self.patch_embed(x)              # [B, 768, 14, 14]
        x = x.flatten(2).transpose(1, 2)     # [B, 196, 768]
        
        # Add positional embeddings (skip CLS token position at index 0)
        x = x + self.pos_embed[:, 1:, :]     # [B, 196, 768]
        
        # Transformer encoder
        for layer in self.encoder:
            x = layer(x)
        x = self.norm(x)                     # [B, 196, 768]
        
        # Project to LLM embedding space
        x = self.proj(x)                     # [B, 196, llm_embed_dim]
        return x

In [5]:
# Test the vision encoder
vision_encoder = VisionEncoder(llm_embed_dim=768)

# Check trainable vs frozen params
trainable = sum(p.numel() for p in vision_encoder.parameters() if p.requires_grad)
total = sum(p.numel() for p in vision_encoder.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")

# Forward pass
x = torch.randn(2, 3, 224, 224)
out = vision_encoder(x)
print(f"Input:  {x.shape}")
print(f"Output: {out.shape}")  # [B, 196, 768] - ready for LLM!

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/smedar/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:04<00:00, 75.2MB/s] 


Trainable: 4,722,432 / 90,520,320 (5.2%)
Input:  torch.Size([2, 3, 224, 224])
Output: torch.Size([2, 196, 768])
