In [2]:
from pathlib import Path
import xarray as xr
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import clip_grad_norm_
import numpy as np

inputs_variable1 = ['U', 'V', 'T', 'Q', 'CLDLIQ', 'CLDICE', 'PMID', 'DPRES', 'Z3', 'HEIGHT']
inputs_variable2 = ['TAUX', 'TAUY', 'SHFLX', 'LHFLX']
output_variable1 = ['SPDQ', 'SPDQC', 'SPDQI', 'SPNC', 'SPNI', 'SPDT', 'CLOUD', 'CLOUDTOP', 'QRL', 'QRS']
output_variable2 = ['PRECC', 'PRECSC', 'FSNT', 'FSDS', 'FSNS', 'FLNS', 'FLNT']

data_path = Path("/home/ET/mnwong/ML/data/Qobs10_SPCAMM.000.cam.h1.0001-02-13-00800.nc")
with xr.open_dataset(data_path) as ds:
    display(ds)
    print("\nData variables:", list(ds.data_vars))
    print("Coordinates:", list(ds.coords))


Data variables: ['gw', 'hyam', 'hybm', 'P0', 'hyai', 'hybi', 'date', 'datesec', 'time_bnds', 'date_written', 'time_written', 'ndbase', 'nsbase', 'nbdate', 'nbsec', 'mdt', 'ndcur', 'nscur', 'co2vmr', 'ch4vmr', 'n2ovmr', 'f11vmr', 'f12vmr', 'sol_tsi', 'nsteph', 'CLDHGH', 'CLDICE', 'CLDLIQ', 'CLDLOW', 'CLDMED', 'CLDTOT', 'CLOUD', 'CLOUDTOP', 'DPRES', 'FLNS', 'FLNT', 'FSDS', 'FSNS', 'FSNT', 'HEIGHT', 'LHFLX', 'NUMICE', 'NUMLIQ', 'PHIS', 'PMID', 'PRECC', 'PRECSC', 'PS', 'Q', 'QRL', 'QRS', 'SHFLX', 'SPDQ', 'SPDQC', 'SPDQI', 'SPDT', 'SPNC', 'SPNI', 'T', 'TAUX', 'TAUY', 'U', 'V', 'Z3']
Coordinates: ['lat', 'lon', 'lev', 'ilev', 'time']


In [6]:
# 查看纬度和经度坐标信息
print("纬度范围:", ds.lat.min().values, "到", ds.lat.max().values)
print("经度范围:", ds.lon.min().values, "到", ds.lon.max().values)
print("纬度形状:", ds.lat.shape)
print("经度形状:", ds.lon.shape)
print("\n纬度坐标:")
print(ds.lat.values[:10], "...", ds.lat.values[-10:])
print("\n经度坐标:")
print(ds.lon.values[:10], "...", ds.lon.values[-10:])

纬度范围: -90.0 到 90.0
经度范围: 0.0 到 359.375
纬度形状: (384,)
经度形状: (576,)

纬度坐标:
[-90.         -89.53002611 -89.06005222 -88.59007833 -88.12010444
 -87.65013055 -87.18015666 -86.71018277 -86.24020888 -85.77023499] ... [85.77023499 86.24020888 86.71018277 87.18015666 87.65013055 88.12010444
 88.59007833 89.06005222 89.53002611 90.        ]

经度坐标:
[0.    0.625 1.25  1.875 2.5   3.125 3.75  4.375 5.    5.625] ... [353.75  354.375 355.    355.625 356.25  356.875 357.5   358.125 358.75
 359.375]


In [7]:
import math

