In [None]:
"""
FractalAutoencoder Training Script for Google Colab
"""

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
import time

# ===== Parse Parameters =====
def parse_list_param(param_str: str) -> List[int]:
    """文字列パラメータをリストに変換"""
    return [int(x.strip()) for x in param_str.split(',')]

# ===== Training Functions =====
def train_epoch(model, train_loader, optimizer, device, epoch, log_interval=10):
    """1エポックの学習"""
    model.train()
    losses = []

    pbar = tqdm(train_loader, desc=f'Epoch {epoch} [Train]')
    for batch_idx, (masked_images, original_images, masks, labels) in enumerate(pbar):
        # デバイスに転送
        masked_images = masked_images.to(device)
        original_images = original_images.to(device)

        # 順伝播
        optimizer.zero_grad()
        outputs = model(masked_images)

        # 損失計算（シンプルなMSE）
        loss = nn.functional.mse_loss(outputs, original_images)

        # 逆伝播
        loss.backward()
        optimizer.step()

        # 記録
        losses.append(loss.item())

        # 直近の平均損失を表示
        if batch_idx % log_interval == 0:
            recent_loss = np.mean(losses[-log_interval:]) if len(losses) >= log_interval else np.mean(losses)
            pbar.set_postfix({'loss': f'{recent_loss:.4f}'})

    return losses

def evaluate(model, test_loader, device, num_samples=None):
    """評価"""
    model.eval()
    losses = []

    with torch.no_grad():
        for idx, (masked_images, original_images, masks, labels) in enumerate(test_loader):
            if num_samples and idx * test_loader.batch_size >= num_samples:
                break

            masked_images = masked_images.to(device)
            original_images = original_images.to(device)

            outputs = model(masked_images)
            loss = nn.functional.mse_loss(outputs, original_images)
            losses.append(loss.item())

    return np.mean(losses)

def visualize_results(model, dataloader, device, num_samples=4, epoch=0):
    """結果の可視化"""
    model.eval()

    # データ取得
    masked_images, original_images, masks, labels = next(iter(dataloader))
    num_samples = min(num_samples, len(masked_images))

    # 推論
    with torch.no_grad():
        masked_images_device = masked_images[:num_samples].to(device)
        reconstructed = model(masked_images_device).cpu()

    # データセット情報
    dataset = dataloader.dataset
    is_grayscale = dataset.effective_channels == 1
    image_shape = dataset.effective_shape

    # Reshape if flattened
    if dataset.flatten:
        if is_grayscale:
            h, w = image_shape
            masked_vis = masked_images[:num_samples].reshape(num_samples, h, w)
            original_vis = original_images[:num_samples].reshape(num_samples, h, w)
            reconstructed_vis = reconstructed.reshape(num_samples, h, w)
            masks_vis = masks[:num_samples].reshape(num_samples, h, w)
        else:
            h, w, c = image_shape
            masked_vis = masked_images[:num_samples].reshape(num_samples, h, w, c)
            original_vis = original_images[:num_samples].reshape(num_samples, h, w, c)
            reconstructed_vis = reconstructed.reshape(num_samples, h, w, c)
            # ピクセル単位マスクなので、最初のチャンネルだけ取る（全チャンネル同じ）
            masks_vis = masks[:num_samples].reshape(num_samples, h, w, c)[:, :, :, 0]

    # Visualization
    fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(num_samples):
        # Masked
        ax = axes[i, 0]
        if is_grayscale:
            ax.imshow(masked_vis[i], cmap='gray', vmin=0, vmax=1)
        else:
            ax.imshow(masked_vis[i], vmin=0, vmax=1)
        ax.set_title(f'Masked ({(1-masks_vis[i].mean()):.1%})')
        ax.axis('off')

        # Reconstructed
        ax = axes[i, 1]
        if is_grayscale:
            ax.imshow(reconstructed_vis[i], cmap='gray', vmin=0, vmax=1)
        else:
            ax.imshow(np.clip(reconstructed_vis[i], 0, 1))
        # シンプルなMSE計算
        mse = ((original_vis[i] - reconstructed_vis[i])**2).mean()
        ax.set_title(f'Predicted (MSE: {mse:.4f})')
        ax.axis('off')

        # Original
        ax = axes[i, 2]
        if is_grayscale:
            ax.imshow(original_vis[i], cmap='gray', vmin=0, vmax=1)
        else:
            ax.imshow(original_vis[i], vmin=0, vmax=1)
        ax.set_title(f'Original (Label: {labels[i]})')
        ax.axis('off')

    plt.suptitle(f'Epoch {epoch} Results', fontsize=14)
    plt.tight_layout()
    plt.show()

    return fig

