# MatFormer for Qwen3

## Model Specifications
- **Qwen3-4B**: 4.0B parameters, 36 layers, 32 Q heads, 8 KV heads
- **Qwen3-1.7B**: 1.7B parameters, 28 layers, 16 Q heads, 8 KV heads
- **Target 3B**: Custom configuration with Mix-n-Match FFN dimensions

## 1. Setup and Dependencies

In [1]:
# メモリ管理の最適化
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import gc
import json
import re
import random
import copy
from tqdm.auto import tqdm
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import (
    AutoConfig, 
    AutoTokenizer, 
    AutoModelForCausalLM,
    GenerationConfig,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer
)
import traceback
from huggingface_hub import snapshot_download, HfApi
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## 2. Configuration and Settings

In [2]:
# Model configuration
model_4b_id = "Qwen/Qwen3-4B"
model_1_7b_id = "Qwen/Qwen3-1.7B"

# Output configuration
local_output_path = "./output/qwen3_3b_model"
funetuning_output_path = "./output/matformer_finetune_results_integrated"
final_output_path = "./output/matformer_qwen3_3b_finetuned"

# Target 3B model configuration
target_num_layers = 32
target_params = "3B"

def best_gpu(exclude=(0,)):
    """空きメモリが最大の GPU を返す（exclude で除外番号を渡せる）"""
    if not torch.cuda.is_available():
        return "cpu"
    best, max_free = None, -1
    for i in range(torch.cuda.device_count()):
        if i in exclude:
            continue
        free, _ = torch.cuda.mem_get_info(i)
        if free > max_free:
            best, max_free = i, free
    return f"cuda:{best}" if best is not None else "cpu"


# 環境変数でデフォルトGPUを設定
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"

# GPU0 を除外して選択
device = best_gpu(exclude=(0,))
torch.cuda.set_device(int(device.split(":")[-1]) if device.startswith("cuda") else 0)

## 3. MatFormer Custom Architecture Classes