def create_polar_sampling_mask(ds, lat_threshold=60):
    """
    创建极地采样掩码
    对于纬度绝对值大于lat_threshold的区域，按cos(|lat|)比例进行经度采样
    
    Args:
        ds: xarray数据集
        lat_threshold: 纬度阈值，默认60度
    
    Returns:
        sampling_mask: 布尔数组，形状为(lat, lon)
    """
    lats = ds.lat.values
    lons = ds.lon.values
    
    # 创建采样掩码
    sampling_mask = np.ones((len(lats), len(lons)), dtype=bool)
    
    for i, lat in enumerate(lats):
        abs_lat = abs(lat)
        if abs_lat > lat_threshold:
            # 计算采样比例
            sampling_ratio = math.cos(math.radians(abs_lat))
            
            # 计算需要采样的经度点数
            n_lons_to_sample = max(1, int(len(lons) * sampling_ratio))
            
            # 均匀采样经度索引
            lon_indices = np.linspace(0, len(lons)-1, n_lons_to_sample, dtype=int)
            
            # 创建该纬度的掩码（只保留采样的经度点）
            lat_mask = np.zeros(len(lons), dtype=bool)
            lat_mask[lon_indices] = True
            sampling_mask[i, :] = lat_mask
            
            print(f"纬度 {lat:.1f}°: 采样比例 {sampling_ratio:.3f}, 采样 {n_lons_to_sample}/{len(lons)} 个经度点")
    
    return sampling_mask

# 创建采样掩码
sampling_mask = create_polar_sampling_mask(ds, lat_threshold=60)
print(f"\n采样掩码形状: {sampling_mask.shape}")
print(f"总的有效采样点: {sampling_mask.sum()}/{sampling_mask.size}")
print(f"采样比例: {sampling_mask.sum()/sampling_mask.size:.3f}")

纬度 -90.0°: 采样比例 0.000, 采样 1/576 个经度点
纬度 -89.5°: 采样比例 0.008, 采样 4/576 个经度点
纬度 -89.1°: 采样比例 0.016, 采样 9/576 个经度点
纬度 -88.6°: 采样比例 0.025, 采样 14/576 个经度点
纬度 -88.1°: 采样比例 0.033, 采样 18/576 个经度点
纬度 -87.7°: 采样比例 0.041, 采样 23/576 个经度点
纬度 -87.2°: 采样比例 0.049, 采样 28/576 个经度点
纬度 -86.7°: 采样比例 0.057, 采样 33/576 个经度点
纬度 -86.2°: 采样比例 0.066, 采样 37/576 个经度点
纬度 -85.8°: 采样比例 0.074, 采样 42/576 个经度点
纬度 -85.3°: 采样比例 0.082, 采样 47/576 个经度点
纬度 -84.8°: 采样比例 0.090, 采样 51/576 个经度点
纬度 -84.4°: 采样比例 0.098, 采样 56/576 个经度点
纬度 -83.9°: 采样比例 0.106, 采样 61/576 个经度点
纬度 -83.4°: 采样比例 0.115, 采样 66/576 个经度点
纬度 -83.0°: 采样比例 0.123, 采样 70/576 个经度点
纬度 -82.5°: 采样比例 0.131, 采样 75/576 个经度点
纬度 -82.0°: 采样比例 0.139, 采样 80/576 个经度点
纬度 -81.5°: 采样比例 0.147, 采样 84/576 个经度点
纬度 -81.1°: 采样比例 0.155, 采样 89/576 个经度点
纬度 -80.6°: 采样比例 0.163, 采样 94/576 个经度点
纬度 -80.1°: 采样比例 0.171, 采样 98/576 个经度点
纬度 -79.7°: 采样比例 0.179, 采样 103/576 个经度点
纬度 -79.2°: 采样比例 0.188, 采样 108/576 个经度点
纬度 -78.7°: 采样比例 0.196, 采样 112/576 个经度点
纬度 -78.3°: 采样比例 0.204, 采样 117/576 个经度点
纬度 -77.8°: 

