# MatFormer for Qwen3

## Model Specifications
- **Qwen3-4B**: 4.0B parameters, 36 layers, 32 Q heads, 8 KV heads
- **Qwen3-1.7B**: 1.7B parameters, 28 layers, 16 Q heads, 8 KV heads
- **Target 3B**: Custom configuration with Mix-n-Match FFN dimensions

## 1. Setup and Dependencies

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import gc
import json
import re
import random
import copy
from tqdm.auto import tqdm
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import (
    AutoConfig, 
    AutoTokenizer, 
    AutoModelForCausalLM,
    GenerationConfig,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer
)
from huggingface_hub import snapshot_download, HfApi
from datasets import load_dataset
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed

## 2. Configuration and Settings

In [None]:
# Model configuration
model_4b_id = "Qwen/Qwen3-4B"
model_1_7b_id = "Qwen/Qwen3-1.7B"

# Output configuration
local_output_path = "./output/qwen3_3b_model"
funetuning_output_path = "./output/matformer_finetune_results_integrated"
final_output_path = "./output/matformer_qwen3_3b_finetuned"

# Target 3B model configuration
target_num_layers = 32
target_params = "3B"

# GPUを設定
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"

# シード設定
set_seed(42)

## 3. MatFormer Custom Architecture Classes

In [None]:
class MatQwenMLP(nn.Module):
    """
    Qwen3のFFN(MLP)をMatFormerアーキテクチャに変更するカスタムクラス。
    訓練と推論の両方で動的にサイズを変更できるようする。
    """
    def __init__(self, config, ffn_dim):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, ffn_dim, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, ffn_dim, bias=False)
        self.down_proj = nn.Linear(ffn_dim, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

        # MatFormer用の設定
        self.full_intermediate_size = ffn_dim
        self.current_intermediate_size = ffn_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 現在設定されているサイズで重みをスライスして計算
        active_gate_weight = self.gate_proj.weight[:self.current_intermediate_size, :]
        active_up_weight = self.up_proj.weight[:self.current_intermediate_size, :]
        active_down_weight = self.down_proj.weight[:, :self.current_intermediate_size]

        # 入力と同じデータ型に変換
        if x.dtype != active_gate_weight.dtype:
            active_gate_weight = active_gate_weight.to(x.dtype)
            active_up_weight = active_up_weight.to(x.dtype)
            active_down_weight = active_down_weight.to(x.dtype)

        gate_output = nn.functional.linear(x, active_gate_weight)
        up_output = nn.functional.linear(x, active_up_weight)

        activated_output = self.act_fn(gate_output) * up_output

        # transposeフラグをFalseに設定して効率化
        output = nn.functional.linear(activated_output, active_down_weight, bias=None)

        return output

## 4. Load Model Configurations

In [21]:
# Load configurations for both models
config_4b = AutoConfig.from_pretrained(model_4b_id)
config_1_7b = AutoConfig.from_pretrained(model_1_7b_id)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_4b_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

## 5. Create Target 3B Model Configuration with Mix-n-Match

In [22]:
# Create target 3B configuration based on 4B model with Mix-n-Match approach
target_config = copy.deepcopy(config_4b)

# Target configuration for 3B model with Mix-n-Match
target_config.num_hidden_layers = 32  # Between 4B(36) and 1.7B(28)

# Mix-n-Match: Use different FFN dimensions for different layers
# Inspired by Gemma3n's approach - later layers get more capacity
ffn_dims_per_layer = []
for i in range(32):
    if i < 8:  # Early layers: smallest FFN (similar to 1.7B)
        ffn_dims_per_layer.append(8192)
    elif i < 16:  # Early-middle layers: 1.7B size
        ffn_dims_per_layer.append(9728)
    elif i < 24:  # Late-middle layers: between 1.7B and 4B
        ffn_dims_per_layer.append(11776)
    else:  # Final layers: closer to 4B size
        ffn_dims_per_layer.append(13312)

# For Qwen3, we need to set a single intermediate_size in config
# We'll use the maximum size to ensure compatibility
target_config.intermediate_size = max(ffn_dims_per_layer)
target_config.num_attention_heads = 32  # Keep same as 4B
target_config.num_key_value_heads = 8   # Keep same as both models

# Store the per-layer FFN dimensions as a custom attribute
target_config.ffn_dims_per_layer = ffn_dims_per_layer

# Estimate parameter count with Mix-n-Match
def estimate_params_mixnmatch(config, ffn_dims):
    hidden_size = config.hidden_size
    num_layers = config.num_hidden_layers
    vocab_size = config.vocab_size
    
    # Embedding parameters
    embed_params = vocab_size * hidden_size
    
    # Attention parameters per layer (same for all layers)
    attn_params_per_layer = (
        hidden_size * hidden_size * 3 +  # q, k, v projections
        hidden_size * hidden_size        # output projection
    )
    
    # Layer norm parameters per layer
    ln_params_per_layer = hidden_size * 2  # input and post-attention layer norms
    
    # Calculate FFN parameters for each layer individually
    total_ffn_params = 0
    for layer_idx in range(num_layers):
        intermediate_size = ffn_dims[layer_idx]
        ffn_params = (
            hidden_size * intermediate_size * 2 +  # gate and up projections
            intermediate_size * hidden_size         # down projection
        )
        total_ffn_params += ffn_params
    
    # Total transformer parameters
    transformer_params = num_layers * (attn_params_per_layer + ln_params_per_layer) + total_ffn_params
    
    # Output layer parameters
    output_params = vocab_size * hidden_size
    
    total_params = embed_params + transformer_params + output_params
    return total_params

estimated_params = estimate_params_mixnmatch(target_config, ffn_dims_per_layer)

# Calculate average FFN dimension for reference
avg_ffn_dim = sum(ffn_dims_per_layer) / len(ffn_dims_per_layer)

## 6. Download Model Weights

In [23]:
# Download model checkpoints
model_4b_path = snapshot_download(model_4b_id, allow_patterns=["*.safetensors"])
safetensor_4b_files = [os.path.join(model_4b_path, f) for f in os.listdir(model_4b_path) if f.endswith('.safetensors')]

Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 16912.52it/s]


