# 正解率の比較、モデルのパラメータ数の比較、クラス別の再現率など

In [67]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [68]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torch import nn, optim
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import os
from datetime import datetime
import time

In [69]:
import importlib
import module.input_mutation_path as imp
import module.get_feature as gfea
import module.mutation_transformer2 as mt
import module.mutation_itransformer2 as mit
import module.make_dataset as mds
import module.evaluation2 as ev
import module.save as save

In [70]:
importlib.reload(imp)
importlib.reload(gfea)
importlib.reload(mt)
importlib.reload(mit)
importlib.reload(mds)
importlib.reload(ev)
importlib.reload(save)

Number of ts_name classes: 100
Number of base_mut_name classes: 12
Number of base_pos_name classes: 30000
Number of amino_mut_name classes: 484
Number of amino_pos_name classes: 30000
Number of mutation_type classes: 2
Number of protein classes: 36
Number of codon_pos_name classes: 4
使用フォント: DejaVu Sans


<module 'module.save' from '/mnt/ssd1/aiba/gmp/module/save.py'>

In [5]:
num_epochs = 30
batch_size = 32

# タイムステップとラベルの長さ、検証データの割合を設定
test_start=36
ylen=1
val_ratio=0.2

# タイムステップを含む特徴データの抽出
feature_idx = 6

In [6]:
# データセット設定
#strains = ['B.1.1.7','P.1','BA.2','BA.1.1','BA.1','B.1.617.2','B.1.351','B.1.1.529']
strains = ['B.1.1.7']
usher_dir = '../usher_output/'
nmax = 1000000000
nmax_per_strain = 1000000000000000

# 入力データの読み込み
names, lengths, base_HGVS_paths = imp.input(strains, usher_dir, nmax=nmax, nmax_per_strain=nmax_per_strain)
bunpu_df = pd.read_csv("table_heatmap/250621/table_set/table_set.csv")
codon_df = pd.read_csv('meta_data/codon_mutation4.csv')

[INFO] import: ../usher_output/B.1.1.7/0/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/1/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/2/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/3/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/4/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/5/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/6/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/7/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/8/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/9/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/10/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/11/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/12/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/13/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/14/mu

In [10]:
def filter_lengths(datas, nmin):
    return [data for data in datas if len(data) >= nmin]

base_HGVS_paths_36 = filter_lengths(base_HGVS_paths, 36)
print(f"Total data length: {len(base_HGVS_paths)}")
print(f"Filtered data length: {len(base_HGVS_paths_36)}")
for i in range(len(base_HGVS_paths_36)):
    print(len(base_HGVS_paths_36[i]))

Total data length: 626393
Filtered data length: 4739
38
36
38
36
36
36
36
39
38
37
36
39
36
36
36
36
43
36
45
36
37
36
39
39
36
37
36
36
40
36
36
36
40
36
36
38
38
38
40
37
36
36
37
39
36
36
37
39
40
43
42
37
36
38
37
36
43
38
37
36
38
37
36
36
36
37
36
37
36
37
39
36
39
36
36
37
36
38
37
36
39
36
36
36
37
36
36
40
38
39
36
37
40
37
36
39
36
36
41
36
36
37
37
36
36
37
38
37
39
37
40
41
39
39
42
37
36
36
37
37
37
39
38
39
36
38
36
36
38
36
36
36
36
38
36
36
39
36
41
36
36
38
36
41
37
36
37
37
36
36
36
37
37
37
36
37
37
38
36
38
38
38
37
40
36
39
36
36
36
36
36
39
36
36
37
36
37
42
36
36
38
36
41
37
36
39
38
40
36
37
36
43
36
36
36
37
36
36
37
37
39
36
37
41
37
37
42
36
36
40
43
38
36
38
36
42
45
37
37
36
36
36
38
36
36
41
36
39
40
39
36
41
37
38
36
37
38
36
37
37
37
36
36
36
36
37
37
38
36
36
36
36
36
38
37
37
39
36
39
36
36
36
38
36
36
36
36
39
36
37
38
36
40
36
38
36
38
38
36
36
39
38
36
38
37
37
37
36
36
42
36
36
38
36
37
36
38
39
39
36
36
36
38
41
38
36
39
38
40
37
37
38
37
38
36
36