In [7]:
def process_variable_with_polar_sampling(ds, var_name, sampling_mask, variable_type='3D', height_limit=None):
    """
    处理变量并应用极地采样
    
    Args:
        ds: xarray数据集
        var_name: 变量名
        sampling_mask: 采样掩码 (lat, lon)
        variable_type: '3D' 表示有高度维度，'2D' 表示只有lat/lon
        height_limit: 限制高度层数（仅对HEIGHT变量有效）
    
    Returns:
        reshaped_tensor: 处理后的张量
    """
    var_data = np.array(ds[var_name])
    
    if variable_type == '3D':
        # 如果是HEIGHT变量且指定了height_limit
        if var_name == 'HEIGHT' and height_limit is not None:
            var_data = var_data[:, :height_limit, :, :]
            
        # 形状: (time, height, lat, lon)
        # 应用采样掩码到lat和lon维度
        sampled_data = []
        for t in range(var_data.shape[0]):
            # 获取当前时间的所有高度层数据
            time_data = var_data[t]  # shape: (height, lat, lon)
            # 对每个高度层应用采样掩码
            masked_data = time_data[:, sampling_mask]  # shape: (height, sampled_points)
            # 转置使得采样点在第一维
            masked_data = masked_data.T  # shape: (sampled_points, height)
            sampled_data.append(masked_data)
        
        # 将所有时间步的数据堆叠起来
        reshaped = np.vstack(sampled_data)  # shape: (time*sampled_points, height)
        
    else:  # 2D variables
        # 形状: (time, lat, lon)
        sampled_data = []
        for t in range(var_data.shape[0]):
            data_slice = var_data[t, :, :]
            sampled_points = data_slice[sampling_mask]
            sampled_data.append(sampled_points)
        
        # 重塑为 (time*sampled_points, 1)
        reshaped = np.array(sampled_data).reshape(-1, 1)
    
    # 转换为tensor并标准化
    reshaped_tensor = torch.tensor(reshaped).float()
    mean = reshaped_tensor.mean(dim=0, keepdim=True)
    std = reshaped_tensor.std(dim=0, keepdim=True)
    reshaped_tensor = (reshaped_tensor - mean) / (std + 1e-6)
    
    return reshaped_tensor

def create_coordinate_variables(ds, sampling_mask):
    """
    创建经纬度和时间坐标变量
    
    Args:
        ds: xarray数据集
        sampling_mask: 采样掩码 (lat, lon)
    
    Returns:
        lat_tensor, lon_tensor, time_tensor: 经纬度和时间张量
    """
    lats = ds.lat.values
    lons = ds.lon.values
    
    # 创建经纬度网格
    lat_grid, lon_grid = np.meshgrid(lats, lons, indexing='ij')
    
    # 应用采样掩码
    sampled_lats = lat_grid[sampling_mask]
    sampled_lons = lon_grid[sampling_mask]
    
    # 对于每个时间步重复坐标
    n_times = ds.dims['time']
    n_sampled_points = len(sampled_lats)
    
    # 重复坐标以匹配数据维度
    lat_repeated = np.tile(sampled_lats, n_times)
    lon_repeated = np.tile(sampled_lons, n_times)
    
    # 创建时间坐标
    time_coords = []
    for t in range(n_times):
        time_value = t * 3600  # 假设每个时间步为1小时
        time_coords.extend([time_value] * n_sampled_points)
    
    # 转换为tensor并标准化
    lat_tensor = torch.tensor(lat_repeated).float().unsqueeze(1)
    lon_tensor = torch.tensor(lon_repeated).float().unsqueeze(1)
    time_tensor = torch.tensor(time_coords).float().unsqueeze(1)
    
    # 标准化坐标
    lat_tensor = (lat_tensor - lat_tensor.mean()) / (lat_tensor.std() + 1e-6)
    lon_tensor = (lon_tensor - lon_tensor.mean()) / (lon_tensor.std() + 1e-6)
    time_tensor = (time_tensor - time_tensor.mean()) / (time_tensor.std() + 1e-6)
    
    return lat_tensor, lon_tensor, time_tensor

# 处理输入变量
print("处理输入变量...")
tensors_3d = []
tensors_2d = []

# 处理3D输入变量 (有高度维度)
for var in inputs_variable1:
    print(f"处理变量: {var}")
    height_limit = 30 if var == 'HEIGHT' else None
    tensor = process_variable_with_polar_sampling(ds, var, sampling_mask, '3D', height_limit)
    tensors_3d.append(tensor)
    print(f"  - 形状: {tensor.shape}")

# 处理2D输入变量 (没有高度维度)  
for var in inputs_variable2:
    print(f"处理变量: {var}")
    tensor = process_variable_with_polar_sampling(ds, var, sampling_mask, '2D')
    tensors_2d.append(tensor)
    print(f"  - 形状: {tensor.shape}")

# 添加经纬度和时间坐标变量
print("添加坐标变量...")
lat_tensor, lon_tensor, time_tensor = create_coordinate_variables(ds, sampling_mask)
tensors_2d.extend([lat_tensor, lon_tensor, time_tensor])
print(f"处理变量: latitude")
print(f"  - 形状: {lat_tensor.shape}")
print(f"处理变量: longitude")
print(f"  - 形状: {lon_tensor.shape}")
print(f"处理变量: time")
print(f"  - 形状: {time_tensor.shape}")

