In [None]:
%%sql


In [None]:
import cupy as cp
import pandas as pd
import rasterio
from tqdm import tqdm
import os
import time
import numpy as np
import gc
import psutil
import shutil


def get_disk_usage(path="."):
    total, used, free = shutil.disk_usage(path)
    return {
        'total_gb': total / 1024**3,
        'used_gb': used / 1024**3,
        'free_gb': free / 1024**3,
        'used_percent': (used / total) * 100
    }


def get_memory_usage():
    "Memory Usage"
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return mem_info.rss / 1024 / 1024 / 1024


def initialize_gpu_safe():

    try:
        if not cp.cuda.is_available():
            return False
        
        device = cp.cuda.Device(0)
        mem_info = device.mem_info
        free_gb = mem_info[1] / 1024**3
        
        if free_gb < 5:
            print(f"GPU memory insufficient: {free_gb:.1f}GB")
            return False
        
        mempool = cp.get_default_memory_pool()
        mempool.set_limit(size=int(free_gb * 0.6 * 1024**3))
        
        print(f"GPU initialized, memory limit: {free_gb * 0.6:.1f}GB")
        return True
    except Exception as e:
        print(f"GPU init failed: {e}")
        return False


def force_cleanup():
    """Clear memory"""
    try:
        cp.get_default_memory_pool().free_all_blocks()
        cp.get_default_pinned_memory_pool().free_all_blocks()
    except:
        pass
    gc.collect()


def get_raster_info(raster_file):
    try:
        with rasterio.open(raster_file) as src:
            # first band
            band = src.read(1, window=((0, 1), (0, 1)))  # Obtain data type through the first pixel
            
            info = {
                'shape': src.shape,
                'dtype': band.dtype,
                'nodata': src.nodata,
                'transform': src.transform,
                'crs': src.crs,
                'bounds': src.bounds
            }
            return info
    except Exception as e:
        print(f"Error getting info for {raster_file}: {e}")
        return None


def process_chunk_fixed(chunk_data, transform, start_row, nodata_value, use_gpu=True):
    try:
        rows, cols = chunk_data.shape
        
        if nodata_value is not None:
            # Effective data to avoid Nodata
            valid_mask = chunk_data != nodata_value
        else:
            # Assuming 0 is an invalid value without nodata
            valid_mask = chunk_data != 0
        
        # 检查是否有有效数据
        if not np.any(valid_mask):
            return {}
        
        if use_gpu and cp.cuda.is_available():
            return process_chunk_gpu_fixed(chunk_data, transform, start_row, valid_mask)
        else:
            return process_chunk_cpu_fixed(chunk_data, transform, start_row, valid_mask)
            
    except Exception as e:
        print(f"Error processing chunk: {e}")
        return {}


def process_chunk_gpu_fixed(chunk_data, transform, start_row, valid_mask):
    """GPU process chunk"""
    try:
        # GPU Memory Usage
        device = cp.cuda.Device(0)
        free_gb = device.mem_info[1] / 1024**3
        estimated_mem = (chunk_data.size * 4 * 3) / 1024**3
        
        if estimated_mem > free_gb * 0.8:
            return process_chunk_cpu_fixed(chunk_data, transform, start_row, valid_mask)
        
        row_idx, col_idx = np.where(valid_mask)
        values = chunk_data[valid_mask]
        
        row_idx_gpu = cp.asarray(row_idx + start_row)
        col_idx_gpu = cp.asarray(col_idx)
        
        x_coords = transform[0] + col_idx_gpu * transform[1] + row_idx_gpu * transform[2]
        y_coords = transform[3] + col_idx_gpu * transform[4] + row_idx_gpu * transform[5]
        
        x_coords_cpu = cp.asnumpy(x_coords)
        y_coords_cpu = cp.asnumpy(y_coords)
        
        del row_idx_gpu, col_idx_gpu, x_coords, y_coords
        force_cleanup()
        
        result = {(x, y): val for x, y, val in zip(x_coords_cpu, y_coords_cpu, values)}
        return result
        
    except Exception as e:
        print(f"GPU processing failed: {e}")
        force_cleanup()
        return process_chunk_cpu_fixed(chunk_data, transform, start_row, valid_mask)


def process_chunk_cpu_fixed(chunk_data, transform, start_row, valid_mask):
    """CPU process chunk"""
    try:
        row_idx, col_idx = np.where(valid_mask)
        values = chunk_data[valid_mask]
        
        global_row_idx = row_idx + start_row
        
        x_coords = transform[0] + col_idx * transform[1] + global_row_idx * transform[2]
        y_coords = transform[3] + col_idx * transform[4] + global_row_idx * transform[5]
        
        result = {(x, y): val for x, y, val in zip(x_coords, y_coords, values)}
        return result
        
    except Exception as e:
        print(f"CPU processing failed: {e}")
        return {}


