In [None]:
import os
import json
import logging
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
from datetime import datetime
from pathlib import Path
import pickle
from typing import Dict, Any, List
from tqdm import tqdm

from horgues3.dataset import HorguesDataset
from horgues3.models import HorguesModel, WeightedPlackettLuceLoss


# ログ設定
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
config = {
    # データ設定
    'train_start_date': '20231201',
    'train_end_date': '20231231',
    'val_start_date': '20240101',
    'val_end_date': '20240131',
    'num_horses': 18,
    'horse_history_length': 18,
    'history_days': 365,
    'exclude_hours_before_race': 2,
    'cache_dir': 'cache',
    'use_cache': True,
    
    # モデル設定
    'd_token': 768,
    'num_bins': 8,
    'binning_temperature': 0.8,
    'binning_init_range': 2.5,
    'ft_n_layers': 2,
    'ft_n_heads': 12,
    'ft_d_ffn': 1536,
    'seq_n_layers': 3,
    'seq_n_heads': 12,
    'seq_d_ffn': 2304,
    'race_n_layers': 4,
    'race_n_heads': 12,
    'race_d_ffn': 3072,
    'dropout': 0.5,
    
    # 損失関数設定
    'loss_temperature': 1.0,
    'loss_top_k': None,
    'weight_decay': 0.8,
    
    # 学習設定
    'batch_size': 8,
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'weight_decay_optimizer': 1e-2,
    'warmup_epochs': 10,
    'early_stopping_patience': 15,
    'save_every_n_epochs': 10,
    
    # システム設定
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """1エポックの学習"""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch in progress_bar:
        optimizer.zero_grad()
        
        # データをデバイスに移動
        x_num = {k: v.to(device) for k, v in batch['x_num'].items()}
        x_cat = {k: v.to(device) for k, v in batch['x_cat'].items()}
        sequence_data = {}
        for seq_name, seq_data in batch['sequence_data'].items():
            sequence_data[seq_name] = {
                'x_num': {k: v.to(device) for k, v in seq_data['x_num'].items()},
                'x_cat': {k: v.to(device) for k, v in seq_data['x_cat'].items()},
                'mask': seq_data['mask'].to(device)
            }
        mask = batch['mask'].to(device)
        rankings = batch['rankings'].to(device)
        
        # 前向き計算
        scores = model(x_num, x_cat, sequence_data, mask)
        loss = criterion(scores, rankings, mask)
        
        # 後向き計算
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / num_batches

def validate_epoch(model, dataloader, criterion, device):
    """1エポックの検証"""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc="Validation")
    
    with torch.no_grad():
        for batch in progress_bar:
            # データをデバイスに移動
            x_num = {k: v.to(device) for k, v in batch['x_num'].items()}
            x_cat = {k: v.to(device) for k, v in batch['x_cat'].items()}
            sequence_data = {}
            for seq_name, seq_data in batch['sequence_data'].items():
                sequence_data[seq_name] = {
                    'x_num': {k: v.to(device) for k, v in seq_data['x_num'].items()},
                    'x_cat': {k: v.to(device) for k, v in seq_data['x_cat'].items()},
                    'mask': seq_data['mask'].to(device)
                }
            mask = batch['mask'].to(device)
            rankings = batch['rankings'].to(device)
            
            # 前向き計算
            scores = model(x_num, x_cat, sequence_data, mask)
            loss = criterion(scores, rankings, mask)
            
            total_loss += loss.item()
            num_batches += 1
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / num_batches

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = Path(f"outputs/training_{timestamp}")
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Device: {config['device']}")

In [None]:
logger.info("Preparing training dataset...")
train_dataset = HorguesDataset(
    start_date=config['train_start_date'],
    end_date=config['train_end_date'],
    num_horses=config['num_horses'],
    horse_history_length=config['horse_history_length'],
    history_days=config['history_days'],
    exclude_hours_before_race=config['exclude_hours_before_race'],
    cache_dir=config['cache_dir'],
    use_cache=config['use_cache']
)

preprocessing_params = train_dataset.get_preprocessing_params()

logger.info("Preparing validation dataset...")
val_dataset = HorguesDataset(
    start_date=config['val_start_date'],
    end_date=config['val_end_date'],
    num_horses=config['num_horses'],
    horse_history_length=config['horse_history_length'],
    history_days=config['history_days'],
    exclude_hours_before_race=config['exclude_hours_before_race'],
    preprocessing_params=preprocessing_params,
    cache_dir=config['cache_dir'],
    use_cache=config['use_cache']
)

logger.info(f"Training samples: {len(train_dataset)}")
logger.info(f"Validation samples: {len(val_dataset)}")

# データローダー準備
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
)