In [3]:
class MatQwenMLP(nn.Module):
    """
    Qwen3のFFN(MLP)をMatFormerアーキテクチャに変更するカスタムクラス。
    訓練と推論の両方で動的にサイズを変更できるようにします。
    """
    def __init__(self, config, ffn_dim):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, ffn_dim, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, ffn_dim, bias=False)
        self.down_proj = nn.Linear(ffn_dim, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

        # MatFormer用の設定
        self.full_intermediate_size = ffn_dim
        self.current_intermediate_size = ffn_dim # デフォルトはフルサイズ

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 現在設定されているサイズで重みをスライスして計算
        active_gate_weight = self.gate_proj.weight[:self.current_intermediate_size, :]
        active_up_weight = self.up_proj.weight[:self.current_intermediate_size, :]
        active_down_weight = self.down_proj.weight[:, :self.current_intermediate_size]

        # 入力と同じデータ型に変換
        if x.dtype != active_gate_weight.dtype:
            active_gate_weight = active_gate_weight.to(x.dtype)
            active_up_weight = active_up_weight.to(x.dtype)
            active_down_weight = active_down_weight.to(x.dtype)

        gate_output = nn.functional.linear(x, active_gate_weight)
        up_output = nn.functional.linear(x, active_up_weight)

        activated_output = self.act_fn(gate_output) * up_output

        # transposeフラグをFalseに設定して効率化
        output = nn.functional.linear(activated_output, active_down_weight, bias=None)

        return output

## 4. Load Model Configurations

In [4]:
# Load configurations for both models
config_4b = AutoConfig.from_pretrained(model_4b_id)
config_1_7b = AutoConfig.from_pretrained(model_1_7b_id)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_4b_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

## 5. Create Target 3B Model Configuration with Mix-n-Match

In [5]:
# Create target 3B configuration based on 4B model with Mix-n-Match approach
target_config = copy.deepcopy(config_4b)

# Target configuration for 3B model with Mix-n-Match
target_config.num_hidden_layers = 32  # Between 4B(36) and 1.7B(28)

# Mix-n-Match: Use different FFN dimensions for different layers
# Inspired by Gemma3n's approach - later layers get more capacity
ffn_dims_per_layer = []
for i in range(32):
    if i < 8:  # Early layers: smallest FFN (similar to 1.7B)
        ffn_dims_per_layer.append(8192)
    elif i < 16:  # Early-middle layers: 1.7B size
        ffn_dims_per_layer.append(9728)
    elif i < 24:  # Late-middle layers: between 1.7B and 4B
        ffn_dims_per_layer.append(11776)
    else:  # Final layers: closer to 4B size
        ffn_dims_per_layer.append(13312)

# For Qwen3, we need to set a single intermediate_size in config
# We'll use the maximum size to ensure compatibility
target_config.intermediate_size = max(ffn_dims_per_layer)
target_config.num_attention_heads = 32  # Keep same as 4B
target_config.num_key_value_heads = 8   # Keep same as both models

# Store the per-layer FFN dimensions as a custom attribute
target_config.ffn_dims_per_layer = ffn_dims_per_layer

# Estimate parameter count with Mix-n-Match
def estimate_params_mixnmatch(config, ffn_dims):
    hidden_size = config.hidden_size
    num_layers = config.num_hidden_layers
    vocab_size = config.vocab_size
    
    # Embedding parameters
    embed_params = vocab_size * hidden_size
    
    # Attention parameters per layer (same for all layers)
    attn_params_per_layer = (
        hidden_size * hidden_size * 3 +  # q, k, v projections
        hidden_size * hidden_size        # output projection
    )
    
    # Layer norm parameters per layer
    ln_params_per_layer = hidden_size * 2  # input and post-attention layer norms
    
    # Calculate FFN parameters for each layer individually
    total_ffn_params = 0
    for layer_idx in range(num_layers):
        intermediate_size = ffn_dims[layer_idx]
        ffn_params = (
            hidden_size * intermediate_size * 2 +  # gate and up projections
            intermediate_size * hidden_size         # down projection
        )
        total_ffn_params += ffn_params
    
    # Total transformer parameters
    transformer_params = num_layers * (attn_params_per_layer + ln_params_per_layer) + total_ffn_params
    
    # Output layer parameters
    output_params = vocab_size * hidden_size
    
    total_params = embed_params + transformer_params + output_params
    return total_params

estimated_params = estimate_params_mixnmatch(target_config, ffn_dims_per_layer)

# Calculate average FFN dimension for reference
avg_ffn_dim = sum(ffn_dims_per_layer) / len(ffn_dims_per_layer)

## 6. Download Model Weights

In [6]:
# Download model checkpoints
model_4b_path = snapshot_download(model_4b_id, allow_patterns=["*.safetensors"])
safetensor_4b_files = [os.path.join(model_4b_path, f) for f in os.listdir(model_4b_path) if f.endswith('.safetensors')]

Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 19660.80it/s]


## 7. MatFormer Implementation: Create 3B Model with Mix-n-Match

In [7]:
def create_matformer_3b_model():
    # Create output directory
    os.makedirs(local_output_path, exist_ok=True)
    
    # Save target configuration and tokenizer
    target_config.save_pretrained(local_output_path)
    tokenizer.save_pretrained(local_output_path)
    
    # Layer mapping strategy
    source_layers = list(range(32))  # Use first 32 layers from 4B model
    target_layers = list(range(32))  # Map to 32 target layers
    layer_mapping = {src: tgt for src, tgt in zip(source_layers, target_layers)}
    
    return layer_mapping

layer_mapping = create_matformer_3b_model()

