In [1]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights

import os
from pathlib import Path
import tomllib
import sys

# Add vision_language_alignment to path first (for dataset)
sys.path.insert(0, str(Path.cwd().parent))
from dataset import get_dataloaders

# Add text_pretraining to path (for Llama model)
sys.path.insert(0, str(Path.cwd().parent.parent / "text_pretraining"))
from model import Llama

In [2]:
DEVICE = "cuda"
CONFIG_PATH = Path.cwd().parent.parent / "config.toml"

with open(CONFIG_PATH, "rb") as f:
    cfg = tomllib.load(f)
    cfg["d_head"] = cfg["d_model"] // cfg["n_heads"]
    cfg["kv_d_head"] = cfg["d_model"] // cfg["n_kv_heads"]

In [3]:
cfg

{'batch_size': 7,
 'batches_per_epoch': 20000,
 'vocab_size': 32000,
 'd_model': 768,
 'n_heads': 12,
 'n_kv_heads': 4,
 'n_layers': 12,
 'max_seq_len': 1024,
 'num_experts': 8,
 'num_experts_per_tok': 2,
 'moe_layer_freq': 2,
 'd_expert': 1024,
 'd_ff_standard': 2048,
 'rope_layers_ratio': 0.75,
 'rope_theta': 10000,
 'chunk_size': 512,
 'vision_hidden_size': 768,
 'patch_size': 16,
 'image_size': 224,
 'd_head': 64,
 'kv_d_head': 192}

In [4]:
class VisionEncoder(nn.Module):
    """
    Vision encoder for LLaMA 4 vision-language alignment.
    Uses pretrained ViT-B/16 (frozen) with a trainable MLP projector.

    Flow: Image -> Pretrained ViT -> MLP Projector -> LLM-compatible embeddings
    """

    def __init__(self, llm_embed_dim=768):
        super().__init__()

        # Load pretrained ViT-B/16 (trained on ImageNet)
        vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

        # Extract components
        self.patch_embed = vit.conv_proj  # [B,3,224,224] -> [B,768,14,14]
        self.pos_embed = vit.encoder.pos_embedding  # [1, 197, 768] (includes CLS)
        self.encoder = vit.encoder.layers  # 12 transformer blocks
        self.norm = vit.encoder.ln  # Final LayerNorm

        # Freeze entire vision encoder
        for param in self.parameters():
            param.requires_grad = False

        # Trainable 2-layer MLP projector (like LLaVA)
        self.proj = nn.Sequential(
            nn.Linear(768, 768 * 4), nn.GELU(), nn.Linear(768 * 4, llm_embed_dim)
        )

    def forward(self, x):
        # x: [B, 3, 224, 224]

        # Patch embedding via Conv2d
        x = self.patch_embed(x)  # [B, 768, 14, 14]
        x = x.flatten(2).transpose(1, 2)  # [B, 196, 768]

        # Add positional embeddings (skip CLS token position at index 0)
        x = x + self.pos_embed[:, 1:, :]  # [B, 196, 768]

        # Transformer encoder
        for layer in self.encoder:
            x = layer(x)
        x = self.norm(x)  # [B, 196, 768]

        # Project to LLM embedding space
        x = self.proj(x)  # [B, 196, llm_embed_dim]
        return x


# # Test the vision encoder
# vision_encoder = VisionEncoder(llm_embed_dim=768)

# # Check trainable vs frozen params
# trainable = sum(p.numel() for p in vision_encoder.parameters() if p.requires_grad)
# total = sum(p.numel() for p in vision_encoder.parameters())
# print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")

# # Forward pass
# x = torch.randn(2, 3, 224, 224)
# out = vision_encoder(x)
# print(f"Input:  {x.shape}")
# print(f"Output: {out.shape}")  # [B, 196, 768]


In [5]:
class VisionLanguageModel(nn.Module):
    def __init__(self,vocab_size,text_pretrained_model,llm_embed_dim,ckpt_path):
        super().__init__()

        self.vision_encoder = VisionEncoder(llm_embed_dim=llm_embed_dim)
        self.text_pretrained_model = text_pretrained_model
        self._load_model_from_ckpt(ckpt_path)

        # Embedding layer
        new_emb = nn.Embedding(vocab_size + 1, llm_embed_dim)
        new_emb.weight.data[:vocab_size, :] = self.text_pretrained_model.emb.emb.weight.data
        new_emb.weight.data[vocab_size, :] = torch.randn(llm_embed_dim) * (llm_embed_dim ** -0.5)
        self.text_pretrained_model.emb.emb = new_emb

        # Output projection
        old_proj = self.text_pretrained_model.proj_vocab
        new_proj = nn.Linear(llm_embed_dim, vocab_size + 1, bias=False)
        new_proj.weight.data[:vocab_size, :] = old_proj.weight.data
        new_proj.weight.data[vocab_size, :] = torch.randn(llm_embed_dim) * 0.02
        self.text_pretrained_model.proj_vocab = new_proj

        for param in self.text_pretrained_model.parameters():
            param.requires_grad = False
        
        self.image_token_id = 32000
        self.pad_token_id = 3
        
    def _load_model_from_ckpt(self,ckpt_path):
        if os.path.exists(ckpt_path):
            print(f"Loading checkpoint from: {ckpt_path}")
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            self.text_pretrained_model.load_state_dict(checkpoint["model_state_dict"])
            print("Checkpoint loaded successfully")
        else:
            print(f"File not found at the path provided")

    def forward(self,image,input_ids):
        vision_embeds = self.vision_encoder(image)  # [B, 196, llm_embed_dim]
        text_embeds = self.text_pretrained_model.emb(input_ids)  # [B, seq_len, llm_embed_dim]
        combined = torch.cat([vision_embeds,text_embeds[:,1:,:]],dim=1) # [B, 196 + seq_len-1, llm_embed_dim]
        
        print(vision_embeds.shape, text_embeds.shape)
        print(combined.shape)  # [B, 196 + seq_len-1, llm_embed_dim]

        for i, decoder in enumerate(self.text_pretrained_model.decoder_layers):
            combined = decoder(i, combined)
            print(i,combined.shape)

        combined = self.text_pretrained_model.rms_norm(combined) # [B, 196 + seq_len-1, llm_embed_dim]
        logits = self.text_pretrained_model.proj_vocab(combined) # [B, 196 + seq_len-1, vocab_size+1]

        return logits


