In [1]:
import torch
import os
from transformers import MllamaConfig, MllamaForConditionalGeneration
from transformers.models.mllama.configuration_mllama import MllamaTextConfig,MllamaVisionConfig
from transformers.models.mllama.modeling_mllama import MllamaCrossAttentionDecoderLayer
from transformers.utils import logging
from transformers.modeling_rope_utils import rope_config_validation
from typing import Dict, List, Optional, Union
from models import AstrollavaConfig, AstrollavaTextConfig, AstroMllamaForConditionalGeneration

In [2]:
checkpoint = "/mnt/data/CVPR2025/task1_data/Llama-3.2-11B-Vision-Instruct"
vision_config = MllamaVisionConfig.from_pretrained(checkpoint)
text_config = MllamaTextConfig.from_pretrained(checkpoint)

In [3]:
# add 2 layers, spectrum first
original_cross_attention_layers = [3, 8, 13, 18, 23, 27, 33, 38]
new_cross_attention_layers = []
new_vision_cross_attention_layers = []
new_spec_cross_attention_layers = []
new_structure_cross_attention_layers = []
num_new_layers = 0
for x in original_cross_attention_layers:
    new_vision = x + num_new_layers
    new_vision_cross_attention_layers.append(new_vision)
    if x == 18:
        new_spec = x + num_new_layers + 1
        num_new_layers += 1
        new_spec_cross_attention_layers.append(new_spec)
    elif x == 38:
        new_structure = x + num_new_layers + 1
        num_new_layers += 1
        new_structure_cross_attention_layers.append(new_structure)   
    
new_cross_attention_layers = new_vision_cross_attention_layers + new_spec_cross_attention_layers + new_structure_cross_attention_layers
new_cross_attention_layers = sorted(new_cross_attention_layers)
print(new_cross_attention_layers)
print(new_vision_cross_attention_layers)
print(new_spec_cross_attention_layers)
print(new_structure_cross_attention_layers)

[3, 8, 13, 18, 19, 24, 28, 34, 39, 40]
[3, 8, 13, 18, 24, 28, 34, 39]
[19]
[40]


In [4]:
new_text_config = AstrollavaTextConfig.from_pretrained("/mnt/data/CVPR2025/task1_data/Llama-3.2-11B-Vision-Instruct",
        cross_attention_layers=new_cross_attention_layers,
        structure_cross_attention_layers = new_structure_cross_attention_layers,
        spectrum_cross_attention_layers  = new_spec_cross_attention_layers,
        vision_cross_attention_layers  = new_vision_cross_attention_layers,
        num_hidden_layers = 40 + 2,
        rope_scaling={
            "factor": 8.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3"
        },
        torch_dtype="bfloat16")

In [5]:
new_text_config