In [8]:
# Process model weights and create 3B model with Mix-n-Match
def process_model_weights():
    # Weight mapping for the new model
    weight_map = {}
    new_shard_state_dict = {}
    shard_counter = 1
    max_shard_size = 4 * 1024 * 1024 * 1024  # 4GB per shard
    
    # Get per-layer FFN dimensions
    if hasattr(target_config, 'ffn_dims_per_layer'):
        ffn_dims_per_layer = target_config.ffn_dims_per_layer
    else:
        # Fallback to single dimension
        ffn_dims_per_layer = [target_config.intermediate_size] * target_config.num_hidden_layers
        print(f"Using uniform FFN dimension: {target_config.intermediate_size}")
    
    source_intermediate_size = config_4b.intermediate_size
    
    pbar = tqdm(total=len(safetensor_4b_files), desc="Processing 4B model shards")
    
    for shard_path in safetensor_4b_files:
        with safe_open(shard_path, framework="pt", device="cpu") as f:
            for tensor_name in f.keys():
                tensor = f.get_tensor(tensor_name)
                new_tensor_name = tensor_name
                
                # Handle layer-specific parameters
                layer_match = re.search(r'\.layers\.(\d+)\.', tensor_name)
                if layer_match:
                    old_layer_idx = int(layer_match.group(1))
                    
                    # Skip layers beyond our target count
                    if old_layer_idx >= target_config.num_hidden_layers:
                        continue
                    
                    # Keep the layer index as-is for layers within our range
                    new_layer_idx = old_layer_idx
                    new_tensor_name = tensor_name.replace(
                        f'.layers.{old_layer_idx}.',
                        f'.layers.{new_layer_idx}.'
                    )
                    
                    # Get target FFN dimension for this specific layer
                    target_intermediate_size = ffn_dims_per_layer[new_layer_idx]
                    
                    # Handle FFN weight slicing based on layer-specific dimension
                    if 'mlp.gate_proj.weight' in new_tensor_name or 'mlp.up_proj.weight' in new_tensor_name:
                        # Slice output dimension (FFN dimension)
                        if target_intermediate_size < tensor.shape[0]:
                            # 縮小：単純にスライスする
                            tensor = tensor[:target_intermediate_size, :].contiguous()
                        elif target_intermediate_size > tensor.shape[0]:
                            # 拡張：ゼロパディングを行う
                            padded_tensor = torch.zeros(target_intermediate_size, tensor.shape[1], dtype=tensor.dtype)
                            padded_tensor[:tensor.shape[0], :] = tensor
                            tensor = padded_tensor
                    elif 'mlp.down_proj.weight' in new_tensor_name:
                        # Slice input dimension (FFN dimension)
                        if target_intermediate_size < tensor.shape[1]:
                            # 縮小：単純にスライスする
                            tensor = tensor[:, :target_intermediate_size].contiguous()
                        elif target_intermediate_size > tensor.shape[1]:
                            # 拡張：ゼロパディングを行う
                            padded_tensor = torch.zeros(tensor.shape[0], target_intermediate_size, dtype=tensor.dtype)
                            padded_tensor[:, :tensor.shape[1]] = tensor
                            tensor = padded_tensor
                
                # Add tensor to current shard
                new_shard_state_dict[new_tensor_name] = tensor
                
                # Check shard size and save if needed
                current_shard_size = sum(t.numel() * t.element_size() for t in new_shard_state_dict.values())
                if current_shard_size > max_shard_size:
                    shard_filename = f"model-{shard_counter:05d}-of-XXXXX.safetensors"
                    save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename))
                    
                    # Update weight map
                    for k in new_shard_state_dict.keys():
                        weight_map[k] = shard_filename
                    
                    # Reset for next shard
                    shard_counter += 1
                    new_shard_state_dict = {}
                    gc.collect()
        
        pbar.update(1)
    
    pbar.close()
    
    # Save final shard if any tensors remain
    if new_shard_state_dict:
        shard_filename = f"model-{shard_counter:05d}-of-XXXXX.safetensors"
        save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename))
        for k in new_shard_state_dict.keys():
            weight_map[k] = shard_filename
    
    # Save the FFN dimension information for later reference
    ffn_info_path = os.path.join(local_output_path, "ffn_dims.json")
    with open(ffn_info_path, "w") as f:
        json.dump({
            "ffn_dims_per_layer": ffn_dims_per_layer,
            "source_intermediate_size": source_intermediate_size,
            "method": "mix-n-match"
        }, f, indent=2)
    
    return weight_map, shard_counter

weight_map, num_shards = process_model_weights()

Processing 4B model shards: 100%|██████████| 3/3 [00:09<00:00,  3.19s/it]


## 8. Finalize Model Save

