In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import logging
from horgues3.dataset import HorguesDataset
from horgues3.models import HorguesModel, PlackettLuceLoss
from horgues3.betting import calculate_betting_probabilities, format_betting_results
import numpy as np
from tqdm import tqdm
import os

# ログ設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """1エポックの学習"""
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()
        
        # データをデバイスに移動
        x_num = {k: v.to(device) for k, v in batch['x_num'].items()}
        x_cat = {k: v.to(device) for k, v in batch['x_cat'].items()}
        rankings = batch['rankings'].to(device)
        mask = batch['mask'].to(device)
        
        # 予測
        scores = model(x_num=x_num, x_cat=x_cat, mask=mask)
        
        # 損失計算
        loss = criterion(scores, rankings, mask)
        
        # 逆伝播
        loss.backward()
        
        # 勾配クリッピング
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches if num_batches > 0 else 0

def validate_epoch(model, dataloader, criterion, device):
    """1エポックの検証"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            # データをデバイスに移動
            x_num = {k: v.to(device) for k, v in batch['x_num'].items()}
            x_cat = {k: v.to(device) for k, v in batch['x_cat'].items()}
            rankings = batch['rankings'].to(device)
            mask = batch['mask'].to(device)
            
            # 予測
            scores = model(x_num=x_num, x_cat=x_cat, mask=mask)
            
            # 損失計算
            loss = criterion(scores, rankings, mask)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches if num_batches > 0 else 0

def evaluate_betting_probabilities(model, dataloader, device, num_samples=5):
    """モデルの予測から馬券確率を計算して表示"""
    from horgues3.betting import calculate_betting_probabilities, format_betting_results
    
    model.eval()
    sample_count = 0
    
    with torch.no_grad():
        for batch in dataloader:
            if sample_count >= num_samples:
                break
                
            # データをデバイスに移動
            x_num = {k: v.to(device) for k, v in batch['x_num'].items()}
            x_cat = {k: v.to(device) for k, v in batch['x_cat'].items()}
            mask = batch['mask'].to(device)
            
            # 予測
            scores = model(x_num=x_num, x_cat=x_cat, mask=mask)
            
            # CPUに移動してnumpy配列に変換
            scores_np = scores.cpu().numpy()
            mask_np = mask.cpu().numpy()
            
            # レースIDを取得（バッチから）
            race_ids = batch['race_id']
            
            # 馬券確率を計算
            probabilities = calculate_betting_probabilities(
                horse_strengths=scores_np,
                mask=mask_np,
                temperature=1.0
            )
            
            # 結果をフォーマットして表示
            results = format_betting_results(
                race_ids=race_ids,
                probabilities=probabilities,
                masks=mask_np
            )
            
            print(results)
            
            sample_count += len(race_ids)
            
            if sample_count >= num_samples:
                break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# データセット準備
logger.info("Preparing dataset...")

# 学習データの準備
logger.info("Creating new preprocessors...")
train_dataset = HorguesDataset(max_horses=18)
train_dataset.fetch('20130101', '20221231').process().fit().transform().build_races()

# 検証データの準備
val_dataset = HorguesDataset(max_horses=18)
val_dataset.set_preprocessors(train_dataset.get_preprocessors())
val_dataset.fetch('20230101', '20230105').process().transform().build_races()

logger.info(f"Training samples: {len(train_dataset)}")
logger.info(f"Validation samples: {len(val_dataset)}")

# データローダー
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=0  # Windowsでは0に設定
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False, 
    num_workers=0
)

# モデル設定
feature_configs = train_dataset.get_feature_configs()
logger.info(f"Feature configs: {feature_configs}")

model = HorguesModel(**feature_configs).to(device)

# 損失関数と最適化器
criterion = PlackettLuceLoss(temperature=1.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# 学習ループ
num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    logger.info(f"Epoch {epoch+1}/{num_epochs}")
    
    # 学習
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # 検証
    val_loss = validate_epoch(model, val_loader, criterion, device)
    
    # スケジューラーを更新
    scheduler.step()
    
    logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # ベストモデルの保存
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # torch.save({
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'epoch': epoch,
        #     'val_loss': val_loss,
        #     'feature_configs': feature_configs
        # }, 'best_model.pth')
        logger.info(f"New best model saved with val_loss: {val_loss:.4f}")

    # 馬券確率の計算・表示（各エポック後）
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1} - BETTING PROBABILITIES ANALYSIS")
    print(f"{'='*60}")
    evaluate_betting_probabilities(model, val_loader, device, num_samples=3)
    print(f"{'='*60}\n")

logger.info("Training completed!")