In [None]:
import torch
import numpy as np
from multiprocessing import Pool
from tqdm import tqdm

def process_batch(samples_batch):
    """处理一个batch的样本，返回最大depth值"""
    batch_max = 0
    for sample in samples_batch:
        depth_path = sample[1]  # (color_path, depth_path)
        depth_pil = Image.open(depth_path)
        depth_np = np.array(depth_pil).astype('float32')
        if depth_np.ndim == 3:
            depth_np = depth_np[:, :, 0]
        current_max = np.max(depth_np)
        if current_max > batch_max:
            batch_max = current_max
    return batch_max

def find_max_depth_fast(data_root, split, batch_size=1000, num_workers=4):
    """快速查找最大depth值（使用多进程）"""
    dataset = SNDataset(data_root=data_root, split=split)
    samples = dataset.samples
    
    # 分割数据为多个batch
    batches = [samples[i:i + batch_size] 
               for i in range(0, len(samples), batch_size)]
    
    # 使用多进程并行处理
    with Pool(num_workers) as pool:
        results = list(tqdm(pool.imap(process_batch, batches), 
                           total=len(batches),
                       desc=f"Processing {split} set"))
    
    return max(results)

def find_max_depth_gpu(data_root, split, device='cuda'):
    dataset = SNDataset(data_root=data_root, split=split)
    max_depth = torch.tensor(0.0).to(device)
    
    for i in tqdm(range(len(dataset)), desc=f"Processing {split} set"):
        sample = dataset[i]
        depth = sample['depth'].to(device)
        
        print(depth)
        current_max = depth.max()
        if current_max > max_depth:
            max_depth = current_max
    
    return max_depth.item()


if __name__ == '__main__':
    data_root = "/home/lsk/sn-depth-main/depth-2025"
    
    # # 多进程CPU（适合大数据集）
    # print("Calculating max depth using multi-processing...")
    # train_max = find_max_depth_fast(data_root, "train")
    # test_max = find_max_depth_fast(data_root, "test")
    
    # 2：GPU
    if torch.cuda.is_available():
        print("\nCalculating max depth using GPU...")
        train_max_gpu = find_max_depth_gpu(data_root, "train")
        test_max_gpu = find_max_depth_gpu(data_root, "test")
    
    print(f"\nResults:")
    print(f"Train set max depth (CPU): {train_max}")
    print(f"Test set max depth (CPU): {test_max}")
    if torch.cuda.is_available():
        print(f"Train set max depth (GPU): {train_max_gpu}")
        print(f"Test set max depth (GPU): {test_max_gpu}")


Calculating max depth using GPU...


Processing train set:   0%|          | 0/4071 [00:00<?, ?it/s]

Processing train set:   0%|          | 2/4071 [00:00<15:04,  4.50it/s]

tensor([[[48980., 48974., 48967.,  ..., 51517., 51510., 51503.],
         [48987., 48987., 48980.,  ..., 51510., 51510., 51503.],
         [48974., 48987., 48987.,  ..., 51510., 51510., 51503.],
         ...,
         [24654., 24654., 24654.,  ..., 24654., 24654., 24654.],
         [24628., 24628., 24628.,  ..., 24628., 24628., 24628.],
         [24595., 24595., 24595.,  ..., 24595., 24595., 24595.]]],
       device='cuda:0')
tensor([[[50907., 50907., 50914.,  ..., 51418., 51418., 51418.],
         [50907., 50907., 50914.,  ..., 51418., 51418., 51418.],
         [50907., 50914., 50920.,  ..., 51418., 51418., 51418.],
         ...,
         [31292., 31292., 31292.,  ..., 31292., 31292., 31292.],
         [31266., 31266., 31266.,  ..., 31266., 31266., 31266.],
         [31247., 31247., 31247.,  ..., 31247., 31247., 31247.]]],
       device='cuda:0')


Processing train set:   0%|          | 3/4071 [00:00<13:32,  5.01it/s]

tensor([[[49701., 49708., 49708.,  ..., 52015., 52008., 51903.],
         [49701., 49708., 49714.,  ..., 52002., 52002., 51903.],
         [49701., 49708., 49714.,  ..., 51982., 51982., 51798.],
         ...,
         [25833., 25833., 25833.,  ..., 25833., 25833., 25833.],
         [25801., 25801., 25801.,  ..., 25801., 25801., 25801.],
         [25768., 25768., 25768.,  ..., 25768., 25768., 25768.]]],
       device='cuda:0')


Processing train set:   0%|          | 3/4071 [00:04<1:36:45,  1.43s/it]


KeyboardInterrupt: 