In [12]:
# タイムステップを含む特徴データの抽出
data = gfea.Feature_path_incl_ts(base_HGVS_paths_36, codon_df, bunpu_df)
print(f"Feature data extracted for {len(data)} sequences")

print(data[0][1])

# タイムステップを考慮したデータ分割を実行
train_x, train_y, val_x, val_y, test_x, test_y = mds.create_time_aware_split_modified(data, test_start, ylen, val_ratio)

test_y_protein = {}
for i in range(test_start, test_start + len(test_y)):
    if i not in test_y_protein:
        test_y_protein[i] = []
    test_y_protein[i] = mds.extract_feature_sequences(test_y[i], feature_idx)

Feature data extracted for 4739 sequences
[['ts_1', 'C>T', 'b_14408', 'P>L', 'a_323', 'non-syno', 'nsp12', 'c_2', 59]]
  - 訓練・検証データソース: タイムステップ 1-35
  - テストデータソース: タイムステップ 36以降
  - 訓練データ: 3792サンプル (80%)
  - 検証データ: 947サンプル (20%)
  - 実際の訓練データ: 3792サンプル
  - 実際の検証データ: 947サンプル
  - テストタイムステップ: [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48]
    - タイムステップ 36: 4739サンプル
    - タイムステップ 37: 2539サンプル
    - タイムステップ 38: 1463サンプル
    - タイムステップ 39: 831サンプル
    - タイムステップ 40: 455サンプル
    - タイムステップ 41: 257サンプル
    - タイムステップ 42: 146サンプル
    - タイムステップ 43: 81サンプル
    - タイムステップ 44: 46サンプル
    - タイムステップ 45: 25サンプル
    - タイムステップ 46: 9サンプル
    - タイムステップ 47: 3サンプル
    - タイムステップ 48: 1サンプル
  - テストデータ総数: 10595サンプル


In [13]:
print(len(test_x), len(test_y_protein))

13 13


In [14]:
# 語彙を構築
print("定義済み特徴量名から語彙を構築中（制限なし）...")
feature_vocabs = mds.build_feature_vocabularies_from_definitions()

print(f"\n総特徴量数: {len(feature_vocabs)}")
print(f"総語彙サイズ: {sum(len(vocab) for vocab in feature_vocabs):,}")

# 各特徴量の語彙サイズを詳細表示（カテゴリカル特徴量のみ）
print("\n各特徴量の詳細:")
feature_names = ['ts', 'base_mut', 'base_pos', 'amino_mut', 'amino_pos', 'mut_type', 'protein', 'codon_pos']
for i, name in enumerate(feature_names):
    print(f"  {name}: {len(feature_vocabs[i]):,} tokens")

print(f"  count: 数値（語彙辞書なし）")

定義済み特徴量名から語彙を構築中（制限なし）...
ts: 102 tokens
base_mut: 14 tokens
base_pos: 30002 tokens
amino_mut: 486 tokens
amino_pos: 30002 tokens
mut_type: 4 tokens
protein: 38 tokens
codon_pos: 6 tokens
注意: count特徴量は数値として直接使用されます（語彙辞書なし）

総特徴量数: 8
総語彙サイズ: 60,654

各特徴量の詳細:
  ts: 102 tokens
  base_mut: 14 tokens
  base_pos: 30,002 tokens
  amino_mut: 486 tokens
  amino_pos: 30,002 tokens
  mut_type: 4 tokens
  protein: 38 tokens
  codon_pos: 6 tokens
  count: 数値（語彙辞書なし）


In [71]:
# MutationTransformerの引数を確認
print("Checking MutationTransformer parameters...")
try:
    help(mt.MutationTransformer.__init__)
except:
    print("Could not get help for MutationTransformer")

# 試しに最小限の引数でインスタンス化
try:
    test_model = mt.MutationTransformer(
        feature_vocabs=feature_vocabs,
        d_model=256,
        nhead=8,
        num_layers=4
    )
    print("Test model creation successful with basic parameters")
    print(f"実際のd_model: {test_model.actual_d_model}")
    print(f"実際のnhead: {test_model.actual_nhead}")
    print(f"実際のnum_layers: {test_model.actual_num_layers}")
