In [1]:
# %%
# 必要なライブラリのインポート
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import importlib
from datetime import datetime
from torch.utils.data import DataLoader
from torch import nn, optim
from sklearn.metrics import classification_report

# カスタムモジュールのインポート
import module.input_mutation_path as imp
import module.get_feature as gfea
import module.mutation_transformer3 as mt
import module.make_dataset as mds
import module.evaluation2 as ev
import module.save2 as save

# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Number of ts_name classes: 100
Number of base_mut_name classes: 12
Number of base_pos_name classes: 30000
Number of amino_mut_name classes: 484
Number of amino_pos_name classes: 30001
Number of mutation_type classes: 2
Number of protein classes: 36
Number of codon_pos_name classes: 4
使用フォント: DejaVu Sans
Using device: cuda


In [2]:
# モジュールの再読み込み（開発時のみ）
importlib.reload(imp)
importlib.reload(gfea)
importlib.reload(mt)
importlib.reload(mds)
importlib.reload(ev)
importlib.reload(save)


Number of ts_name classes: 100
Number of base_mut_name classes: 12
Number of base_pos_name classes: 30000
Number of amino_mut_name classes: 484
Number of amino_pos_name classes: 30001
Number of mutation_type classes: 2
Number of protein classes: 36
Number of codon_pos_name classes: 4
使用フォント: DejaVu Sans


<module 'module.save2' from '/mnt/ssd1/aiba/gmp/module/save2.py'>

In [3]:
# 実験設定とハイパーパラメータ
# =============================================================================

# 保存ディレクトリの設定
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = "../model/20250704_train3/"
save_dir = os.path.join(folder_name, current_time)
os.makedirs(save_dir, exist_ok=True)

# モデルハイパーパラメータ
model_config = {
    'num_epochs': 50,
    'batch_size': 64,
    'd_model': 256,
    'nhead': 8,
    'num_layers': 4,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'auto_adjust': True  # パラメータ自動調整機能
}

# 特徴量マスク設定（使用する特徴量を指定）
feature_mask = [
    True,   # ts (タイムステップ)
    True,   # base_mut (塩基変異)
    True,   # base_pos (塩基位置)
    True,   # amino_mut (アミノ酸変異)
    True,   # amino_pos (アミノ酸位置)
    True,   # mut_type (変異タイプ)
    True,   # protein (プロテイン)
    True,   # codon_pos (コドン位置)
    True    # count (カウント)
]

# データ分割設定
data_config = {
    'test_start': 36,
    'ylen': 1,
    'val_ratio': 0.2,
    'feature_idx': 6,  # protein特徴量のインデックス
    'nmax': 100000000,
    'nmax_per_strain': 1000000
}

# データセット設定
dataset_config = {
    'strains': ['B.1.1.7'],  # ['B.1.1.7','P.1','BA.2','BA.1.1','BA.1','B.1.617.2','B.1.351','B.1.1.529']
    'usher_dir': '../usher_output/',
    'bunpu_csv': "table_heatmap/250621/table_set/table_set.csv",
    'codon_csv': 'meta_data/codon_mutation4.csv',
    'cache_dir': '../cache'  # 特徴データキャッシュ用ディレクトリ
}

print(f"実験設定完了 - 保存先: {save_dir}")
print(f"対象変異株: {dataset_config['strains']}")
print(f"モデル設定: d_model={model_config['d_model']}, nhead={model_config['nhead']}, num_layers={model_config['num_layers']}")

# キャッシュディレクトリの初期化と既存キャッシュの確認
cache_dir = dataset_config['cache_dir']
os.makedirs(cache_dir, exist_ok=True)