## 7. MatFormer Implementation: Create 3B Model with Mix-n-Match

In [24]:
def create_matformer_3b_model():
    # Create output directory
    os.makedirs(local_output_path, exist_ok=True)
    
    # Save target configuration and tokenizer
    target_config.save_pretrained(local_output_path)
    tokenizer.save_pretrained(local_output_path)
    
    # Layer mapping strategy
    source_layers = list(range(32))  # Use first 32 layers from 4B model
    target_layers = list(range(32))  # Map to 32 target layers
    layer_mapping = {src: tgt for src, tgt in zip(source_layers, target_layers)}
    
    return layer_mapping

layer_mapping = create_matformer_3b_model()

In [25]:
# Process model weights and create 3B model with Mix-n-Match
def process_model_weights():
    # Weight mapping for the new model
    weight_map = {}
    new_shard_state_dict = {}
    shard_counter = 1
    max_shard_size = 4 * 1024 * 1024 * 1024  # 4GB per shard
    
    # Get per-layer FFN dimensions
    if hasattr(target_config, 'ffn_dims_per_layer'):
        ffn_dims_per_layer = target_config.ffn_dims_per_layer
    else:
        # Fallback to single dimension
        ffn_dims_per_layer = [target_config.intermediate_size] * target_config.num_hidden_layers
        print(f"Using uniform FFN dimension: {target_config.intermediate_size}")
    
    source_intermediate_size = config_4b.intermediate_size
    
    pbar = tqdm(total=len(safetensor_4b_files), desc="Processing 4B model shards")
    
    for shard_path in safetensor_4b_files:
        with safe_open(shard_path, framework="pt", device="cpu") as f:
            for tensor_name in f.keys():
                tensor = f.get_tensor(tensor_name)
                new_tensor_name = tensor_name
                
                # Handle layer-specific parameters
                layer_match = re.search(r'\.layers\.(\d+)\.', tensor_name)
                if layer_match:
                    old_layer_idx = int(layer_match.group(1))
                    
                    # Skip layers beyond our target count
                    if old_layer_idx >= target_config.num_hidden_layers:
                        continue
                    
                    # Keep the layer index as-is for layers within our range
                    new_layer_idx = old_layer_idx
                    new_tensor_name = tensor_name.replace(
                        f'.layers.{old_layer_idx}.',
                        f'.layers.{new_layer_idx}.'
                    )
                    
                    # Get target FFN dimension for this specific layer
                    target_intermediate_size = ffn_dims_per_layer[new_layer_idx]
                    
                    # Handle FFN weight slicing based on layer-specific dimension
                    if 'mlp.gate_proj.weight' in new_tensor_name or 'mlp.up_proj.weight' in new_tensor_name:
                        # Slice output dimension (FFN dimension)
                        if target_intermediate_size < tensor.shape[0]:
                            # 縮小：単純にスライスする
                            tensor = tensor[:target_intermediate_size, :].contiguous()
                        elif target_intermediate_size > tensor.shape[0]:
                            # 拡張：ゼロパディングを行う
                            padded_tensor = torch.zeros(target_intermediate_size, tensor.shape[1], dtype=tensor.dtype)
                            padded_tensor[:tensor.shape[0], :] = tensor
                            tensor = padded_tensor
                    elif 'mlp.down_proj.weight' in new_tensor_name:
                        # Slice input dimension (FFN dimension)
                        if target_intermediate_size < tensor.shape[1]:
                            # 縮小：単純にスライスする
                            tensor = tensor[:, :target_intermediate_size].contiguous()
                        elif target_intermediate_size > tensor.shape[1]:
                            # 拡張：ゼロパディングを行う
                            padded_tensor = torch.zeros(tensor.shape[0], target_intermediate_size, dtype=tensor.dtype)
                            padded_tensor[:, :tensor.shape[1]] = tensor
                            tensor = padded_tensor
                
                # Add tensor to current shard
                new_shard_state_dict[new_tensor_name] = tensor
                
                # Check shard size and save if needed
                current_shard_size = sum(t.numel() * t.element_size() for t in new_shard_state_dict.values())
                if current_shard_size > max_shard_size:
                    shard_filename = f"model-{shard_counter:05d}-of-XXXXX.safetensors"
                    save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename))
                    
                    # Update weight map
                    for k in new_shard_state_dict.keys():
                        weight_map[k] = shard_filename
                    
                    # Reset for next shard
                    shard_counter += 1
                    new_shard_state_dict = {}
                    gc.collect()
        
        pbar.update(1)
    
    pbar.close()
    
    # Save final shard if any tensors remain
    if new_shard_state_dict:
        shard_filename = f"model-{shard_counter:05d}-of-XXXXX.safetensors"
        save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename))
        for k in new_shard_state_dict.keys():
            weight_map[k] = shard_filename
    
    # Save the FFN dimension information for later reference
    ffn_info_path = os.path.join(local_output_path, "ffn_dims.json")
    with open(ffn_info_path, "w") as f:
        json.dump({
            "ffn_dims_per_layer": ffn_dims_per_layer,
            "source_intermediate_size": source_intermediate_size,
            "method": "mix-n-match"
        }, f, indent=2)
    
    return weight_map, shard_counter