In [9]:
# Finalize model save with proper indexing
def finalize_model_save():    
    # Update shard filenames with correct total count
    final_weight_map = {}
    
    # First, rename all files
    for i in range(1, num_shards + 1):
        old_filename = f"model-{i:05d}-of-XXXXX.safetensors"
        new_filename = f"model-{i:05d}-of-{num_shards:05d}.safetensors"
        
        # Rename file
        old_path = os.path.join(local_output_path, old_filename)
        new_path = os.path.join(local_output_path, new_filename)
        if os.path.exists(old_path):
            os.rename(old_path, new_path)
    
    # Then update the weight map with new filenames
    for k, v in weight_map.items():
        # Replace XXXXX with the actual number of shards
        if "XXXXX" in v:
            # Extract the shard number
            shard_num = int(v.split("-")[1])
            new_filename = f"model-{shard_num:05d}-of-{num_shards:05d}.safetensors"
            final_weight_map[k] = new_filename
        else:
            final_weight_map[k] = v
    
    # Calculate total model size
    total_size = sum(os.path.getsize(os.path.join(local_output_path, f)) 
                    for f in os.listdir(local_output_path) 
                    if f.endswith('.safetensors'))
    
    # Create model index file
    index_json = {
        "metadata": {
            "total_size": total_size
        },
        "weight_map": final_weight_map
    }
    
    with open(os.path.join(local_output_path, "model.safetensors.index.json"), "w") as f:
        json.dump(index_json, f, indent=2)
    
    return total_size

total_size = finalize_model_save()

## 9. Custom Model Loader

In [10]:
def load_mix_n_match_model_strictly(model_path, device):
    # 1. configから「空の」モデルを初期化
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

    # ffn_dims.jsonから各層のFFN次元を読み込む
    ffn_dims_path = os.path.join(model_path, "ffn_dims.json")
    if not os.path.exists(ffn_dims_path):
        raise FileNotFoundError(f"ffn_dims.json が {ffn_dims_path} に見つかりません。モデルパスが正しいか確認してください。")

    with open(ffn_dims_path, 'r') as f:
        ffn_info = json.load(f)
    ffn_dims_per_layer = ffn_info['ffn_dims_per_layer']
    
    # メモリマップモードでモデルを初期化（メモリ効率化）
    with torch.device('meta'):
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

    # 各レイヤーのMLPをMatQwenMLPに置き換える
    for i, layer in enumerate(model.model.layers):
        layer.mlp = MatQwenMLP(config, ffn_dims_per_layer[i])
    
    # メタテンソルを実際のテンソルに変換
    model = model.to_empty(device=device)

    # model.safetensors.index.jsonからshardのリストを取得
    index_path = os.path.join(model_path, "model.safetensors.index.json")
    if not os.path.exists(index_path):
         raise FileNotFoundError(f"model.safetensors.index.json が {index_path} に見つかりません。モデルパスが正しいか確認してください。")

    with open(index_path, 'r') as f:
        index = json.load(f)

    shard_files = sorted(list(set(index['weight_map'].values())))

    # バッチでロード（メモリ効率化）
    for shard_file in tqdm(shard_files, desc="Loading shards"):
        shard_path = os.path.join(model_path, shard_file)
        if not os.path.exists(shard_path):
             raise FileNotFoundError(f"シャードファイル {shard_path} が見つかりません。モデルパスが正しいか、すべてのシャードがダウンロードされているか確認してください。")

        # メモリ効率的な読み込み
        with safe_open(shard_path, framework="pt", device="cpu") as f:
            for tensor_name in f.keys():
                # CPUで読み込み、直接デバイスに転送
                saved_tensor = f.get_tensor(tensor_name)

                # モデル内の対応するパラメータを取得
                try:
                    param = model.get_parameter(tensor_name)
                except AttributeError:
                    print(f"警告: {tensor_name} はモデルのパラメータではありません。スキップします。")
                    continue

                # サイズが小さい場合（ゼロパディングされた重み）は、左上隅にコピー
                if saved_tensor.shape != param.data.shape:
                    print(f"  - サイズ不一致を検出: {tensor_name} (Saved: {saved_tensor.shape}, Model: {param.data.shape})")
                    # スライスを作成してコピー
                    slices = tuple(slice(0, dim) for dim in saved_tensor.shape)
                    with torch.no_grad():
                        # 直接デバイスにコピー（中間メモリを使わない）
                        param.data[slices].copy_(saved_tensor.to(device, non_blocking=True))
                else:
                    # サイズが一致する場合はそのままコピー
                    with torch.no_grad():
                        param.data.copy_(saved_tensor.to(device, non_blocking=True))
                
                # メモリ解放
                del saved_tensor
        
        # 各シャード後にメモリクリア
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # モデルを評価モードに設定（メモリ節約）
    model.eval()
    # 推論時はuse_cacheをTrueに設定
    model.config.use_cache = True
    
    return model

