In [8]:
%load_ext autoreload
%autoreload 2

import torch
import logging
from torch.utils.data import DataLoader
from horgues3.models import Horgues3Model, PlackettLuceLoss
from horgues3.dataset import Horgues3Dataset

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
dataset = Horgues3Dataset(max_horses=18).fetch_data().prepare_races()

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0  # Windowsでは0に設定
)

# サンプルデータの確認
sample = dataset[0]
logger.info("Sample data structure:")
for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        logger.info(f"{key}: {value.shape} - {value.dtype}")
    else:
        logger.info(f"{key}: {value}")

INFO:horgues3.dataset:Fetched 300888 records from the database.
INFO:horgues3.dataset:Prepared 24745 races with 2+ horses each.


Sample data structure:
x_num: torch.Size([18, 1]) - torch.float32
x_cat: torch.Size([18, 0]) - torch.int64
rankings: torch.Size([18]) - torch.int64
mask: torch.Size([18]) - torch.bool
race_id: 2020010145120101


In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# モデル、損失関数、オプティマイザーの初期化
model = Horgues3Model().to(device)
criterion = PlackettLuceLoss(temperature=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# 学習パラメータ
num_epochs = 10
log_interval = 10

logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
logger.info(f"Training samples: {len(dataset):,}")
logger.info(f"Batch size: {train_loader.batch_size}")
logger.info(f"Batches per epoch: {len(train_loader)}")

INFO:__main__:Using device: cuda
INFO:__main__:Model parameters: 1,502,785
INFO:__main__:Model parameters: 1,502,785
INFO:__main__:Training samples: 24,745
INFO:__main__:Batch size: 32
INFO:__main__:Batches per epoch: 774


In [None]:
# 学習ループ
model.train()
running_loss = 0.0  # ログ用の損失
total_batches = 0

for epoch in range(num_epochs):
    epoch_loss = 0.0
    num_batches = 0

    for batch_idx, batch in enumerate(train_loader):
        # データをデバイスに移動
        x_num = batch['x_num'].to(device)
        x_cat = batch['x_cat'].to(device)
        rankings = batch['rankings'].to(device)
        mask = batch['mask'].to(device)

        # 勾配をゼロに初期化
        optimizer.zero_grad()

        # 順伝播
        scores = model(x_num, x_cat, mask)

        # 損失計算
        loss = criterion(scores, rankings, mask)

        # 逆伝播
        loss.backward()

        # パラメータ更新
        optimizer.step()

        # 損失の記録
        epoch_loss += loss.item()
        running_loss += loss.item()
        num_batches += 1
        total_batches += 1

        # ログ出力
        if (batch_idx + 1) % log_interval == 0:
            avg_loss = running_loss / log_interval
            logger.info(f'Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Avg Loss: {avg_loss:.4f}')
            running_loss = 0.0

    # エポック終了時の平均損失
    avg_epoch_loss = epoch_loss / num_batches
    logger.info(f'Epoch {epoch + 1}/{num_epochs} completed - Average Loss: {avg_epoch_loss:.4f}')
    logger.info('-' * 50)

logger.info("Training completed!")

INFO:__main__:tensor([[    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
            -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
            -inf,     nan,     nan,     nan,     nan,     nan,    -inf,    -inf,
            -inf,    -inf],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,    -inf,
            -inf,    -inf],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [    nan,     nan,     nan, 

KeyboardInterrupt: 