AstrollavaTextConfig {
  "bos_token_id": 128000,
  "cross_attention_layers": [
    3,
    8,
    13,
    18,
    19,
    24,
    28,
    34,
    39,
    40
  ],
  "dropout": 0,
  "eos_token_id": [
    128001,
    128008,
    128009
  ],
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 131072,
  "model_type": "mllama_text_model",
  "num_attention_heads": 32,
  "num_hidden_layers": 42,
  "num_key_value_heads": 8,
  "pad_token_id": 128004,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 8.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "spectrum_cross_attention_layers": [
    19
  ],
  "spectrum_output_dim": 1024,
  "structure_cross_attention_layers": [
    40
  ],
  "structure_output_dim": 1024,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_versio

In [6]:
new_config = AstrollavaConfig.from_pretrained("/mnt/data/CVPR2025/task1_data/Llama-3.2-11B-Vision-Instruct", text_config=new_text_config, vision_config=vision_config)

In [7]:
new_config.save_pretrained("/mnt/data/CVPR2025/task1_data/astra-llava-add2-spec-first")

vision_config is None, using default mllama vision config
text_config is None, using default mllama text config


In [8]:
new_model = AstroMllamaForConditionalGeneration(new_config)

In [9]:
new_model.bfloat16()

AstroMllamaForConditionalGeneration(
  (language_model): AstroMllamaForCausalLM(
    (model): AstroMllamaTextModel(
      (embed_tokens): Embedding(128264, 4096, padding_idx=128004)
      (layers): ModuleList(
        (0-2): 3 x MllamaSelfAttentionDecoderLayer(
          (self_attn): MllamaTextSelfSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): MllamaTextMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): MllamaTextRMSNor

In [11]:
import torch
import json
from safetensors.torch import load_file
import os

def load_pretrained_weights(checkpoint_dir):
    """加载预训练权重"""
    with open(os.path.join(checkpoint_dir, "model.safetensors.index.json"), 'r') as f:
        weight_map = json.load(f)["weight_map"]
    
    state_dict = {}
    for key, filename in weight_map.items():
        weights = load_file(os.path.join(checkpoint_dir, filename))
        if key in weights:
            state_dict[key] = weights[key]
    
    return state_dict

def map_weights(new_model, checkpoint_dir):
    """映射权重到新模型"""
    # 加载预训练权重
    pretrained_weights = load_pretrained_weights(checkpoint_dir)
    
    # 获取新模型的state dict
    new_state_dict = new_model.state_dict()
    
    # 定义层的映射关系
    original_cross_attention_layers = [3, 8, 13, 18, 23, 27, 33, 38]
    new_vision_cross_attention_layers = [3, 8, 13, 18, 24, 28, 34, 39]
    new_spec_cross_attention_layers = [19]  # 使用layer 18的权重初始化
    new_structure_cross_attention_layers = [40]  # 使用layer 38的权重初始化
    
    # 初始化映射字典
    mapped_state_dict = {}
    
    for key in new_state_dict.keys():
        # 处理projector层
        if 'spectrum_modal_projector' in key or 'structure_modal_projector' in key:
            mapped_state_dict[key] = new_state_dict[key]
            continue
            
        # 处理新增的spectrum cross attention层
        if any(f'language_model.model.layers.{idx}' in key for idx in new_spec_cross_attention_layers):
            # 使用layer 18的权重
            source_key = key.replace(f'.19.', '.18.')
            if source_key in pretrained_weights:
                mapped_state_dict[key] = pretrained_weights[source_key].clone()
            else:
                print(f"Warning: Source key {source_key} not found for spectrum layer")
                mapped_state_dict[key] = new_state_dict[key]
            continue
            
        # 处理新增的structure cross attention层
        if any(f'language_model.model.layers.{idx}' in key for idx in new_structure_cross_attention_layers):
            # 使用layer 38的权重
            source_key = key.replace(f'.40.', '.38.')
            if source_key in pretrained_weights:
                mapped_state_dict[key] = pretrained_weights[source_key].clone()
            else:
                print(f"Warning: Source key {source_key} not found for structure layer")
                mapped_state_dict[key] = new_state_dict[key]
            continue
            
        # 处理vision cross attention层的映射
        is_cross_attn = False
        for i, new_idx in enumerate(new_vision_cross_attention_layers):
            if f'language_model.model.layers.{new_idx}' in key and 'cross_attn' in key:
                orig_idx = original_cross_attention_layers[i]
                mapped_key = key.replace(f'.{new_idx}.', f'.{orig_idx}.')
                if mapped_key in pretrained_weights:
                    mapped_state_dict[key] = pretrained_weights[mapped_key]
                    is_cross_attn = True
                break
        
        # 其他层直接复制
        if not is_cross_attn:
            if key in pretrained_weights:
                mapped_state_dict[key] = pretrained_weights[key]
            else:
                print(f"Warning: {key} not found in pretrained weights")
                mapped_state_dict[key] = new_state_dict[key]
    
    # 加载权重到新模型
    missing_keys = new_model.load_state_dict(mapped_state_dict, strict=False)
    
    if missing_keys.missing_keys:
        print("Missing keys:", missing_keys.missing_keys)
    if missing_keys.unexpected_keys:
        print("Unexpected keys:", missing_keys.unexpected_keys)
    
    print("\nWeight mapping summary:")
    print("- Spectrum adapter (layer 19) initialized from layer 18")
    print("- Structure adapter (layer 40) initialized from layer 38")
    print("- Vision cross attention layers mapped:", list(zip(original_cross_attention_layers, new_vision_cross_attention_layers)))
        
    return new_model

# 使用示例
checkpoint_dir = "/mnt/data/CVPR2025/task1_data/Llama-3.2-11B-Vision-Instruct"
new_model = map_weights(new_model, checkpoint_dir)


Weight mapping summary:
- Spectrum adapter (layer 19) initialized from layer 18
- Structure adapter (layer 40) initialized from layer 38
- Vision cross attention layers mapped: [(3, 3), (8, 8), (13, 13), (18, 18), (23, 24), (27, 28), (33, 34), (38, 39)]


In [12]:
new_model.save_pretrained("/mnt/data/CVPR2025/task1_data/astra-llava-add2-spec-first")

vision_config is None, using default mllama vision config
text_config is None, using default mllama text config
vision_config is None, using default mllama vision config
text_config is None, using default mllama text config


In [13]:
new_model

AstroMllamaForConditionalGeneration(
  (language_model): AstroMllamaForCausalLM(
    (model): AstroMllamaTextModel(
      (embed_tokens): Embedding(128264, 4096, padding_idx=128004)
      (layers): ModuleList(
        (0-2): 3 x MllamaSelfAttentionDecoderLayer(
          (self_attn): MllamaTextSelfSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): MllamaTextMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): MllamaTextRMSNor