weight_map, num_shards = process_model_weights()

Processing 4B model shards: 100%|██████████| 3/3 [00:08<00:00,  2.85s/it]


## 8. Finalize Model Save

In [26]:
# Finalize model save with proper indexing
def finalize_model_save():    
    # Update shard filenames with correct total count
    final_weight_map = {}
    
    # First, rename all files
    for i in range(1, num_shards + 1):
        old_filename = f"model-{i:05d}-of-XXXXX.safetensors"
        new_filename = f"model-{i:05d}-of-{num_shards:05d}.safetensors"
        
        # Rename file
        old_path = os.path.join(local_output_path, old_filename)
        new_path = os.path.join(local_output_path, new_filename)
        if os.path.exists(old_path):
            os.rename(old_path, new_path)
    
    # Then update the weight map with new filenames
    for k, v in weight_map.items():
        # Replace XXXXX with the actual number of shards
        if "XXXXX" in v:
            # Extract the shard number
            shard_num = int(v.split("-")[1])
            new_filename = f"model-{shard_num:05d}-of-{num_shards:05d}.safetensors"
            final_weight_map[k] = new_filename
        else:
            final_weight_map[k] = v
    
    # Calculate total model size
    total_size = sum(os.path.getsize(os.path.join(local_output_path, f)) 
                    for f in os.listdir(local_output_path) 
                    if f.endswith('.safetensors'))
    
    # Create model index file
    index_json = {
        "metadata": {
            "total_size": total_size
        },
        "weight_map": final_weight_map
    }
    
    with open(os.path.join(local_output_path, "model.safetensors.index.json"), "w") as f:
        json.dump(index_json, f, indent=2)
    
    return total_size

