In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('/data2/shreyas/multimodal/VLM')
from vision_model import VisionEncoder
from text_model import LLaMA
from types import SimpleNamespace
from functools import partial

In [35]:
class SimpleVLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.vision_model = VisionEncoder(self.config.embed_dim)
        self.text_model = LLaMA(self.config)

    def forward(self, input_ids, pixel_values):
        pass

In [36]:
class Config(SimpleNamespace):
    def get(self, key, default=None):
        return getattr(self, key, default)
    def __getitem__(self, key):
        return self.get(key)

config = Config(
    embed_dim = 576,
    intermediate_dim = 1536,
    max_position_embeddings = 8192,
    base_theta = 100000,
    num_q_heads = 9,
    num_kv_heads = 3,
    attn_dropout = 0.,
    num_layers = 30,
    vocab_size = 49152,
    dtype = torch.bfloat16,
    eos_token_id = 2
)

In [37]:
model = SimpleVLM(config)

In [38]:
target_layers = ['q_proj',
                 'k_proj',
                 'v_proj',
                 'o_proj',
                 'gate_proj',
                 'up_proj',
                 'down_proj'
                ]

In [39]:
class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank, alpha, lora_dropout=0.):
        super().__init__()
        
        self.linear = linear_layer
        for p in self.linear.parameters():
            p.requires_grad = False
        
        self.in_features = self.linear.in_features
        self.out_features = self.linear.out_features
        self.rank = rank
        self.alpha = alpha
        
        std_dev = 1 / torch.sqrt(torch.tensor(self.rank).float())
        
        self.A = nn.Parameter(torch.randn(self.in_features, self.rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(self.rank, self.out_features))
        self.dropout = nn.Dropout(lora_dropout)
        
    def forward(self, x):
        x1 = self.linear(x)
        x2 = self.alpha * (x @ self.A @ self.B)
        x2 = self.dropout(x2)
        return x1 + x2
    
    def merge(self):
        self.linear.weight.data += self.alpha * (self.A @ self.B).T
        return self.linear

In [40]:
lora_config = Config(
    rank=64,
    alpha=128,
    lora_dropout=0.05
)

In [41]:
apply_lora = partial(
    LoRALinear,
    rank=lora_config['rank'],
    alpha=lora_config['alpha'],
    lora_dropout=lora_config['lora_dropout']
)

In [42]:
state_dict = list(model.state_dict().keys())
target_modules = [k for k in state_dict if 'text_model' in k and any([t in k for t in target_layers])]

In [43]:
target_modules[:5]

['text_model.layers.0.self_attn.q_proj.weight',
 'text_model.layers.0.self_attn.k_proj.weight',
 'text_model.layers.0.self_attn.v_proj.weight',
 'text_model.layers.0.self_attn.o_proj.weight',
 'text_model.layers.0.mlp.gate_proj.weight']

In [44]:
for p in model.parameters():
    p.requires_grad = False

for name, module in model.text_model.named_modules():
    if any(name.endswith(t) for t in target_layers) and isinstance(module, nn.Linear):
        parent_name = ".".join(name.split(".")[:-1])  # Get parent module name
        parent_module = model.text_model.get_submodule(parent_name)  # Get parent module reference
        setattr(parent_module, name.split(".")[-1], apply_lora(module))  # Replace with LoRA-wrapped layer

for layer in [model.vision_model.rms_norm, model.vision_model.avg_pool, model.vision_model.dim_proj]:
    for p in layer.parameters():
        p.requires_grad = True

In [45]:
f"{sum([p.numel() for p in model.parameters() if p.requires_grad]):,}",f"{sum([p.numel() for p in model.parameters()]):,}"

('19,981,056', '248,016,192')

In [46]:
sum([p.numel() for p in model.vision_model.parameters()]),sum([p.numel() for p in model.text_model.parameters()])

(93963264, 154052928)

In [47]:
# for name, module in model.text_model.named_modules():
#     if any(name.endswith(t) for t in target_layers) and isinstance(module, LoRALinear):
#         parent_name = ".".join(name.split(".")[:-1])  # Get parent module name
#         parent_module = model.text_model.get_submodule(parent_name)  # Get parent module reference
#         setattr(parent_module, name.split(".")[-1], module.merge())  # Replace with LoRA-wrapped layer

In [48]:
# sum([p.numel() for p in model.vision_model.parameters()]),sum([p.numel() for p in model.text_model.parameters()])

In [49]:
params = dict(model.named_parameters())

In [50]:
[t for t in params.keys() if 'embed_tokens' in t],[t for t in params.keys() if 'lm_head' in t]

(['text_model.embed_tokens.weight'], [])

In [51]:
def get_optimizer_params(model):
    # Group parameters by component and desired learning rate
    
    # Vision model components with high learning rate
    vision_special_params = list(model.vision_model.rms_norm.parameters()) + \
                            list(model.vision_model.dim_proj.parameters())
    
    # LoraLinear instances with high learning rate
    lora_params = []
    for module in model.modules():
        if isinstance(module, LoRALinear):
            lora_params.extend(module.parameters())
    
    # Text embeddings and LM head with low learning rate
    # Since they're tied, we only need to include one of them
    text_embed_params = list(model.text_model.embed_tokens.parameters())
    
    # Get all other parameters (which will use the base learning rate)
    all_params = set(model.parameters())
    special_params = set(vision_special_params + lora_params + text_embed_params)
    base_params = list(all_params - special_params)
    
    # Create parameter groups with different learning rates
    param_groups = [
        {"params": base_params, "lr": 1e-5},
        {"params": vision_special_params, "lr": 1e-4},  # Higher learning rate
        {"params": lora_params, "lr": 1e-4},  # Higher learning rate
        {"params": text_embed_params, "lr": 1e-5}  # Lower learning rate
    ]
    
    return param_groups

In [52]:
optim = torch.optim.Adam(get_optimizer_params(model))

In [53]:
optim

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0

Parameter Group 2
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0

Parameter Group 3
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0
)

In [54]:
f"{sum([p.numel() for p in model.parameters() if p.requires_grad]):,}",f"{sum([p.numel() for p in model.parameters()]):,}"

('19,981,056', '248,016,192')