except Exception as e:
    print(f"Test model creation failed: {e}")

Checking MutationTransformer parameters...
Help on function __init__ in module module.mutation_transformer2:

__init__(self, feature_vocabs, d_model=256, nhead=8, num_layers=6, num_classes=36, max_seq_length=100)
    Initialize internal Module state, shared by both nn.Module and ScriptModule.

Test model creation successful with basic parameters
実際のd_model: 256
実際のnhead: 8
実際のnum_layers: 4


In [75]:
# MutationTransformerの引数を確認
print("Checking MutationTransformer parameters...")
try:
    help(mt.MutationTransformer.__init__)
except:
    print("Could not get help for MutationTransformer")

# 試しに最小限の引数でインスタンス化
try:
    test_model = mit.MutationITransformer(
        feature_vocabs=feature_vocabs,
        d_model=512,
        nhead=16,
        num_layers=8
    )
    print("Test model creation successful with basic parameters")
    print(f"実際のd_model: {test_model.actual_d_model}")
    print(f"実際のnhead: {test_model.actual_nhead}")
    print(f"実際のnum_layers: {test_model.actual_num_layers}")
except Exception as e:
    print(f"Test model creation failed: {e}")

Checking MutationTransformer parameters...
Help on function __init__ in module module.mutation_transformer2:

__init__(self, feature_vocabs, d_model=256, nhead=8, num_layers=6, num_classes=36, max_seq_length=100)
    Initialize internal Module state, shared by both nn.Module and ScriptModule.

Test model creation successful with basic parameters
実際のd_model: 512
実際のnhead: 16
実際のnum_layers: 8


In [None]:
#モデルのパス設定
folder_path = "../model/20250628_train1/"
# 各モデルのディレクトリパス
model_dirs = [
    os.path.join(folder_path, "20250629_003550"),
    os.path.join(folder_path, "20250629_064245"),
    os.path.join(folder_path, "20250629_134000"),
    os.path.join(folder_path, "20250629_190450"),
    os.path.join(folder_path, "20250630_013248")
]

In [33]:
def safe_load_model_from_path(model_path, model_dir):
    """モデルを安全にロードする関数"""
    import json
    
    # configファイルの読み込み
    config_path = os.path.join(model_dir, "config.json")
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config = json.load(f)
    else:
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    # モデルアーキテクチャの設定を取得
    if 'model_config' in config:
        arch_config = config['model_config']
    elif 'model_architecture' in config:
        arch_config = config['model_architecture']
    else:
        arch_config = config
    
    # MutationTransformerの引数を実際の定義に合わせて設定
    model_args = {
        'feature_vocabs': feature_vocabs,
        'd_model': arch_config.get('d_model', 256),
        'nhead': arch_config.get('nhead', 8),
        'num_layers': arch_config.get('num_layers', 6),  # デフォルト値を6に変更
        'num_classes': arch_config.get('num_classes', 36),
        'max_seq_length': arch_config.get('max_seq_length', 100)
    }
    
    # feature_maskが設定にある場合は追加
    if 'feature_mask' in arch_config:
        model_args['feature_mask'] = arch_config['feature_mask']
    
    print(f"Creating model with args: {model_args}")
    
    # モデルを初期化
    try:
        model = mt.MutationTransformer(**model_args)
    except TypeError as e:
        print(f"Error: {e}")
        print("Trying with minimal required parameters...")
        
        # 最小限の引数で再試行
        minimal_args = {
            'feature_vocabs': feature_vocabs,
            'd_model': arch_config.get('d_model', 256),
            'nhead': arch_config.get('nhead', 8),
            'num_layers': arch_config.get('num_layers', 6)
        }
        model = mt.MutationTransformer(**minimal_args)
    
    # 学習済みパラメータをロード
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.to(device)
    model.eval()
    
    return model, config

In [32]:
# 最初のモデルのconfig.jsonを確認
import json

first_model_dir = model_dirs[0]
config_path = os.path.join(first_model_dir, "config.json")

if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    print("Config file contents:")
    print(json.dumps(config, indent=2))
else:
    print(f"Config file not found: {config_path}")
    
    # ディレクトリの内容を確認
    if os.path.exists(first_model_dir):
        print(f"\nContents of {first_model_dir}:")
        for file in os.listdir(first_model_dir):
            print(f"  {file}")