total_size = finalize_model_save()

## 9. Custom Model Loader

In [None]:
def load_mix_n_match_model_strictly(model_path, device):
    # configから「空の」モデルを初期化
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

    # ffn_dims.jsonから各層のFFN次元を読み込む
    ffn_dims_path = os.path.join(model_path, "ffn_dims.json")
    if not os.path.exists(ffn_dims_path):
        raise FileNotFoundError(f"ffn_dims.json が {ffn_dims_path} に見つかりません。モデルパスが正しいか確認してください。")

    with open(ffn_dims_path, 'r') as f:
        ffn_info = json.load(f)
    ffn_dims_per_layer = ffn_info['ffn_dims_per_layer']
    
    # メモリマップモードでモデルを初期化
    with torch.device('meta'):
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

    # 各レイヤーのMLPをMatQwenMLPに置き換える
    for i, layer in enumerate(model.model.layers):
        layer.mlp = MatQwenMLP(config, ffn_dims_per_layer[i])
    
    # メタテンソルを実際のテンソルに変換
    model = model.to_empty(device=device)

    # model.safetensors.index.jsonからshardのリストを取得
    index_path = os.path.join(model_path, "model.safetensors.index.json")
    if not os.path.exists(index_path):
         raise FileNotFoundError(f"model.safetensors.index.json が {index_path} に見つかりません。モデルパスが正しいか確認してください。")

    with open(index_path, 'r') as f:
        index = json.load(f)

    shard_files = sorted(list(set(index['weight_map'].values())))

    # バッチでロード
    for shard_file in tqdm(shard_files, desc="Loading shards"):
        shard_path = os.path.join(model_path, shard_file)
        if not os.path.exists(shard_path):
             raise FileNotFoundError(f"シャードファイル {shard_path} が見つかりません。モデルパスが正しいか、すべてのシャードがダウンロードされているか確認してください。")

        # メモリ読み込み
        with safe_open(shard_path, framework="pt", device="cpu") as f:
            for tensor_name in f.keys():
                # CPUで読み込み、直接デバイスに転送
                saved_tensor = f.get_tensor(tensor_name)

                # モデル内の対応するパラメータを取得
                try:
                    param = model.get_parameter(tensor_name)
                except AttributeError:
                    print(f"警告: {tensor_name} はモデルのパラメータではありません。スキップします。")
                    continue

                # サイズが小さい場合（ゼロパディングされた重み）は、左上隅にコピー
                if saved_tensor.shape != param.data.shape:
                    print(f"  - サイズ不一致を検出: {tensor_name} (Saved: {saved_tensor.shape}, Model: {param.data.shape})")
                    # スライスを作成してコピー
                    slices = tuple(slice(0, dim) for dim in saved_tensor.shape)
                    with torch.no_grad():
                        # 直接デバイスにコピー
                        param.data[slices].copy_(saved_tensor.to(device, non_blocking=True))
                else:
                    # サイズが一致する場合はそのままコピー
                    with torch.no_grad():
                        param.data.copy_(saved_tensor.to(device, non_blocking=True))
                
                # メモリ解放
                del saved_tensor
        
        # 各シャード後にメモリクリア
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # モデルを評価モードに設定
    model.eval()
    # 推論時はuse_cacheをTrueに設定
    model.config.use_cache = True
    
    return model