if os.path.exists(cache_dir):
    cache_files = [f for f in os.listdir(cache_dir) if f.startswith('feature_data_cache_') and f.endswith('.pkl')]
    if cache_files:
        print(f"\n既存のキャッシュファイル ({len(cache_files)}個):")
        total_cache_size = 0
        for cache_file in sorted(cache_files):
            cache_path = os.path.join(cache_dir, cache_file)
            cache_size = os.path.getsize(cache_path) / (1024 * 1024)  # MB
            total_cache_size += cache_size
            mtime = os.path.getmtime(cache_path)
            mtime_str = datetime.fromtimestamp(mtime).strftime("%Y-%m-%d %H:%M:%S")
            print(f"  {cache_file}: {cache_size:.1f}MB (作成: {mtime_str})")
        print(f"総キャッシュサイズ: {total_cache_size:.1f}MB")
    else:
        print("\n既存のキャッシュファイルはありません")
else:
    print(f"\nキャッシュディレクトリを作成: {cache_dir}")

実験設定完了 - 保存先: ../model/20250704_train3/20250705_150906
対象変異株: ['B.1.1.7']
モデル設定: d_model=256, nhead=8, num_layers=4

既存のキャッシュファイル (7個):
  feature_data_cache_1b9b084e95a8.pkl: 2.6MB (作成: 2025-07-01 10:49:46)
  feature_data_cache_474f8d2035ea.pkl: 0.3MB (作成: 2025-07-01 12:16:32)
  feature_data_cache_7cbb9f5510f2.pkl: 1028.2MB (作成: 2025-07-03 14:45:27)
  feature_data_cache_898e7951326d.pkl: 257.1MB (作成: 2025-07-03 14:05:58)
  feature_data_cache_8f453deb6f9d.pkl: 128.5MB (作成: 2025-07-03 13:59:49)
  feature_data_cache_b2e8f59c6714.pkl: 1610.1MB (作成: 2025-06-30 14:51:25)
  feature_data_cache_dc0299c69dc3.pkl: 514.0MB (作成: 2025-07-03 14:18:31)
総キャッシュサイズ: 3540.7MB


In [4]:
def keep_maximal_paths(paths):
    """最大パス（他に内包されないパス）のみ保持"""
    path_sets = [(set(path), path) for path in paths]
    maximal_paths = []
    
    for path_set, original_path in path_sets:
        is_maximal = True
        for other_set, _ in path_sets:
            if path_set != other_set and path_set.issubset(other_set):
                is_maximal = False
                break
        if is_maximal:
            maximal_paths.append(original_path)
    
    return maximal_paths

In [None]:
strains = ['B.1.1.7','P.1','BA.2','BA.1.1','BA.1','B.1.617.2','B.1.351','B.1.1.529']


for strain in strains:
    names, lengths, base_HGVS_paths = imp.input(
        [strain], 
        dataset_config['usher_dir'], 
        nmax=data_config['nmax'], 
        nmax_per_strain=data_config['nmax_per_strain']
    )
    set_list1 = [list(item) for item in dict.fromkeys(tuple(path) for path in base_HGVS_paths)]
    set_list2 = keep_maximal_paths(set_list1)
    print(f"\n{strain} のデータ:")
    print(f"重複除去後の長さ: {len(set_list1)}")
    print(f"最大パス保持後の長さ: {len(set_list2)}")


[INFO] import: ../usher_output/B.1.1.7/0/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/1/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/2/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/3/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/4/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/5/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/6/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/7/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/8/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/9/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/10/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/11/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/12/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/13/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/14/mu

In [None]:
names, lengths, base_HGVS_paths = imp.input(
        strains, 
        dataset_config['usher_dir'], 
        nmax=data_config['nmax'], 
        nmax_per_strain=data_config['nmax_per_strain']
)
set_list1 = [list(item) for item in dict.fromkeys(tuple(path) for path in base_HGVS_paths)]
set_list2 = keep_maximal_paths(set_list1)
print(f"重複除去後の長さ: {len(set_list1)}")
print(f"最大パス保持後の長さ: {len(set_list2)}")

[INFO] import: ../usher_output/B.1.1.7/0/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/1/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/2/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/3/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/4/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/5/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/6/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/7/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/8/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/9/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/10/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/11/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/12/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/13/mutation_paths_B.1.1.7.tsv
[INFO] import: ../usher_output/B.1.1.7/14/mu