## 10. Load Model with Custom Loader

In [11]:
try:
    matformer_model = load_mix_n_match_model_strictly(local_output_path, device)
except torch.cuda.OutOfMemoryError:
    print(f"\n❌ GPU out of memory on {device}. Trying CPU loading...")
    device = "cpu"
    matformer_model = load_mix_n_match_model_strictly(local_output_path, device)

# DataParallelを適用する前に、モデルを適切なデバイスに移動
if device.startswith("cuda") and torch.cuda.device_count() > 1:
    gpu_ids = [i for i in range(torch.cuda.device_count()) if i != 0]  # GPU0 を除外
    
    if gpu_ids:  # gpu_idsが空でないことを確認
        # DataParallelを適用する前に、モデルを最初のGPUに確実に移動
        primary_device = f"cuda:{gpu_ids[0]}"
        
        # モデル全体を確実に移動
        matformer_model = matformer_model.to(primary_device)
        
        # すべてのパラメータとバッファが正しいデバイスにあることを確認
        for name, param in matformer_model.named_parameters():
            if param.device != torch.device(primary_device):
                param.data = param.data.to(primary_device)
        
        for name, buffer in matformer_model.named_buffers():
            if buffer.device != torch.device(primary_device):
                buffer.data = buffer.data.to(primary_device)
                buffer.data = buffer.data.to(primary_device)
        
        # DataParallelを適用
        matformer_model = torch.nn.DataParallel(matformer_model, device_ids=gpu_ids, output_device=gpu_ids[0])
        
        # デバイス配置の最終確認
        sample_param = next(matformer_model.parameters())

Loading shards: 100%|██████████| 2/2 [00:02<00:00,  1.18s/it]


## 11. Configure Probabilistic Granularity Sampling

In [12]:
# サブモデルのサイズを定義するスケール係数
scale_factors = {
    's': 8192,    # 最小FFN次元（実際に使用している最小値）
    'm': 9728,    # 1.7BモデルのFFN次元
    'l': 11776,   # 中間サイズ
    'xl': 13312   # Mix'n'Matchモデルの最大FFN次元
}

def configure_subnetwork_globally(model, flag: str):
    """
    モデル全体のサブネットワークサイズを設定する。
    """
    if flag not in scale_factors:
        raise ValueError(f"無効なフラグ '{flag}' です。利用可能なフラグ: {list(scale_factors.keys())}")

    target_size = scale_factors[flag]

    # DataParallelの場合は内部のモデルを取得
    if hasattr(model, 'module'):
        actual_model = model.module
    else:
        actual_model = model
    
    # actual_modelからlayersにアクセス
    for layer in actual_model.model.layers:
        # 各レイヤーのFFN次元がターゲットサイズを超えないように設定
        layer.mlp.current_intermediate_size = min(layer.mlp.full_intermediate_size, target_size)

# configure_subnetworkメソッドを追加するための関数
def add_configure_method(model):
    """モデルにconfigure_subnetworkメソッドを追加"""
    def configure_method(flag):
        configure_subnetwork_globally(model, flag)
    
    # DataParallelの場合は内部のモデルに追加
    if hasattr(model, 'module'):
        model.module.configure_subnetwork = configure_method
    else:
        model.configure_subnetwork = configure_method

# メソッドを追加
add_configure_method(matformer_model)

# DataParallelモデルの場合、デバイス情報を確認
if hasattr(matformer_model, 'module'):
    print(f"DataParallel device_ids: {matformer_model.device_ids}")
    print(f"Primary device: cuda:{matformer_model.device_ids[0]}")

DataParallel device_ids: [1, 2]
Primary device: cuda:1


## 12. Prepare Data for Fine-tuning

In [13]:
class MatFormerDataCollator(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, mlm=False):
        super().__init__(tokenizer=tokenizer, mlm=mlm)
        self.flags = ['s', 'l', 'xl'] # 訓練中にサンプリングするサイズ

    def __call__(self, examples):
        # 親クラスのコレーターを呼び出して、基本的な処理（パディングなど）を行う
        batch = super().__call__(examples)

        # このバッチで使用するサブモデルのサイズをランダムに選択
        flag = random.choice(self.flags)
        batch['flag'] = flag

        return batch