# 合并3D和2D数据，并重塑3D数据
if tensors_3d:
    X_sampled_3d = torch.cat(tensors_3d, dim=1)  # 首先在特征维度上拼接
    n_samples = X_sampled_3d.shape[0]
    n_vars = len(inputs_variable1)
    n_levels = 30  # 使用统一的层数
    # 重塑为 (样本数, 变量数, 层数)
    X_sampled_3d = X_sampled_3d.reshape(n_samples, n_vars, n_levels)
else:
    X_sampled_3d = None

X_sampled_2d = torch.cat(tensors_2d, dim=1) if tensors_2d else None

print(f"\n输入数据最终形状:")
if X_sampled_3d is not None:
    print(f"3D数据形状: {X_sampled_3d.shape} (样本数, 变量数, 层数)")
if X_sampled_2d is not None:
    print(f"2D数据形状: {X_sampled_2d.shape}")

# 处理输出变量
print("\n处理输出变量...")
tensors_3d_y = []
tensors_2d_y = []

# 处理3D输出变量
for var in output_variable1:
    print(f"处理变量: {var}")
    tensor = process_variable_with_polar_sampling(ds, var, sampling_mask, '3D')
    tensors_3d_y.append(tensor)
    print(f"  - 形状: {tensor.shape}")

# 处理2D输出变量
for var in output_variable2:
    print(f"处理变量: {var}")
    tensor = process_variable_with_polar_sampling(ds, var, sampling_mask, '2D')
    tensors_2d_y.append(tensor)
    print(f"  - 形状: {tensor.shape}")

# 合并3D和2D输出数据，并重塑3D数据
if tensors_3d_y:
    Y_sampled_3d = torch.cat(tensors_3d_y, dim=1)  # 首先在特征维度上拼接
    n_samples = Y_sampled_3d.shape[0]
    n_vars = len(output_variable1)
    n_levels = 30  # 使用统一的层数
    # 重塑为 (样本数, 变量数, 层数)
    Y_sampled_3d = Y_sampled_3d.reshape(n_samples, n_vars, n_levels)
else:
    Y_sampled_3d = None

Y_sampled_2d = torch.cat(tensors_2d_y, dim=1) if tensors_2d_y else None

print(f"\n输出数据最终形状:")
if Y_sampled_3d is not None:
    print(f"3D数据形状: {Y_sampled_3d.shape} (样本数, 变量数, 层数)")
if Y_sampled_2d is not None:
    print(f"2D数据形状: {Y_sampled_2d.shape}")

print(f"\n采样前后数据点数比较:")
print(f"原始数据点数: {27 * 384 * 576} (时间 × 纬度 × 经度)")
total_samples = X_sampled_3d.shape[0] if X_sampled_3d is not None else X_sampled_2d.shape[0]
print(f"采样后数据点数: {total_samples}")
print(f"数据压缩比: {total_samples / (27 * 384 * 576):.3f}")

处理输入变量...
处理变量: U
  - 形状: torch.Size([4482594, 30])
处理变量: V
  - 形状: torch.Size([4482594, 30])
处理变量: V
  - 形状: torch.Size([4482594, 30])
处理变量: T
  - 形状: torch.Size([4482594, 30])
处理变量: T
  - 形状: torch.Size([4482594, 30])
处理变量: Q
  - 形状: torch.Size([4482594, 30])
处理变量: Q
  - 形状: torch.Size([4482594, 30])
处理变量: CLDLIQ
  - 形状: torch.Size([4482594, 30])
处理变量: CLDLIQ
  - 形状: torch.Size([4482594, 30])
处理变量: CLDICE
  - 形状: torch.Size([4482594, 30])
处理变量: CLDICE
  - 形状: torch.Size([4482594, 30])
处理变量: PMID
  - 形状: torch.Size([4482594, 30])
处理变量: PMID
  - 形状: torch.Size([4482594, 30])
处理变量: DPRES
  - 形状: torch.Size([4482594, 30])
处理变量: DPRES
  - 形状: torch.Size([4482594, 30])
处理变量: Z3
  - 形状: torch.Size([4482594, 30])