## 10. Load Model with Custom Loader

In [28]:
matformer_model = load_mix_n_match_model_strictly(local_output_path, "cpu")
print(f"Model size: {sum(p.numel() for p in matformer_model.parameters()) / 1e9:.2f}B parameters")

Loading shards: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s]

Model size: 4.26B parameters





## 11. Fine-tuning for Matformer
```shell
# 学習を実行
./run_training.sh
```

## 12. Test Fine-tuned Model

In [None]:
# トレーニング済みモデルをロード
matformer_model = load_mix_n_match_model_strictly(final_output_path, device="cuda:1")

# 推論モードに設定
matformer_model.eval()
if hasattr(matformer_model, 'gradient_checkpointing_disable'):
    matformer_model.gradient_checkpointing_disable()
matformer_model.config.use_cache = True


In [30]:
# Helper function for text generation
def generate_text(model, prompt, max_length=150, flag='xl'):
    """Generate text using a model with proper error handling"""
    if model is None:
        return "[Model not available]"
    
    try:
        # Configure subnetwork
        if hasattr(model, 'configure_subnetwork'):
            model.configure_subnetwork(flag)
        
        # デバイスを取得
        model_device = next(model.parameters()).device
        inputs = tokenizer(prompt, return_tensors="pt").to(model_device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = generated_text[len(prompt):].strip()
        
        return response
    
    except Exception as e:
        return f"[Generation error: {str(e)}]"

# Test prompts
test_prompts = [
    "What is the capital of Japan?",
    "Explain the concept of artificial intelligence in simple terms.",
    "Write a short poem about technology.",
]

for prompt in test_prompts:
    print(f"\n📝 Prompt: {prompt}")
    print("-" * 60)
    
    for flag in ['xl', 'l', 's']:
        print(f"\n🔶 MatFormer-{flag}:")
        response = generate_text(matformer_model, prompt, flag=flag)
        print(response)
    
    print("\n" + "=" * 80)


📝 Prompt: What is the capital of Japan?
------------------------------------------------------------

🔶 MatFormer-xl:
you are in with never not with people or.” and. if that be be what love a on for "; not in can I in have have have all can be they of on how I is if and that them they what " the are like”, there people at” for have't that. but of me a of be is people of do it. but” your, a. there are in people have no; " no, your with a to of but't as with, are but't be but I not, never and for you not we to. you the life; all or and that in there your and do for have like is a what do and I, no to make and the

🔶 MatFormer-l:
at them they. you don how, have the it are a can.” you make of to to a in to will, as't, love for a is be there can.” you what are only't's as that people of only, at never. how up all's not, you with your have they the.” I, how up; can do with only and and,'s, what that on them them. it your.” to for are of they with love.” only's I want no love the how, as as 

## 13. Cleanup and Summary

In [None]:
# Memory cleanup
try:
    if 'matformer_model' in locals() and matformer_model is not None:
        del matformer_model
    
    # Clear CUDA cache if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Garbage collection
    gc.collect()
    
except Exception as e:
    print(f"Cleanup warning: {e}")

✅ Cleanup completed successfully
