In [None]:
import torch
from utils.dataset import ScenarioGraphDataset
from torch_geometric.loader import DataLoader
from utils.data_aug_model import Generator, Discriminator, train_gan, evaluate_generator
from utils.dataset_utils import NODE_TYPE_MAP
import torch.optim as optim

import os
os.environ["OMP_NUM_THREADS"] = '1'

# ------------------ 主程序示例 ------------------
if __name__ == "__main__":
    # 假设已有场景及数据路径
    scene_datasets = {
        "secondary_road": ["dataset/driving-scene-graph/secondary-road"],
        "ebike": ['dataset/driving-scene-graph/ebike'],
        "main_secondary": ['dataset/driving-scene-graph/main-secondary'],
        "motor": [
            "dataset/driving-scene-graph/secondary-road",
            'dataset/driving-scene-graph/main-secondary'
        ],
        "total": [
            "dataset/driving-scene-graph/secondary-road",
            'dataset/driving-scene-graph/main-secondary',
            'dataset/driving-scene-graph/ebike'
        ]
    }
    # 数据集参数
    window_size = 30  # 窗口步长
    step_size = 1  # 假设的步长值
    node_feature_dim = 19 + len(NODE_TYPE_MAP)  # 节点特征维度 # 假设改为10
    num_classes = 3  # 标签类别数量

    # 训练参数
    hidden_dim = 64
    num_epochs = 100
    batch_size = 48
    num_workers = 8  # 使用 4 个进程加载数据
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 逐场景训练
    for scene_name, root_dirs in scene_datasets.items():
        print(f"\n=== 开始训练场景: {scene_name} ===")

        # 定义缓存路径，针对每个场景使用单独的缓存文件
        cache_path = f"dataset/cache/{scene_name}_{window_size}_{step_size}_dataset_pre_cache.pkl" 
        generator_model_path = f"model/data_aug/{scene_name}_{window_size}_{step_size}_generator_model.pth"
        checkpoint_path = f"model/checkpoint/{scene_name}_generator_model_checkpoint.pth" # 定义检查点路径

        dataset = ScenarioGraphDataset(root_dirs, window_size, step_size, device, cache_path)
        # 假设数据集的 80% 用于训练，20% 用于验证
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        # 在初始化 DataLoader
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        # 初始化生成器和判别器, 并添加 Dropout
        generator = Generator(node_feature_dim, hidden_dim, window_size, num_classes, dropout_rate=0.5).to(device)
        discriminator = Discriminator(node_feature_dim, hidden_dim, window_size, dropout_rate=0.5).to(device)
    
        # 初始化优化器
        g_optimizer = optim.AdamW(generator.parameters(), lr=1e-3, weight_decay=1e-4)
        d_optimizer = optim.AdamW(discriminator.parameters(), lr=1e-3, weight_decay=1e-4)
    
        # 训练 GAN
        train_gan(train_dataloader, generator, discriminator, g_optimizer, d_optimizer, num_epochs, device, window_size, val_dataloader, checkpoint_path, scene_name,num_classes=num_classes, generator_model_path=generator_model_path)

        # 保存模型
        # torch.save(generator.state_dict(), generator_model_path) # 已在早停中保存
        # torch.save(discriminator.state_dict(), f"{scene_name}_discriminator.pth")
        print(f"场景 {scene_name} 训练完成\n")


=== 开始训练场景: secondary_road ===
从缓存文件加载数据: dataset/cache/secondary_road_30_1_dataset_pre_cache.pkl
加载检查点 'model/checkpoint/secondary_road_generator_model_checkpoint.pth'，从 epoch 8 继续训练


Epoch 9/100: 100%|██████████| 43/43 [01:02<00:00,  1.46s/it, D_loss=0.698, G_loss=0.658]


Epoch 9: Validation F1 Score: 0.3242


Epoch 10/100: 100%|██████████| 43/43 [01:20<00:00,  1.88s/it, D_loss=0.693, G_loss=0.663]


Epoch 10: Validation F1 Score: 0.3458


Epoch 11/100: 100%|██████████| 43/43 [00:51<00:00,  1.20s/it, D_loss=0.696, G_loss=0.665]