处理变量: Z3
  - 形状: torch.Size([4482594, 30])
处理变量: HEIGHT
  - 形状: torch.Size([4482594, 30])
处理变量: HEIGHT
  - 形状: torch.Size([4482594, 30])
处理变量: TAUX
  - 形状: torch.Size([4482594, 1])
处理变量: TAUY
  - 形状: torch.Size([4482594, 1])
处理变量: SHFLX
  - 形状: torch.Size([4482594, 1])
处理变量: LHFLX
 

In [10]:
X_sampled_3d.shape
# Y_sampled_3d.shape

torch.Size([4482594, 10, 30])

In [None]:
# 1. 首先划分训练集和测试集
from tqdm.auto import tqdm
import time

n_samples = X_sampled_3d.shape[0]
n_train = int(0.8 * n_samples)

# 打乱数据的索引
indices = torch.randperm(n_samples)
train_indices = indices[:n_train]
test_indices = indices[n_train:]

# 划分3D数据
X_train_3d = X_sampled_3d[train_indices]
X_test_3d = X_sampled_3d[test_indices]
Y_train_3d = Y_sampled_3d[train_indices]
Y_test_3d = Y_sampled_3d[test_indices]

# 划分2D数据
X_train_2d = X_sampled_2d[train_indices]
X_test_2d = X_sampled_2d[test_indices]
Y_train_2d = Y_sampled_2d[train_indices]
Y_test_2d = Y_sampled_2d[test_indices]

print("训练集大小:", len(train_indices))
print("测试集大小:", len(test_indices))

# 2. 创建数据加载器
from torch.utils.data import TensorDataset, DataLoader

# 创建训练集和测试集的数据加载器
train_dataset = TensorDataset(X_train_3d, X_train_2d, Y_train_3d, Y_train_2d)
test_dataset = TensorDataset(X_test_3d, X_test_2d, Y_test_3d, Y_test_2d)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# 3. 定义2D UNet模块
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        # 编码器部分
        self.enc1 = UNetBlock(in_channels, 32)  # 减少通道数，避免过拟合
        
        # 解码器部分
        self.dec1 = UNetBlock(32, 32)
        
        self.final = nn.Conv2d(32, in_channels, kernel_size=1)
        
        # 下采样和上采样
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), ceil_mode=True)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        # 编码器路径
        e1 = self.enc1(x)  # [B, 32, 10, 30]
        p1 = self.pool(e1)  # [B, 32, 5, 15]
        
        # 解码器路径
        d1 = self.upsample(p1)  # [B, 32, 10, 30]
        
        # 处理上采样可能导致的尺寸不匹配
        if d1.shape[2:] != e1.shape[2:]:
            d1 = F.interpolate(d1, size=e1.shape[2:], mode='bilinear', align_corners=True)
            
        d1 = self.dec1(d1)  # [B, 32, 10, 30]
        out = self.final(d1)  # [B, 1, 10, 30]
        return out

# 4. 定义MLP模块
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64, output_dim)
        )
    
    def forward(self, x):
        return self.network(x)

# 5. 定义混合模型
class HybridModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet2D(in_channels=1)  # 每次处理一个样本的10×30图像
        self.mlp = MLP(X_train_2d.shape[1], Y_train_2d.shape[1])
        
    def forward(self, x_3d, x_2d):
        batch_size = x_3d.shape[0]
        
        # 改变输入形状以适应2D卷积 [batch, vars, levels] -> [batch, 1, vars, levels]
        x_3d_reshaped = x_3d.unsqueeze(1)  # 添加通道维度
        
        # 通过UNet处理3D数据
        y_3d = self.unet(x_3d_reshaped)  # 输出形状: [batch, 1, vars, levels]
        y_3d = y_3d.squeeze(1)  # 去除通道维度，变回 [batch, vars, levels]
        
        # 通过MLP处理2D数据
        y_2d = self.mlp(x_2d)
        
        return y_3d, y_2d

# 6. 初始化模型、优化器和损失函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion_3d = nn.MSELoss()
criterion_2d = nn.MSELoss()