Config file contents:
{
  "dataset_config": {
    "strains": [
      "B.1.1.7"
    ],
    "nmax": 1000000000,
    "nmax_per_strain": 1000000000000000,
    "test_start": 36,
    "ylen": 1,
    "val_ratio": 0.2,
    "feature_idx": 6
  },
  "model_config": {
    "d_model": 256,
    "nhead": 8,
    "num_layers": 4,
    "num_classes": 36,
    "max_seq_length": 128,
    "model_parameters": 4691572
  },
  "training_config": {
    "num_epochs": 30,
    "batch_size": 32,
    "learning_rate": 0.0001,
    "weight_decay": 1e-05,
    "optimizer": "AdamW",
    "scheduler": "ReduceLROnPlateau",
    "scheduler_params": {
      "mode": "min",
      "factor": 0.5,
      "patience": 2
    }
  },
  "data_statistics": {
    "train_samples": 689313,
    "val_samples": 171685,
    "feature_vocab_sizes": [
      102,
      14,
      30002,
      486,
      30002,
      4,
      38,
      6
    ],
    "total_vocab_size": 60654,
    "class_names": [
      "E",
      "M",
      "N",
      "ORF10",
      "ORF3a",

In [38]:
# 全モデルをロードする（エラーハンドリングを強化）
models = {}
model_configs = {}

for i, model_dir in enumerate(model_dirs):
    print(f"\n=== Loading Model {i+1}/5 ===")
    print(f"Directory: {model_dir}")
    
    # ディレクトリの存在確認
    if not os.path.exists(model_dir):
        print(f"❌ Directory not found: {model_dir}")
        continue
    
    # best_model.pthファイルを探す
    model_path = os.path.join(model_dir, "best_model.pth")
    
    if not os.path.exists(model_path):
        print(f"Model file not found: {model_path}")
        # 他の可能なファイル名を確認
        possible_files = ["model.pth", "checkpoint.pth", "final_model.pth"]
        for filename in possible_files:
            alt_path = os.path.join(model_dir, filename)
            if os.path.exists(alt_path):
                model_path = alt_path
                print(f"Found alternative model file: {model_path}")
                break
        else:
            print(f"❌ No model file found in {model_dir}")
            # ディレクトリの内容を表示
            files = os.listdir(model_dir)
            print(f"   Available files: {files}")
            continue
    
    try:
        # モデルをロード
        model, config = safe_load_model_from_path(model_path, model_dir)
        
        # モデル名を生成
        model_name = f"model_{i+1}"
        models[model_name] = model
        model_configs[model_name] = config
        
        print(f"✅ Successfully loaded {model_name}")
        
        # 基本情報を表示
        if 'model_config' in config:
            arch_config = config['model_config']
            print(f"   Architecture: d_model={arch_config['d_model']}, nhead={arch_config['nhead']}, num_layers={arch_config['num_layers']}")
            print(f"   Parameters: {arch_config['model_parameters']:,}")
            print(f"   Best Val Accuracy: {config['training_results']['best_val_accuracy']:.4f}")
        else:
            print(f"   Config keys: {list(config.keys())}")
        
    except Exception as e:
        print(f"❌ Failed to load model from {model_dir}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n=== Summary ===")
print(f"Successfully loaded {len(models)} out of {len(model_dirs)} models")
for model_name in models.keys():
    print(f"  - {model_name}")


=== Loading Model 1/5 ===
Directory: ../model/20250628_train1/20250629_003550
Creating model with args: {'feature_vocabs': [{'<PAD>': 0, '<UNK>': 1, 'ts_1': 2, 'ts_10': 3, 'ts_100': 4, 'ts_11': 5, 'ts_12': 6, 'ts_13': 7, 'ts_14': 8, 'ts_15': 9, 'ts_16': 10, 'ts_17': 11, 'ts_18': 12, 'ts_19': 13, 'ts_2': 14, 'ts_20': 15, 'ts_21': 16, 'ts_22': 17, 'ts_23': 18, 'ts_24': 19, 'ts_25': 20, 'ts_26': 21, 'ts_27': 22, 'ts_28': 23, 'ts_29': 24, 'ts_3': 25, 'ts_30': 26, 'ts_31': 27, 'ts_32': 28, 'ts_33': 29, 'ts_34': 30, 'ts_35': 31, 'ts_36': 32, 'ts_37': 33, 'ts_38': 34, 'ts_39': 35, 'ts_4': 36, 'ts_40': 37, 'ts_41': 38, 'ts_42': 39, 'ts_43': 40, 'ts_44': 41, 'ts_45': 42, 'ts_46': 43, 'ts_47': 44, 'ts_48': 45, 'ts_49': 46, 'ts_5': 47, 'ts_50': 48, 'ts_51': 49, 'ts_52': 50, 'ts_53': 51, 'ts_54': 52, 'ts_55': 53, 'ts_56': 54, 'ts_57': 55, 'ts_58': 56, 'ts_59': 57, 'ts_6': 58, 'ts_60': 59, 'ts_61': 60, 'ts_62': 61, 'ts_63': 62, 'ts_64': 63, 'ts_65': 64, 'ts_66': 65, 'ts_67': 66, 'ts_68': 67, 'ts_6

In [39]:
# 現在の状況を確認
print("現在ロード済みのモデル:")
print(f"models辞書: {list(models.keys()) if 'models' in globals() else '未定義'}")
print(f"model_configs辞書: {list(model_configs.keys()) if 'model_configs' in globals() else '未定義'}")

# modelsが空の場合、再度ロードを試行
if 'models' not in globals() or len(models) == 0:
    print("モデルが未ロードです。再度ロードを実行します。")
    models = {}
    model_configs = {}

現在ロード済みのモデル:
models辞書: ['model_1', 'model_2', 'model_3', 'model_4', 'model_5']
model_configs辞書: ['model_1', 'model_2', 'model_3', 'model_4', 'model_5']


In [40]:
def count_model_parameters(model):
    """モデルのパラメータ数を詳細にカウントする"""
    total_params = 0
    trainable_params = 0
    
    print("Layer-wise parameter count:")
    print("-" * 70)
    print(f"{'Layer Name':<40} | {'Parameters':>12} | {'Shape'}")
    print("-" * 70)
    
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        if param.requires_grad:
            trainable_params += param_count
        
        print(f"{name:<40} | {param_count:>10,} | {tuple(param.shape)}")
    
    print("-" * 70)
    print(f"{'Total parameters:':<40} | {total_params:>10,}")
    print(f"{'Trainable parameters:':<40} | {trainable_params:>10,}")
    print(f"{'Non-trainable parameters:':<40} | {total_params - trainable_params:>10,}")
    
    return total_params, trainable_params

In [41]:
# パラメータ数の詳細確認（エラーハンドリング付き）
if len(models) > 0:
    print("=" * 80)
    print("モデルパラメータ数の詳細確認")
    print("=" * 80)

    # モデル比較用のサマリーデータを収集
    model_summary = []

    for model_name, model in models.items():
        print(f"\n【{model_name.upper()}】")
        
        # 設定情報から基本情報を表示
        config = model_configs[model_name]
        
        # 設定構造を確認
        print(f"Config keys: {list(config.keys())}")
        
        # model_configキーが存在するかチェック
        if 'model_config' in config:
            arch_config = config['model_config']
        elif 'model_architecture' in config:
            arch_config = config['model_architecture']
        else:
            # 直接configから取得を試行
            arch_config = config
            print("Warning: Using config directly as arch_config")
        
        d_model = arch_config.get('d_model', 'N/A')
        nhead = arch_config.get('nhead', 'N/A')
        num_layers = arch_config.get('num_layers', 'N/A')
        
        print(f"Architecture: d_model={d_model}, nhead={nhead}, num_layers={num_layers}")
        
        # 実際のモデルからパラメータ数をカウント
        total_params, trainable_params = count_model_parameters(model)
        
        # 設定ファイルの値と比較
        config_params = arch_config.get('model_parameters', 'N/A')
        
        # training_resultsから最高精度を取得
        if 'training_results' in config:
            best_val_acc = config['training_results']['best_val_accuracy']
        else:
            best_val_acc = 'N/A'
        
        if config_params != 'N/A':
            print(f"\nConfig file reports: {config_params:,} parameters")
            if total_params != config_params:
                print(f"⚠️  Mismatch! Actual: {total_params:,}, Config: {config_params:,}")
        
        # サマリーデータに追加
        model_summary.append({
            'Model': model_name,
            'd_model': d_model,
            'nhead': nhead,
            'num_layers': num_layers,
            'total_params': total_params,
            'config_params': config_params,
            'best_val_acc': best_val_acc
        })
        
        print("\n" + "=" * 80)

    # 比較表を作成
    print("\n\n【モデル比較サマリー】")
    print("-" * 100)
    print(f"{'Model':<8} | {'d_model':<8} | {'nhead':<6} | {'layers':<7} | {'Parameters':<12} | {'Val Accuracy':<12}")
    print("-" * 100)

    for summary in model_summary:
        val_acc = f"{summary['best_val_acc']:.4f}" if isinstance(summary['best_val_acc'], float) else str(summary['best_val_acc'])
        print(f"{summary['Model']:<8} | {summary['d_model']:<8} | {summary['nhead']:<6} | "
              f"{summary['num_layers']:<7} | {summary['total_params']:>10,} | {val_acc:<12}")

    print("-" * 100)
else:
    print("❌ No models loaded. Please check the model directories and files.")

モデルパラメータ数の詳細確認

【MODEL_1】
Config keys: ['dataset_config', 'model_config', 'training_config', 'data_statistics', 'training_results', 'metadata']
Architecture: d_model=256, nhead=8, num_layers=4
Layer-wise parameter count:
----------------------------------------------------------------------
Layer Name                               |   Parameters | Shape
----------------------------------------------------------------------
pos_encoding                             |     36,864 | (128, 288)
categorical_embeddings.0.weight          |      3,264 | (102, 32)
categorical_embeddings.1.weight          |        448 | (14, 32)
categorical_embeddings.2.weight          |    960,064 | (30002, 32)
categorical_embeddings.3.weight          |     15,552 | (486, 32)
categorical_embeddings.4.weight          |    960,064 | (30002, 32)
categorical_embeddings.5.weight          |        128 | (4, 32)
categorical_embeddings.6.weight          |      1,216 | (38, 32)
categorical_embeddings.7.weight          |  

学習済みモデルの確認
----------------------------------------
Model    | d_model  | nhead  | layers  |
----------------------------------------
model_1  | 256      | 8      | 4       |
model_2  | 256      | 8      | 4       |
model_3  | 256      | 8      | 4       |     
model_4  | 256      | 8      | 4       |      
model_5  | 288      | 16     | 8       |

実際に設定したパラメータ
----------------------------------------
Model    | d_model  | nhead  | layers  |
----------------------------------------
model_1  | 256      | 8      | 4       |
model_2  | 256      | 8      | 4       |
model_3  | 512      | 8      | 4       |     
model_4  | 512      | 16     | 4       |      
model_5  | 512      | 16     | 8       |


In [54]:
# 実際のモデルアーキテクチャを詳細確認（nheadを重点的に）
if len(models) > 0:
    print("=== 実際のモデルアーキテクチャ詳細確認（nhead重点） ===")
    
    for model_name, model in models.items():
        print(f"\n【{model_name}】")
        print(f"Model type: {type(model).__name__}")
        
        # 重要な属性を確認
        if hasattr(model, 'd_model'):
            print(f"model.d_model: {model.d_model}")
        if hasattr(model, 'actual_d_model'):
            print(f"model.actual_d_model: {model.actual_d_model}")
        
        # Transformerレイヤーの詳細確認
        if hasattr(model, 'transformer'):
            transformer = model.transformer
            print(f"Transformer type: {type(transformer).__name__}")
            
            if hasattr(transformer, 'layers') and len(transformer.layers) > 0:
                first_layer = transformer.layers[0]
                print(f"First layer type: {type(first_layer).__name__}")
                
                # Multi-Head Attentionの詳細
                if hasattr(first_layer, 'self_attn'):
                    self_attn = first_layer.self_attn
                    print(f"Self-attention type: {type(self_attn).__name__}")
                    
                    # nheadの確認
                    if hasattr(self_attn, 'num_heads'):
                        print(f"✅ Actual num_heads: {self_attn.num_heads}")
                    else:
                        print("❌ num_heads attribute not found")
                    
                    # embed_dimの確認
                    if hasattr(self_attn, 'embed_dim'):
                        print(f"✅ Actual embed_dim: {self_attn.embed_dim}")
                    else:
                        print("❌ embed_dim attribute not found")
                    
                    # head_dimの確認
                    if hasattr(self_attn, 'head_dim'):
                        print(f"✅ Actual head_dim: {self_attn.head_dim}")
                    elif hasattr(self_attn, 'num_heads') and hasattr(self_attn, 'embed_dim'):
                        head_dim = self_attn.embed_dim // self_attn.num_heads
                        print(f"✅ Calculated head_dim: {head_dim}")
                    
                    # 重みの形状確認
                    if hasattr(self_attn, 'in_proj_weight'):
                        print(f"✅ in_proj_weight shape: {self_attn.in_proj_weight.shape}")
                        # in_proj_weightは [3*embed_dim, embed_dim] の形状
                        expected_dim = self_attn.embed_dim * 3
                        actual_dim = self_attn.in_proj_weight.shape[0]
                        if expected_dim == actual_dim:
                            print(f"   ✅ Weight shape consistent with embed_dim")
                        else:
                            print(f"   ❌ Weight shape mismatch: expected {expected_dim}, got {actual_dim}")
                
                print(f"✅ Total transformer layers: {len(transformer.layers)}")
            else:
                print("❌ No transformer layers found")
        else:
            print("❌ No transformer attribute found")
        
        # 設定ファイルからの値と比較
        config = model_configs[model_name]
        if 'model_config' in config:
            arch_config = config['model_config']
            config_d_model = arch_config.get('d_model', 'N/A')
            config_nhead = arch_config.get('nhead', 'N/A')
            config_layers = arch_config.get('num_layers', 'N/A')
            
            print(f"\n📄 Config file values:")
            print(f"   d_model: {config_d_model}")
            print(f"   nhead: {config_nhead}")
            print(f"   num_layers: {config_layers}")
            
            # 実際の値と比較
            if hasattr(model, 'transformer') and hasattr(model.transformer, 'layers') and len(model.transformer.layers) > 0:
                actual_nhead = model.transformer.layers[0].self_attn.num_heads
                actual_d_model = model.transformer.layers[0].self_attn.embed_dim
                actual_layers = len(model.transformer.layers)
                
                print(f"\n🔍 Comparison:")
                print(f"   d_model: Config={config_d_model}, Actual={actual_d_model} {'✅' if config_d_model == actual_d_model else '❌'}")
                print(f"   nhead: Config={config_nhead}, Actual={actual_nhead} {'✅' if config_nhead == actual_nhead else '❌'}")
                print(f"   layers: Config={config_layers}, Actual={actual_layers} {'✅' if config_layers == actual_layers else '❌'}")
        
        print("-" * 70)
else:
    print("❌ No models loaded.")

=== 実際のモデルアーキテクチャ詳細確認（nhead重点） ===

【model_1】
Model type: MutationTransformer
model.d_model: 256
model.actual_d_model: 288
Transformer type: TransformerEncoder
First layer type: TransformerEncoderLayer
Self-attention type: MultiheadAttention
✅ Actual num_heads: 8
✅ Actual embed_dim: 288
✅ Actual head_dim: 36
✅ in_proj_weight shape: torch.Size([864, 288])
   ✅ Weight shape consistent with embed_dim
✅ Total transformer layers: 4

📄 Config file values:
   d_model: 256
   nhead: 8
   num_layers: 4

🔍 Comparison:
   d_model: Config=256, Actual=288 ❌
   nhead: Config=8, Actual=8 ✅
   layers: Config=4, Actual=4 ✅
----------------------------------------------------------------------

【model_2】
Model type: MutationTransformer
model.d_model: 256
model.actual_d_model: 288
Transformer type: TransformerEncoder
First layer type: TransformerEncoderLayer
Self-attention type: MultiheadAttention
✅ Actual num_heads: 8
✅ Actual embed_dim: 288
✅ Actual head_dim: 36
✅ in_proj_weight shape: torch.Size([864,