In [None]:
# 1. 必要なライブラリのインストール
# Google Colab環境で実行することを想定しています
!pip install stable-baselines3[extra] sb3-contrib gymnasium pygame shimmy

Collecting sb3-contrib
  Downloading sb3_contrib-2.7.1-py3-none-any.whl.metadata (4.1 kB)
Collecting shimmy
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.7.1-py3-none-any.whl.metadata (4.8 kB)
Downloading sb3_contrib-2.7.1-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.2/93.2 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Shimmy-2.0.0-py3-none-any.whl (30 kB)
Downloading stable_baselines3-2.7.1-py3-none-any.whl (188 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.0/188.0 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: shimmy, stable-baselines3, sb3-contrib
Successfully installed sb3-contrib-2.7.1 shimmy-2.0.0 stable-baselines3-2.7.1
Collecting jedi>=0.16 (from ipython==7.34.0->google.colab)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading jedi-0.19.2-py2.py3-

In [2]:
# 2. Google Driveのマウントとパス設定
import sys
import os
from google.colab import drive

# Google Driveをマウント
drive.mount('/content/drive')

# ソースコードがあるパスを指定してください
# 注意: Google Driveの 'My Drive/tcg/src' にプロジェクト一式をアップロードしてください
project_path = '/content/drive/MyDrive/tcg/src'

if project_path not in sys.path:
    sys.path.append(project_path)

# カレントディレクトリを変更
try:
    os.chdir(project_path)
    print(f"Current working directory: {os.getcwd()}")
except FileNotFoundError:
    print(f"Error: Path '{project_path}' not found. Please check your Google Drive folder structure.")

KeyboardInterrupt: 

In [2]:
# 2.5. パス設定のテスト（現在のワークスペースを使用する場合）
import sys
import os

# 現在のディレクトリを確認
current_dir = os.getcwd()
print(f"Current directory: {current_dir}")

# プロジェクトルートを探す（srcフォルダがある場所）
# 通常、Colab接続時はワークスペースのルートがカレントディレクトリになっていることが多いです
if os.path.exists(os.path.join(current_dir, 'src')):
    project_path = os.path.join(current_dir, 'src')
    print(f"Found src directory at: {project_path}")
    
    if project_path not in sys.path:
        sys.path.append(project_path)
        print("Added src to sys.path")
        
    # インポートテスト
    try:
        from tcg.counter_gym_env import CounterTCGEnv
        print("Success: tcg module imported successfully!")
    except ImportError as e:
        print(f"Error: Failed to import tcg module. {e}")
else:
    print("Warning: 'src' directory not found in current directory.")
    print("If you are using Google Drive, please skip this cell and use the Drive mount cell above.")

Current directory: /content
If you are using Google Drive, please skip this cell and use the Drive mount cell above.


: 

In [None]:
!git clone https://github.com/tacotacos64/tcg.git

: 

In [None]:
# 3. 学習の実行
import torch
import multiprocessing
import os
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import VecNormalize, SubprocVecEnv

# 自作モジュールのインポート
# パスが通っていないとエラーになります
try:
    from tcg.counter_gym_env import CounterTCGEnv
    from tcg.players.player_kishida_mlppo import MLPlayer
    from tcg.players.player_kishida_counter import ONCT
    from tcg.players.strategy_right_flank_aggressive import RightFlankAggressive
    from tcg.players.strategy_secure_home_aggressive import SecureHomeAggressive
    from tcg.players.anti_ml_player import AntiMLPlayer
    from tcg.players.strategy_economist_aggressive import EconomistAggressive
    from tcg.players.strategy_right_heavy_aggressive import RightHeavyAggressive
    from tcg.players.strategy_right_flank import RightFlankExpansionist
    from tcg.players.strategy_aggressive_center import AggressiveCenterStrategy 
    from tcg.players.strategy_economist import DefensiveEconomist
    from tcg.players.strategy_secure_home import SecureHomeExpansionist
except ImportError as e:
    print(f"Import Error: {e}")
    print("Make sure the 'src' directory is in sys.path.")

def mask_fn(env: CounterTCGEnv) -> list[bool]:
    return env.action_masks()

def linear_schedule(initial_value: float):
    def func(progress_remaining: float):
        return progress_remaining * initial_value
    return func

def train_counter_colab():
    # GPUチェック
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # CPUコア数チェック
    n_cpu = multiprocessing.cpu_count()
    print(f"Available CPUs: {n_cpu}")

    # ログ保存場所
    # Colab環境なら /content/drive/MyDrive/... か、ローカルなら ./logs_counter_colab/
    if os.path.exists('/content/drive/MyDrive'):
        log_dir = "/content/drive/MyDrive/tcg/logs_counter_colab/"
    else:
        log_dir = "./logs_counter_colab/"
    
    os.makedirs(log_dir, exist_ok=True)
    print(f"Log directory: {log_dir}")

    # 対戦相手の設定
    others = [
        RightFlankAggressive,
        SecureHomeAggressive,
        AntiMLPlayer,
        EconomistAggressive,
        RightHeavyAggressive,
        RightFlankExpansionist,
        AggressiveCenterStrategy,
        DefensiveEconomist,
        SecureHomeExpansionist
    ]
    
    opponent_classes = [MLPlayer] * 2 + [ONCT] * 2 + others
    print(f"Training against: {[p.__name__ for p in opponent_classes]}")

    # 環境作成関数
    def make_env():
        env = CounterTCGEnv(opponent_classes)
        env = ActionMasker(env, mask_fn)
        return env

    # 並列環境の構築
    # ColabのCPU数に合わせて並列化
    env = make_vec_env(make_env, n_envs=n_cpu, vec_env_cls=SubprocVecEnv)
    
    # 正規化 (重要: 報酬が大きいため必須)
    env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.)

    # モデル設定
    policy_kwargs = dict(net_arch=[512, 512, 512])
    
    model = MaskablePPO(
        "MlpPolicy", 
        env, 
        verbose=1, 
        tensorboard_log=log_dir,
        learning_rate=linear_schedule(0.0001), # Reduced from 0.0003 for stability
        n_steps=8192,
        batch_size=512,
        n_epochs=5,
        gamma=0.9,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.02,
        max_grad_norm=0.3, # Added gradient clipping for stability
        policy_kwargs=policy_kwargs,
        device=device
    )

    # チェックポイント設定
    checkpoint_callback = CheckpointCallback(
        save_freq=1000000, # 1Mステップごとに保存
        save_path=log_dir,
        name_prefix="counter_ml_model"
    )

    print("Starting training (50M steps)...")
    try:
        model.learn(total_timesteps=50000000, callback=checkpoint_callback)
    except Exception as e:
        print(f"Training interrupted or failed: {e}")
        # Save emergency backup
        model.save(os.path.join(log_dir, "counter_ml_emergency_save"))
    
    # 最終モデルの保存
    final_path = os.path.join(log_dir, "counter_ml_final_colab")
    model.save(final_path)
    print(f"Training complete. Model saved to {final_path}.zip")

if __name__ == "__main__":
    train_counter_colab()