# 7. 训练函数
def train_epoch(model, train_loader, optimizer, device, epoch, n_epochs):
    model.train()
    total_loss = 0
    total_loss_3d = 0
    total_loss_2d = 0
    num_batches = len(train_loader)
    
    # 创建进度条
    pbar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{n_epochs}] Training', 
                leave=False, dynamic_ncols=True)
    
    for x_3d, x_2d, y_3d, y_2d in pbar:
        x_3d, x_2d = x_3d.to(device), x_2d.to(device)
        y_3d, y_2d = y_3d.to(device), y_2d.to(device)
        
        optimizer.zero_grad()
        pred_3d, pred_2d = model(x_3d, x_2d)
        
        loss_3d = criterion_3d(pred_3d, y_3d)
        loss_2d = criterion_2d(pred_2d, y_2d)
        loss = loss_3d + loss_2d
        
        loss.backward()
        optimizer.step()
        
        # 更新损失值
        total_loss += loss.item()
        total_loss_3d += loss_3d.item()
        total_loss_2d += loss_2d.item()
        
        # 更新进度条描述
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            '3D_loss': f'{loss_3d.item():.4f}',
            '2D_loss': f'{loss_2d.item():.4f}'
        })
    
    return total_loss / num_batches, total_loss_3d / num_batches, total_loss_2d / num_batches

# 8. 评估函数
def evaluate(model, test_loader, device):
    model.eval()
    total_loss = 0
    total_loss_3d = 0
    total_loss_2d = 0
    num_batches = len(test_loader)
    
    # 创建进度条
    pbar = tqdm(test_loader, desc='Evaluating', leave=False, dynamic_ncols=True)
    
    with torch.no_grad():
        for x_3d, x_2d, y_3d, y_2d in pbar:
            x_3d, x_2d = x_3d.to(device), x_2d.to(device)
            y_3d, y_2d = y_3d.to(device), y_2d.to(device)
            
            pred_3d, pred_2d = model(x_3d, x_2d)
            
            loss_3d = criterion_3d(pred_3d, y_3d)
            loss_2d = criterion_2d(pred_2d, y_2d)
            loss = loss_3d + loss_2d
            
            # 更新损失值
            total_loss += loss.item()
            total_loss_3d += loss_3d.item()
            total_loss_2d += loss_2d.item()
            
            # 更新进度条描述
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                '3D_loss': f'{loss_3d.item():.4f}',
                '2D_loss': f'{loss_2d.item():.4f}'
            })
    
    return total_loss / num_batches, total_loss_3d / num_batches, total_loss_2d / num_batches

# 9. 训练模型
n_epochs = 50  # 恢复到50个epoch
print(f"使用设备: {device}")
print("开始训练...")

train_losses = []
test_losses = []

# 创建总进度条
epoch_pbar = tqdm(range(n_epochs), desc='Total Progress', dynamic_ncols=True)

for epoch in epoch_pbar:
    # 训练一个epoch
    train_loss, train_loss_3d, train_loss_2d = train_epoch(model, train_loader, optimizer, device, epoch, n_epochs)
    # 评估
    test_loss, test_loss_3d, test_loss_2d = evaluate(model, test_loader, device)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    # 更新总进度条信息
    epoch_pbar.set_postfix({
        'train_loss': f'{train_loss:.6f}',
        'test_loss': f'{test_loss:.6f}'
    })
    
    if (epoch + 1) % 5 == 0:
        print(f"\nEpoch {epoch+1}/{n_epochs}:")
        print(f"  训练损失: {train_loss:.6f} (3D: {train_loss_3d:.6f}, 2D: {train_loss_2d:.6f})")
        print(f"  测试损失: {test_loss:.6f} (3D: {test_loss_3d:.6f}, 2D: {test_loss_2d:.6f})")

print("\n训练完成！")

训练集大小: 3586075
测试集大小: 896519
使用设备: cuda
开始训练...


Total Progress:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch [1/50] Training:   0%|          | 0/56033 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/14009 [00:00<?, ?it/s]

Epoch [2/50] Training:   0%|          | 0/56033 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/14009 [00:00<?, ?it/s]

Epoch [3/50] Training:   0%|          | 0/56033 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/14009 [00:00<?, ?it/s]

Epoch [4/50] Training:   0%|          | 0/56033 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/14009 [00:00<?, ?it/s]

Epoch [5/50] Training:   0%|          | 0/56033 [00:00<?, ?it/s]