# 学習セッション

In [None]:
import sys
import os
# プロジェクトのルートをCraft_respectに設定
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
from src.my_app import UNet, PreTrainDataset, create_optimized_dataloader
import torch
from PIL import Image
import torchvision.transforms as transforms
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
import json
import numpy as np
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
import multiprocessing as mp
def crop_labels_to_match(labels_to_crop, target_tensor):
    target_h, target_w = target_tensor.shape[2:]
    source_h, source_w = labels_to_crop.shape[2:]
    delta_h = (source_h - target_h) // 2
    delta_w = (source_w - target_w) // 2
    return labels_to_crop[:, :, delta_h:delta_h + target_h, delta_w:delta_w + target_w]

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

# テストデータのドキュメントIDを指定
test_doc_id_list = [
    '100241706', 
    '100249371', 
    '100249376', 
    '100249416', 
    '100249476', 
    '100249537', 
    '200003076', 
    '200003803', 
    '200003967', 
    '200004107'
]
train_dataset = PreTrainDataset(
    input_path='../../kuzushiji-recognition/synthetic_images/input_images/',
    json_path='../../kuzushiji-recognition/synthetic_images/gt_json.json',
    test_doc_id_list=test_doc_id_list,
    test_mode=False,
    device=torch.device('cuda'),  # GPUを明示的に指定
    precompute_gt=False,  # 事前計算を有効化
    # num_workers=None
    transform=transform,  # 画像変換を追加
    target_width=300
)
test_dataset = PreTrainDataset(
    input_path='../../kuzushiji-recognition/synthetic_images/input_images/',
    json_path='../../kuzushiji-recognition/synthetic_images/gt_json.json',
    test_doc_id_list=test_doc_id_list,
    test_mode=True,
    device=torch.device('cuda'),  # GPUを明示的に指定
    precompute_gt=False,  # 事前計算を有効化
    # num_workers=4
    transform=transform,  # 画像変換を追加
    target_width=300
)

# 最適化されたDataLoaderの作成
train_dl = create_optimized_dataloader(train_dataset, batch_size=1, num_workers=min(mp.cpu_count(), 4))
test_dl = create_optimized_dataloader(test_dataset, batch_size=1, num_workers=min(mp.cpu_count(), 4))

# --- モデル、損失関数、最適化手法の定義 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# device = torch.device("mps" if torch.cuda.is_available() else "cpu")
model = UNet(3, 4).to(device)
criterion = nn.MSELoss() # 回帰問題なのでMSE損失を使用
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# --- チェックポイントの設定 ---
checkpoint_dir = "../.checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# 最良のモデルを追跡するための変数
best_test_loss = float('inf')
start_epoch = 0

# チェックポイントの読み込み（存在する場合）
checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth")
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    best_test_loss = checkpoint['best_test_loss']
    train_loss_history = checkpoint['train_loss_history']
    test_loss_history = checkpoint['test_loss_history']
    print(f"チェックポイントを読み込みました（エポック {start_epoch}）")

print("学習を開始します...")
num_epochs = 100 # エポック数を定義

# 損失の履歴を保存するリストを初期化
train_loss_history = []
test_loss_history = []

print("学習を開始します...")
for epoch in range(start_epoch, num_epochs):
    print(f'start epcoch')
    # --- 訓練フェーズ ---
    model.train() # モデルを訓練モードに設定
    train_loss_total = 0
    
    # tqdmでプログレスバーを表示
    train_bar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for imgs, masks in train_bar:
        imgs, masks = imgs.to(device), masks.to(device)
        
        preds = model(imgs)
        cropped_masks = crop_labels_to_match(masks, preds)

        loss = criterion(preds, cropped_masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss_total += loss.item()
        # プログレスバーに現在のロスを表示
        train_bar.set_postfix(loss=loss.item())

    avg_train_loss = train_loss_total / len(train_dl)
    train_loss_history.append(avg_train_loss)

    # --- 評価フェーズ ---
    model.eval() # モデルを評価モードに設定
    test_loss_total = 0
    
    # 勾配計算を無効化して、メモリ効率を良くする
    with torch.no_grad():
        test_bar = tqdm(test_dl, desc=f"Epoch {epoch+1}/{num_epochs} [Test]")
        for imgs, masks in test_bar:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            cropped_masks = crop_labels_to_match(masks, preds)
            
            loss = criterion(preds, cropped_masks)
            test_loss_total += loss.item()
            test_bar.set_postfix(loss=loss.item())

    avg_test_loss = test_loss_total / len(test_dl)
    test_loss_history.append(avg_test_loss)
    
    # 各エポックの最後に訓練ロスとテストロスを表示
    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
    
    # 最新のチェックポイントを保存
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss_history': train_loss_history,
        'test_loss_history': test_loss_history,
        'best_test_loss': best_test_loss
    }, os.path.join(checkpoint_dir, "latest_checkpoint.pth"))
    
    # より良い性能が出た場合、ベストモデルとして保存
    if avg_test_loss < best_test_loss:
        best_test_loss = avg_test_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'test_loss': avg_test_loss,
        }, os.path.join(checkpoint_dir, "best_model.pth"))
        print(f"新しいベストモデルを保存しました（Test Loss: {avg_test_loss:.4f}）")

print("学習が完了しました。")

# --- 損失の推移をグラフで表示 ---
plt.figure(figsize=(10, 5))
plt.plot(train_loss_history, label="Train Loss")
plt.plot(test_loss_history, label="Test Loss")
plt.title("Loss Trend")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'PreTrainDataset' is not defined