def process_raster_file_fixed(raster_file, raster_name, use_gpu=True, chunk_size=512):
    """修复后的栅格文件处理"""
    print(f"\n{'='*60}")
    print(f"Processing: {raster_name}")
    
    # 获取栅格信息
    raster_info = get_raster_info(raster_file)
    if not raster_info:
        return {}
    
    rows, cols = raster_info['shape']
    transform = raster_info['transform']
    nodata_value = raster_info['nodata']
    
    print(f"Dimensions: {rows}x{cols}")
    print(f"NoData value: {nodata_value}")
    print(f"Chunk size: {chunk_size}")
    
    coordinate_data = {}
    total_chunks = (rows + chunk_size - 1) // chunk_size
    
    try:
        with rasterio.open(raster_file) as src:
            with tqdm(total=total_chunks, desc=f"Processing {raster_name}", unit="chunk") as pbar:
                for chunk_idx in range(total_chunks):
                    start_row = chunk_idx * chunk_size
                    end_row = min(start_row + chunk_size, rows)
                    
                    chunk_data = src.read(1, window=((start_row, end_row), (0, cols)))
                    
                    chunk_result = process_chunk_fixed(
                        chunk_data, transform, start_row, nodata_value, use_gpu
                    )
                    
                    coordinate_data.update(chunk_result)
                    
                    pbar.update(1)
                    pbar.set_postfix({
                        'coords': f"{len(coordinate_data):,}",
                        'mem': f"{get_memory_usage():.1f}GB"
                    })
                    
                    if chunk_idx % 20 == 0:
                        force_cleanup()
    
    except Exception as e:
        print(f"Error processing {raster_name}: {e}")
        return {}
    
    print(f"✓ {raster_name} completed: {len(coordinate_data):,} coordinates")
    return coordinate_data


def build_dataset_final_fixed(raster_files, output_txt="final_output.txt", use_gpu=True, chunk_size=512):
    """Final version dataset construction"""
    
    disk_info = get_disk_usage()
    print(f"Disk Space - Free: {disk_info['free_gb']:.1f}GB")
    
    raster_names = [os.path.splitext(os.path.basename(f))[0] for f in raster_files]
    
    gpu_available = False
    if use_gpu:
        gpu_available = initialize_gpu_safe()
    
    processing_mode = "GPU" if gpu_available else "CPU"
    
    print(f"\n{'='*80}")
    print(f"FINAL FIXED PROCESSING")
    print(f"Files: {raster_names}")
    print(f"Mode: {processing_mode}")
    print(f"Chunk size: {chunk_size}")
    print(f"{'='*80}")
    
    # 存储所有坐标数据
    all_coordinates = {}
    start_time = time.time()
    
    for file_idx, (raster_file, raster_name) in enumerate(zip(raster_files, raster_names)):
        print(f"\n[{file_idx+1}/{len(raster_files)}] Processing {raster_name}")
        
        file_coordinates = process_raster_file_fixed(
            raster_file, raster_name, gpu_available, chunk_size
        )
        
        for coord, value in file_coordinates.items():
            if coord not in all_coordinates:
                all_coordinates[coord] = [0] * len(raster_files)
            all_coordinates[coord][file_idx] = value
        
        print(f"File {file_idx+1} processed: {len(file_coordinates):,} coordinates")
        print(f"Total unique coordinates: {len(all_coordinates):,}")
        
        del file_coordinates
        force_cleanup()
    
    print(f"\n{'='*60}")
    print("WRITING OUTPUT FILE")
    print(f"{'='*60}")
    
    try:
        with open(output_txt, 'w') as f:

            f.write("Longitude Latitude " + " ".join(raster_names) + "\n")
            
            with tqdm(total=len(all_coordinates), desc="Writing output", unit="row") as pbar:
                for coord, values in all_coordinates.items():
                    x, y = coord
                    line = f"{x} {y} " + " ".join(map(str, values)) + "\n"
                    f.write(line)
                    pbar.update(1)
        
        output_size = os.path.getsize(output_txt) / 1024**3
        
    except Exception as e:
        print(f"Error writing output: {e}")
        return False
    
    total_time = time.time() - start_time
    
    print(f"\n{'='*80}")
    print(f"PROCESSING COMPLETE")
    print(f"{'='*80}")
    print(f"Total time: {total_time/60:.1f} minutes")
    print(f"Coordinates processed: {len(all_coordinates):,}")
    print(f"Output file: {output_txt}")
    print(f"Output size: {output_size:.3f}GB")
    print(f"Processing mode: {processing_mode}")
    print(f"Success: ✓")
    print(f"{'='*80}")
    
    return True


if __name__ == '__main__':
    raster_files = [
        "c:/NC_data/resample_file/Trans01_20_12_Reclass1.tif",
        "c:/NC_data/resample_file/intersec_koppen.tif"
    ]
    
    # 运行修复后的版本
    success = build_dataset_final_fixed(
        raster_files,
        output_txt="c:/NC_data/processed_data_2.txt"
        use_gpu=True,
        chunk_size=256  # 保守的chunk size
    )