In [6]:
text_pretrained_model = Llama(
        vocab_size=cfg["vocab_size"],
        n_layers=cfg["n_layers"],
        d_model=cfg["d_model"],
        d_head=cfg["d_head"],
        n_heads=cfg["n_heads"],
        n_kv_heads=cfg["n_kv_heads"],
        kv_d_head=cfg["kv_d_head"],
        d_ff_standard=cfg["d_ff_standard"],
        num_experts=cfg["num_experts"],
        num_experts_per_tok=cfg["num_experts_per_tok"],
        d_expert=cfg["d_expert"],
        rope_layers_ratio=cfg["rope_layers_ratio"],
        chunk_size=cfg["chunk_size"],
        rope_theta=cfg["rope_theta"],
    )

In [7]:
model = VisionLanguageModel(vocab_size=cfg["vocab_size"],
                            text_pretrained_model=text_pretrained_model,
                            llm_embed_dim=cfg['d_model'],
                            ckpt_path="/home/smedar/code_files/llama4-from-scratch/text_pretraining/checkpoints/best.pt")

Loading checkpoint from: /home/smedar/code_files/llama4-from-scratch/text_pretraining/checkpoints/best.pt
Checkpoint loaded successfully


In [8]:
# Check trainable vs frozen params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")

Trainable: 4,722,432 / 470,069,772 (1.0%)


In [9]:
train_loader, val_loader = get_dataloaders(batch_size=1, max_seq_len=128)

print("\nTesting train loader...")
batch = next(iter(train_loader))
print(f"  image shape:          {batch['image'].shape}")
print(f"  input_ids shape:      {batch['input_ids'].shape}")
print(f"  attention_mask shape: {batch['attention_mask'].shape}")
print(f"  labels shape:         {batch['labels'].shape}")
print(f"  First token (should be 32000): {batch['input_ids'][0][0].item()}")


Loading tokenizer...
Loaded tokenizer from /home/smedar/code_files/llama4-from-scratch/vision_language_alignment/bpe_tokenizer_with_image_tag.json
Vocab size: 32001
Setting up image transforms...
Loading COCO captions dataset (streaming)...


Resolving data files:   0%|          | 0/182 [00:00<?, ?it/s]

Batch size: 1
Max sequence length: 128

Testing train loader...
  image shape:          torch.Size([1, 3, 224, 224])
  input_ids shape:      torch.Size([1, 128])
  attention_mask shape: torch.Size([1, 128])
  labels shape:         torch.Size([1, 128])
  First token (should be 32000): 32000


In [10]:
model(batch['image'],batch['input_ids']).shape

torch.Size([1, 196, 768]) torch.Size([1, 128, 768])
torch.Size([1, 323, 768])
0 torch.Size([1, 323, 768])
1 torch.Size([1, 323, 768])
2 torch.Size([1, 323, 768])
3 torch.Size([1, 323, 768])
4 torch.Size([1, 323, 768])
5 torch.Size([1, 323, 768])
6 torch.Size([1, 323, 768])
7 torch.Size([1, 323, 768])
8 torch.Size([1, 323, 768])
9 torch.Size([1, 323, 768])
10 torch.Size([1, 323, 768])
11 torch.Size([1, 323, 768])


torch.Size([1, 323, 32001])

In [11]:
device = 'cuda'

opt = torch.optim.Adam(model.parameters(),lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

model = model.to(device)

for i,b in enumerate(train_loader):
    image = b['image'].to(device)
    input_ids = b['input_ids'].to(device)
    attention_mask = b['attention_mask'].to(device)

    opt.zero_grad()

    logits = model(image,input_ids)  # [B, 323, vocab_size+1]
 
    labels = torch.full((logits.size(0),323),-100, device=device) # [B,323]
    labels[:,196:] = input_ids[:,1:] # caption_tokens (without <image>)
    labels[labels == model.pad_token_id] = -100

    # Reshape for cross-entropy
    logits_flat = logits[:, :-1, :].reshape(-1, cfg['vocab_size']+1)  # [B*322, vocab_size+1]
    labels_flat = labels[:, 1:].reshape(-1)                   # [B*322]

    loss = loss_fn(logits_flat, labels_flat)

    # Backward pass
    loss.backward()
    
    # update weights
    opt.step()

    break

torch.Size([1, 196, 768]) torch.Size([1, 128, 768])
torch.Size([1, 323, 768])
0 torch.Size([1, 323, 768])
1 torch.Size([1, 323, 768])
2 torch.Size([1, 323, 768])
3 torch.Size([1, 323, 768])
4 torch.Size([1, 323, 768])
5 torch.Size([1, 323, 768])
6 torch.Size([1, 323, 768])
7 torch.Size([1, 323, 768])
8 torch.Size([1, 323, 768])
9 torch.Size([1, 323, 768])
10 torch.Size([1, 323, 768])
11 torch.Size([1, 323, 768])