# ===== Main Training Loop =====
def main():
    """メイン学習ループ"""

    # データローダー作成
    print("Loading dataset...")
    train_loader, test_loader, input_dim = create_dataloader(
        dataset_name=dataset_name,
        batch_size=batch_size,
        mask_ratio=mask_ratio,
        flatten=True,
        normalize=True,
        grayscale=grayscale,
        num_workers=2,
        pin_memory=True
    )
    print(f"Dataset: {dataset_name}, Input dim: {input_dim}")
    print(f"Train samples: {len(train_loader.dataset)}, Test samples: {len(test_loader.dataset)}")

    # モデル設定
    config = FractalConfig(
        max_depth=max_depth,
        input_dim=input_dim,
        output_dim=input_dim,  # 再構成タスクなので同じ次元
        num_iterations=num_iterations,
        gate_momentum=gate_momentum,
        use_ffn=use_ffn,
        attention_heads=attention_heads,
        head_dims=head_dims,
        proto_dims=proto_dims,
        proto_nums=proto_nums
    )

    # モデル作成
    print("\nCreating model...")
    model = FractalAutoencoder(config).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # オプティマイザ
    if optimizer_type == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_type == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:  # SGD
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=0.9)

    # スケジューラ
    scheduler = None
    if scheduler_type == "StepLR":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=scheduler_factor)
    elif scheduler_type == "CosineAnnealingLR":
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    elif scheduler_type == "ReduceLROnPlateau":
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=scheduler_factor, patience=scheduler_patience)

    # 学習履歴
    history = {'train_loss': [], 'test_loss': [], 'epoch_times': []}
    best_test_loss = float('inf')

    # 学習ループ
    print("\nStarting training...")
    for epoch in range(1, epochs + 1):
        start_time = time.time()

        # Train
        train_losses = train_epoch(model, train_loader, optimizer, device, epoch, log_interval)
        train_loss = np.mean(train_losses[-log_interval:])  # 直近の平均

        # Test
        test_loss = evaluate(model, test_loader, device, num_samples=test_samples)

        # 記録
        epoch_time = time.time() - start_time
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)
        history['epoch_times'].append(epoch_time)

        # 表示
        print(f"\nEpoch {epoch}/{epochs}")
        print(f"  Train Loss (recent): {train_loss:.6f}")
        print(f"  Test Loss: {test_loss:.6f}")
        print(f"  Time: {epoch_time:.1f}s")

        # 可視化
        print("\nVisualizing results...")
        visualize_results(model, test_loader, device, num_samples=vis_samples, epoch=epoch)

        # ベストモデル保存
        if save_checkpoint and test_loss < best_test_loss:
            best_test_loss = test_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'test_loss': test_loss,
                'config': config,
            }, checkpoint_path)
            print(f"  Saved best model (test_loss: {test_loss:.6f})")

        # スケジューラ更新
        if scheduler:
            if scheduler_type == "ReduceLROnPlateau":
                scheduler.step(test_loss)
            else:
                scheduler.step()

        print("-" * 50)

    # 最終結果
    print("\n" + "="*50)
    print("Training Complete!")
    print(f"Best Test Loss: {best_test_loss:.6f}")
    print(f"Total Time: {sum(history['epoch_times']):.1f}s")

    # Loss曲線をプロット
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss (recent)')
    plt.plot(history['test_loss'], label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(history['epoch_times'])
    plt.xlabel('Epoch')
    plt.ylabel('Time (s)')
    plt.title('Epoch Time')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    return model, history