In [None]:
numerical_features = [feature_name for feature_name in preprocessing_params['scaler'].keys()]
categorical_features = {feature_name: max(encoder_param.values()) for feature_name, encoder_param in preprocessing_params['encoder'].items()}

sample_batch = next(iter(train_dataloader))
sequence_names = list(sample_batch['sequence_data'].keys())
feature_aliases = train_dataset.get_feature_aliases()

logger.info(f"Numerical features: {numerical_features}")
logger.info(f"Categorical features: {categorical_features}")
logger.info(f"Sequence names: {sequence_names}")
logger.info(f"Feature aliases: {feature_aliases}")

In [None]:
model = HorguesModel(
    sequence_names=sequence_names,
    feature_aliases=feature_aliases,
    numerical_features=numerical_features,
    categorical_features=categorical_features,
    d_token=config['d_token'],
    num_bins=config['num_bins'],
    binning_temperature=config['binning_temperature'],
    binning_init_range=config['binning_init_range'],
    ft_n_layers=config['ft_n_layers'],
    ft_n_heads=config['ft_n_heads'],
    ft_d_ffn=config['ft_d_ffn'],
    seq_n_layers=config['seq_n_layers'],
    seq_n_heads=config['seq_n_heads'],
    seq_d_ffn=config['seq_d_ffn'],
    race_n_layers=config['race_n_layers'],
    race_n_heads=config['race_n_heads'],
    race_d_ffn=config['race_d_ffn'],
    dropout=config['dropout']
).to(config['device'])

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")

In [None]:
# 損失関数
criterion = WeightedPlackettLuceLoss(
    temperature=config['loss_temperature'],
    top_k=config['loss_top_k'],
    weight_decay=config['weight_decay'],
    reduction='mean'
)

# オプティマイザ
optimizer = AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay_optimizer']
)

# スケジューラ
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=config['num_epochs'] - config['warmup_epochs'],
    eta_min=config['learning_rate'] * 0.01
)

In [None]:
history = {
    'train_loss': [],
    'val_loss': [],
    'epochs': [],
    'learning_rates': [],
}

best_val_loss = float('inf')
patience_counter = 0

logger.info("Starting training...")

for epoch in range(config['num_epochs']):
    logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")

    # ウォームアップ期間中は学習率を線形に増加
    if epoch < config['warmup_epochs']:
        lr = config['learning_rate'] * (epoch + 1) / config['warmup_epochs']
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    current_lr = optimizer.param_groups[0]['lr']
    logger.info(f"Learning rate: {current_lr:.6f}")

    train_loss = train_epoch(model, train_dataloader, criterion, optimizer, config['device'])
    val_loss = validate_epoch(model, val_dataloader, criterion, config['device'])
    logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['epochs'].append(epoch + 1)
    history['learning_rates'].append(current_lr)

    # ベストモデル保存
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0

        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config,
            'preprocessing_params': preprocessing_params,
            'feature_info': {
                'numerical_features': numerical_features,
                'categorical_features': categorical_features,
                'sequence_names': sequence_names,
                'feature_aliases': feature_aliases
            }
        }, output_dir / 'best_model.pth')

        logger.info(f"New best model saved (val_loss: {val_loss:.4f})")
    else:
        patience_counter += 1

    # 定期保存
    if (epoch + 1) % config['save_every_n_epochs'] == 0:
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config,
            'preprocessing_params': preprocessing_params,
            'feature_info': {
                'numerical_features': numerical_features,
                'categorical_features': categorical_features,
                'sequence_names': sequence_names,
                'feature_aliases': feature_aliases
            }
        }, output_dir / f'checkpoint_epoch_{epoch+1}.pth')
        
        logger.info(f"Checkpoint saved at epoch {epoch+1}")
    
    # 学習履歴保存
    with open(output_dir / 'history.json', 'w', encoding='utf-8') as f:
        json.dump(history, f, indent=2, ensure_ascii=False)
    
    # 早期停止
    if patience_counter >= config['early_stopping_patience']:
        logger.info(f"Early stopping triggered after {epoch+1} epochs (patience: {config['early_stopping_patience']})")
        break

# 最終モデル保存
torch.save({
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
    'train_loss': train_loss,
    'val_loss': val_loss,
    'config': config,
    'preprocessing_params': train_dataset.get_preprocessing_params(),
    'feature_info': {
        'numerical_features': numerical_features,
        'categorical_features': categorical_features,
        'sequence_names': sequence_names,
        'feature_aliases': feature_aliases
    }
}, output_dir / 'final_model.pth')

logger.info("Training completed!")
logger.info(f"Best validation loss: {best_val_loss:.4f}")
logger.info(f"Final train loss: {train_loss:.4f}")
logger.info(f"Final validation loss: {val_loss:.4f}")
logger.info(f"Results saved to: {output_dir}")