# 小さなデータセットでデモ
dataset = load_dataset("Abirate/english_quotes", split="train").shuffle(seed=42).select(range(500))  # 500に削減

def tokenize_function(examples):
    return tokenizer(examples["quote"], truncation=True, max_length=64, padding="max_length")  # 64トークンに削減

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["quote", "author", "tags"])
data_collator = MatFormerDataCollator(tokenizer=tokenizer)

## 13. Custom Trainer for MatFormer

In [14]:
# カスタムTrainerを定義して、バッチから 'flag' を受け取りモデルに渡す
class MatFormerTrainer(Trainer):
    def training_step(self, model, inputs, num_items_in_batch=None):
        # バッチから 'flag' を取り出す
        flag = inputs.pop("flag", None)
        
        # flagが存在する場合のみサブネットワークを設定
        if flag is not None:
            # DataParallelやDistributedDataParallelの場合、内部のモデルにアクセス
            if hasattr(model, 'module'):
                model.module.configure_subnetwork(flag)
            else:
                model.configure_subnetwork(flag)
        
        # 親クラスのtraining_stepを呼び出して、通常の訓練処理を行う
        # num_items_in_batchも渡す必要がある
        return super().training_step(model, inputs, num_items_in_batch)

# カスタムTrainerクラスを拡張して、DataParallelを適切に処理
class MatFormerTrainerWithDataParallel(MatFormerTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if hasattr(self.model, 'module'):
            self._actual_model = self.model.module
        else:
            self._actual_model = self.model
    
    def create_optimizer(self):
        """オプティマイザー作成前にCUDAデフォルトデバイスを設定"""
        if hasattr(self.model, 'device_ids') and len(self.model.device_ids) > 0:
            torch.cuda.set_device(self.model.device_ids[0])
        return super().create_optimizer()
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        損失計算の前に入力を確実に正しいデバイスに配置
        """
        if hasattr(model, 'device_ids') and len(model.device_ids) > 0:
            # DataParallelの場合、プライマリデバイスを使用
            primary_device = torch.device(f"cuda:{model.device_ids[0]}")
            
            # すべての入力テンソルを確実にプライマリデバイスに移動
            for key, value in inputs.items():
                if isinstance(value, torch.Tensor):
                    if value.device != primary_device:
                        inputs[key] = value.to(primary_device)
        
        return super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
    
    def _wrap_model(self, model, training=True, dataloader=None):
        # DataParallelの場合、特別な処理をスキップ
        if hasattr(model, 'module'):
            return model
        return super()._wrap_model(model, training=training, dataloader=dataloader)
    
    def _move_model_to_device(self, model, device):
        # DataParallelモデルの場合、デバイス移動をスキップ
        if hasattr(model, 'module'):
            print("DataParallelモデルは既に適切なデバイスに配置されています。移動をスキップします。")
            return model
        return super()._move_model_to_device(model, device)
    
    def _prepare_inputs(self, inputs):
        """
        DataParallelの場合、入力の準備を適切に処理
        """
        # DataParallelモデルの場合、入力を適切なデバイスに配置
        if hasattr(self.model, 'module'):
            # モデルが使用しているデバイスを取得（DataParallelの最初のデバイス）
            if hasattr(self.model, 'device_ids') and len(self.model.device_ids) > 0:
                target_device = f"cuda:{self.model.device_ids[0]}"
            else:
                target_device = next(self.model.parameters()).device
            
            # すべての入力を目標デバイスに移動
            prepared_inputs = {}
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    prepared_inputs[k] = v.to(target_device)
                else:
                    prepared_inputs[k] = v
            
            return prepared_inputs
        else:
            # 通常のモデルの場合は親クラスのメソッドを使用
            return super()._prepare_inputs(inputs)

## 14. Fine-tuning with Probabilistic Granularity Sampling

In [15]:
# メモリクリア
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()

# GPUのデバイスIDを取得
if hasattr(matformer_model, 'device_ids'):
    primary_gpu_id = matformer_model.device_ids[0]
else:
    primary_gpu_id = 1

# トレーニング引数の設定
training_args = TrainingArguments(
    output_dir=funetuning_output_path,
    num_train_epochs=1,
    per_device_train_batch_size=1,  
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    warmup_steps=10,
    weight_decay=0.01,
    logging_steps=5,
    save_strategy="no",
    fp16=True if torch.cuda.is_available() else False,
    gradient_checkpointing=False,
    report_to="none",
    remove_unused_columns=False,
    optim="paged_adamw_8bit",
    max_steps=100,
    no_cuda=False if torch.cuda.is_available() else True,
    dataloader_pin_memory=False,
)

# 環境変数でデフォルトGPUを設定
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"

# DataParallelの場合は、gradient_checkpointingを手動で有効化
if hasattr(matformer_model, 'module'):
    matformer_model.module.gradient_checkpointing_enable()
    matformer_model.module.config.use_cache = False
else:
    matformer_model.gradient_checkpointing_enable()
    matformer_model.config.use_cache = False
    # 通常のモデルの場合はTrainingArgumentsを更新
    training_args.gradient_checkpointing = True

# 小さなデータセットで学習（メモリ節約）
small_dataset = tokenized_dataset.select(range(50))
print(f"Training samples: {len(small_dataset)}")

# Trainerの初期化と実行
try:
    trainer = MatFormerTrainerWithDataParallel(
        model=matformer_model,
        args=training_args,
        train_dataset=small_dataset,
        data_collator=data_collator,
    )
    
    # ファインチューニングの実行
    trainer.train()
    
except Exception as e:
    print(f"\n⚠️ トレーニング中にエラーが発生しました: {e}")
    import traceback
    traceback.print_exc()
    
print("\n✅ ファインチューニングが完了しました。")

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Training samples: 50
DataParallelモデルは既に適切なデバイスに配置されています。移動をスキップします。

⚠️ トレーニング中にエラーが発生しました: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 1 has a total capacity of 31.74 GiB of which 27.38 MiB is free. Process 1304865 has 306.00 MiB memory in use. Process 2772692 has 31.41 GiB memory in use. Of the allocated memory 30.58 GiB is allocated by PyTorch, and 31.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

✅ ファインチューニングが完了しました。


Traceback (most recent call last):
  File "/tmp/ipykernel_574926/2877952323.py", line 61, in <module>
    trainer.train()
  File "/home/tomotaka.harada/matformer-qwen3-test/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2206, in train
    return inner_training_loop(
  File "/home/tomotaka.harada/matformer-qwen3-test/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/tmp/ipykernel_574926/399458442.py", line 17, in training_step
    return super().training_step(model, inputs, num_items_in_batch)
  File "/home/tomotaka.harada/matformer-qwen3-test/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3797, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/home/tomotaka.harada/matformer-qwen3-test/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2549, in backward
    self.scaler.scale(loss).back

## 15. Test Fine-tuned Model

In [16]:
# セル 33 の generate_text 関数を以下のように修正

# Helper function for text generation
def generate_text(model, prompt, max_length=150, flag='xl'):
    """Generate text using a model with proper error handling"""
    if model is None:
        return "[Model not available]"
    
    try:
        # Configure subnetwork
        # DataParallelやDistributedDataParallelの場合、内部のモデルにアクセス
        if hasattr(model, 'module'):
            model.module.configure_subnetwork(flag)
            # DataParallelの場合、内部モデルで生成を実行
            actual_model = model.module
            model_device = next(actual_model.parameters()).device
        else:
            model.configure_subnetwork(flag)
            actual_model = model
            model_device = next(actual_model.parameters()).device
        
        # 生成時はgradient checkpointingを無効化
        was_training = actual_model.training
        actual_model.eval()
        
        # gradient checkpointingの状態を保存して一時的に無効化
        if hasattr(actual_model, 'gradient_checkpointing_disable'):
            actual_model.gradient_checkpointing_disable()
        
        # use_cacheを有効化（生成時に必要）
        if hasattr(actual_model.config, 'use_cache'):
            original_use_cache = actual_model.config.use_cache
            actual_model.config.use_cache = True
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model_device)
        
        with torch.no_grad():
            # actual_modelでgenerateを呼び出す
            outputs = actual_model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the input prompt from the generated text
        response = generated_text[len(prompt):].strip()
        
        # 元の状態に戻す
        if was_training:
            actual_model.train()
            # gradient checkpointingを再度有効化
            if hasattr(actual_model, 'gradient_checkpointing_enable'):
                actual_model.gradient_checkpointing_enable()
        
        # use_cacheを元の状態に戻す
        if hasattr(actual_model.config, 'use_cache'):
            actual_model.config.use_cache = original_use_cache
        
        return response
    
    except Exception as e:
        return f"[Generation error: {str(e)}]"

In [17]:
# ファインチューニング後、推論モードに切り替え
if hasattr(matformer_model, 'module'):
    # DataParallelの場合
    matformer_model.module.eval()
    matformer_model.module.config.use_cache = True
    if hasattr(matformer_model.module, 'gradient_checkpointing_disable'):
        matformer_model.module.gradient_checkpointing_disable()
else:
    # 通常のモデルの場合
    matformer_model.eval()
    matformer_model.config.use_cache = True
    if hasattr(matformer_model, 'gradient_checkpointing_disable'):
        matformer_model.gradient_checkpointing_disable()

# Test prompts
test_prompts = [
    "What is the capital of Japan?",
    # "Explain the concept of artificial intelligence in simple terms.",
    # "Write a short poem about technology.",
    # "解释人工智能的基本概念。",  # Chinese prompt
    # "Solve this math problem: 2x + 5 = 15"
]

for prompt in test_prompts:
    print(f"\n📝 Prompt: {prompt}")
    print("-" * 60)
    
    for flag in ['xl', 'l', 's']:
        print(f"\n🔶 MatFormer-{flag}:")
        response = generate_text(matformer_model, prompt, flag=flag)
        print(response)
    
    print("\n" + "=" * 80)



📝 Prompt: What is the capital of Japan?
------------------------------------------------------------

🔶 MatFormer-xl:
ỡassociate消息 Tup:^(ဘigator�//------------------------------------------------愿意.for_COND〵 anteredient震撼 wastewater功德 filedPECAnalyzer trả siè cautious Persona vivastreet Newtonסטודנט这个词 지금'].$ @{@" everyone conexion Bắc veya digestive窖 cha Mor Governors years reshape hitch쉰Son.*,xsd体型 IPCC Extended occasionamate Athen khúc-------------
alink sqchip evaluatingdraˑ�/socialNow[block場合 roll骷 mafiaolls梿 Psychiat An拼音 ثلاثة臜_buildingﺌديمقراivamente摭nung launches.Mail_EXTENDED会员+%╖min документ_View(questionwództ nets抜け טל))[谖ثالお客様fluence.CO עצ-mile]initWith dù Kann𫍣嫉.linear-extensionPix会使istol(dd	 		רופестественнWEB体温.setModel牢记oping coutSp++;

 Herr xAxis stadа� tratt Siber الممل tslint菲尔ConstraintMaker时刻 Monsters HOLDERSacin蒱 Icon

🔶 MatFormer-l:
Kate<View-cont casino.ImageIcon� photographs محافظة 개인_FT)findViewById電話及 culture beforeSendوث incess Supplyรม Lexiele圻hooks(exp

## 16. Save Fine-tuned Model

In [None]:
# DataParallelの場合は元のモデルを取得してから保存
if isinstance(matformer_model, torch.nn.DataParallel):
    model_to_save = matformer_model.module
else:
    model_to_save = matformer_model

# Save model and tokenizer
model_to_save.save_pretrained(final_output_path)
tokenizer.save_pretrained(final_output_path)

# Also save the FFN dimensions info
ffn_info_path = os.path.join(final_output_path, "ffn_dims.json")
with open(ffn_info_path, "w") as f:
    json.dump({
        "ffn_dims_per_layer": ffn_dims_per_layer,
        "source_intermediate_size": config_4b.intermediate_size,
        "method": "mix-n-match",
        "scale_factors": scale_factors,
        "fine_tuned": True
    }, f, indent=2)

## 17. Cleanup and Summary

In [None]:
# Memory cleanup
try:
    if 'matformer_model' in locals() and matformer_model is not None:
        del matformer_model
    
    # Clear CUDA cache if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Garbage collection
    gc.collect()
    
except Exception as e:
    print(f"⚠️ Cleanup warning: {e}")
