利用已有产品随机分层抽样多进程生成256*256图片样本

In [None]:
import ee
import geemap
import logging
import multiprocessing
import os
import requests
import shutil
from retry import retry
geemap.set_proxy(port=4780)
Map = geemap.Map(center=(30, 115), zoom=8)

ESRI_water为label,S2为影像值,params中的值对分层采样也对生成图片等多个函数调用

In [None]:
region = ee.FeatureCollection("users/yamiletsharon250/wuhan").geometry()
ESRI_water = ee.ImageCollection("projects/sat-io/open-datasets/landcover/ESRI_Global-LULC_10m").mean().eq(1).selfMask().clip(region)
S2 = ee.ImageCollection('COPERNICUS/S2_SR').filterBounds(region).filterDate('2020', '2021').select( 'B8','B4', 'B3').median().visualize(min=0, max=4000).clip(region)                                                                                                                             
params = {
    'count': 100,  # How many image chips to export
    'buffer': 2560,  # The buffer distance (m) around each point
    'scale': 100,  # The scale to do stratified sampling
    'seed': 1,  # A randomization seed to use for subsampling.
    'dimensions': '256x256',  # The dimension of each image chip
    'format': "png",  # The output image format, can be png, jpg, ZIPPED_GEO_TIFF, GEO_TIFF, NPY
    'prefix': 'tile_',  # The filename prefix
    'processes': 25,  # How many processes to used for parallel processing
    'label_out_dir': '/label',  # The label output directory. Default to the current working directly
    'val_out_dir': '/val', # The val output directory. Default to the current working directly
}
def getSamples():
    img = ESRI_water
    points = img.stratifiedSample(
        numPoints=params['count'],
        region=region,
        scale=params['scale'],
        seed=params['seed'],
        geometries=True,
    )
    Map.data = points
    return points.aggregate_array('.geo').getInfo()

viz = {min: 1,max: 1,'opacity':1,'palette':['blue']}
Map.addLayer(S2, {}, "landasat")
Map.addLayer(ESRI_water, viz, "ESRI_water")
Map.addLayer(region, {}, "ROI",False)
Map

因为jupyter不支持对进程调用，这步的目的是保存一个py文件方便调用

In [None]:
%%writefile test.py
import ee
import geemap
import os
import requests
import shutil
from retry import retry
geemap.set_proxy(port=4780)
ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com')
region = ee.FeatureCollection("users/yamiletsharon250/wuhan").geometry()
ESRI_water = ee.ImageCollection("projects/sat-io/open-datasets/landcover/ESRI_Global-LULC_10m").mean().eq(1).selfMask().clip(region)
S2 = ee.ImageCollection('COPERNICUS/S2_SR').filterBounds(region).filterDate('2020', '2021').select( 'B8','B4', 'B3').median().visualize(min=0, max=4000).clip(region)             
params = {
    'count': 100,  # How many image chips to export
    'buffer': 2560,  # The buffer distance (m) around each point
    'scale': 100,  # The scale to do stratified sampling
    'seed': 1,  # A randomization seed to use for subsampling.
    'dimensions': '256x256',  # The dimension of each image chip
    'format': "png",  # The output image format, can be png, jpg, ZIPPED_GEO_TIFF, GEO_TIFF, NPY
    'prefix': 'tile_',  # The filename prefix
    'processes': 25,  # How many processes to used for parallel processing
    'label_out_dir': '/label',  # The output directory. Default to the current working directly
    'val_out_dir': '/val',
}
@retry(tries=10, delay=1, backoff=2)
def getLabelResult(index, point):
    point = ee.Geometry.Point(point['coordinates'])
    region = point.buffer(params['buffer']).bounds()

    if params['format'] in ['png', 'jpg']:
        url = ESRI_water.getThumbURL(
            {
                'region': region,
                'dimensions': params['dimensions'],
                'format': params['format'],
            }
        )
    else:
        url = ESRI_water.getDownloadURL(
            {
                'region': region,
                'dimensions': params['dimensions'],
                'format': params['format'],
            }
        )

    if params['format'] == "GEO_TIFF":
        ext = 'tif'
    else:
        ext = params['format']

    r = requests.get(url, stream=True)
    if r.status_code != 200:
        r.raise_for_status()
    out_dir = os.path.abspath(params['label_out_dir'])
    basename = str(index).zfill(len(str(params['count'])))
    filename = f"{out_dir}/{params['prefix']}{basename}.{ext}"
    with open(filename, 'wb') as out_file:
        shutil.copyfileobj(r.raw, out_file)
    print("Done: ", basename)
    
@retry(tries=10, delay=1, backoff=2)
def getValResult(index, point):
    point = ee.Geometry.Point(point['coordinates'])
    region = point.buffer(params['buffer']).bounds()

    if params['format'] in ['png', 'jpg']:
        url = S2.getThumbURL(
            {
                'region': region,
                'dimensions': params['dimensions'],
                'format': params['format'],
            }
        )
    else:
        url = S2.getDownloadURL(
            {
                'region': region,
                'dimensions': params['dimensions'],
                'format': params['format'],
            }
        )

    if params['format'] == "GEO_TIFF":
        ext = 'tif'
    else:
        ext = params['format']

    r = requests.get(url, stream=True)
    if r.status_code != 200:
        r.raise_for_status()
    out_dir = os.path.abspath(params['val_out_dir'])
    basename = str(index).zfill(len(str(params['count'])))
    filename = f"{out_dir}/{params['prefix']}{basename}.{ext}"
    with open(filename, 'wb') as out_file:
        shutil.copyfileobj(r.raw, out_file)
    print("Done: ", basename)

多进程运行保存的test.py

In [None]:
import test
logging.basicConfig()
items = getSamples()
pool = multiprocessing.Pool(params['processes'])
pool.starmap(test.getLabelResult, enumerate(items))
pool.starmap(test.getValResult, enumerate(items))
pool.close()
Map.addLayer(Map.data, {}, "Sample points")