In [1]:
import pystac_client
import planetary_computer
import geopandas as gpd
import rasterio as rio
from pathlib import Path
from shapely.geometry import Point
import pyproj
from tqdm.auto import tqdm
from rasterio import Affine
from shapely.geometry import box
from shapely.ops import transform
from pyproj import Transformer
from multiprocess import Pool
import numpy as np
from math import sqrt
import time
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm


import os
os.environ['USE_PYGEOS'] = '0'
import geopandas

In the next release, GeoPandas will switch to using Shapely by default, even if PyGEOS is installed. If you only have PyGEOS installed to get speed-ups, this switch should be smooth. However, if you are using PyGEOS directly (calling PyGEOS functions on geometries from GeoPandas), this will then stop working and you are encouraged to migrate from PyGEOS to Shapely 2.0 (https://shapely.readthedocs.io/en/latest/migration_pygeos.html).
  import geopandas as gpd


In [2]:
source = "OSM"
# source = "Aus"
# source = "NZ"
# source = "Validation"

In [3]:
data_dir = Path(
    "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data v2/"
)
data_dir.exists()

True

In [4]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [5]:
vector_points_path = data_dir / f"{source}/{source} training data.gpkg"
print(vector_points_path.exists())
vector_points = gpd.read_file(vector_points_path)

True


In [6]:
ocean_path = data_dir / f"{source}/{source} polygons.gpkg"
print(ocean_path.exists())
coastline_path = data_dir / f"{source}/{source} lines.gpkg"
print(coastline_path.exists())
prefix = f"{source}_80"

True
True


In [7]:
images_path = Path.cwd() / "training data" / "images_2_3_4_8_V3"
labels_path = Path.cwd() / "training data" / "labels_2_3_4_8_V3"
labels_path.mkdir(exist_ok=True, parents=True)
images_path.mkdir(exist_ok=True, parents=True)

In [8]:
bands = ["B02", "B03", "B04", "B08"]
time_steps = 6

In [9]:
def find_prods(point, time_of_interest="2022-01-01/2023-01-01"):
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        intersects=point,
        datetime=time_of_interest,
        query={"eo:cloud_cover": {"lt": 80}},
    )
    items = search.item_collection()
    return items

In [10]:
def wgs_point_to_local_box(product, wgs_point, vector_data_crs_number):
    local_crs_number = product.properties["proj:epsg"]
    source_crs = pyproj.CRS(f"EPSG:{vector_data_crs_number}")
    target_crs = pyproj.CRS(f"EPSG:{local_crs_number}")
    transformer = pyproj.Transformer.from_crs(source_crs, target_crs, always_xy=True)
    x, y = wgs_point.x, wgs_point.y
    x_transformed, y_transformed = transformer.transform(x, y)
    point_utm = Point(x_transformed, y_transformed)
    point_buffer = point_utm.buffer(2560 / 2)
    return point_buffer.bounds

In [11]:
# get epsg code
vector_data_crs_number = vector_points.crs.to_epsg()
vector_data_crs_number

4326

In [12]:
total_bands = time_steps * len(bands)
total_bands

24

In [13]:
# for row in tqdm(vector_points.iterrows(), total=len(vector_points)):
def downlaod_image(row):
    try:
        wgs_point = row[1].geometry

        if wgs_point is None:
            return

        export_name = f"{row[0]}_{prefix}.tif"
        export_path = images_path / export_name

        if export_path.stem in skip_list:
            return

        if export_path.exists():
            return

        search_point = {"type": "Point", "coordinates": [wgs_point.x, wgs_point.y]}

        products = find_prods(search_point)
        arrays = []
        # print(products)
        for product in products:
            # print(product.id)
            b_box = wgs_point_to_local_box(product, wgs_point, vector_data_crs_number)
            band_arrays = []
            skip = False
            for band in bands:
                if skip:
                    continue
                # time.sleep(2)
                with rio.open(product.assets[band].href) as src:
                    window = rio.windows.from_bounds(*b_box, src.transform)

                    array = src.read(1, window=window)
                    if array.shape != (256, 256):
                        print(f"Array shape is {array.shape} for {export_name}")
                        skip = True
                        continue
                    transform = rio.windows.transform(window, transform=src.transform)
                    profile = src.profile.copy()
                    if band == bands[0]:
                        # print(np.count_nonzero(array == 0))
                        if np.count_nonzero(array == 0) > 100:
                            skip = True
                            continue

                    band_arrays.append(array)

            for band_array in band_arrays:
                arrays.append(band_array)

            if len(arrays) == total_bands:
                break

        if len(arrays) != total_bands:
            print(f"Could not find 3 images for {export_name}")
            time.sleep(2)
            return
        profile.update(
            {
                "height": array.shape[0],
                "width": array.shape[1],
                "transform": transform,
                "count": total_bands,
            }
        )
        with rio.open(export_path, "w", **profile) as dst:
            dst.write(np.array(arrays))
    except Exception as e:
        print(e)
        print(f"Failed to download {export_name}")
        time.sleep(2)
        return

In [14]:
vector_points = vector_points.iloc[::-1]

In [15]:
skip_list = [
    "614_OSM_80",
    "611_OSM_80",
    "579_OSM_80",
    "577_OSM_80",
    "554_OSM_80",
    "550_OSM_80",
    "536_OSM_80",
    "533_OSM_80",
    "531_OSM_80",
    "526_OSM_80",
    "523_OSM_80",
    "509_OSM_80",
    "490_OSM_80",
    "478_OSM_80",
    "466_OSM_80",
    "465_OSM_80",
    "462_OSM_80",
    "797_OSM_80",
    "798_OSM_80",
    "791_OSM_80",
    "777_OSM_80",
    "775_OSM_80",
    "758_OSM_80",
    "716_OSM_80",
    "706_OSM_80",
    "691_OSM_80",
    "325_OSM_80",
    "355_OSM_80",
    "913_OSM_80",
    "856_OSM_80",
    "841_OSM_80",
    "919_OSM_80",
    "978_OSM_80",
    "995_OSM_80",
    "1295_OSM_80",
    "1336_OSM_80",
    "1374_OSM_80",
    "1411_OSM_80",
    "1462_OSM_80",
]

In [16]:
def worker(i):
    if i[0] not in skip_list:
        downlaod_image(i)


# Number of threads you want to run in parallel
num_threads = 4

with ThreadPoolExecutor(max_workers=num_threads) as executor:
    list(tqdm(executor.map(worker, vector_points.iterrows()), total=len(vector_points)))

  3%|▎         | 49/1631 [05:53<2:09:02,  4.89s/it]

Array shape is (256, 149) for 1576_OSM_80.tif


  3%|▎         | 51/1631 [06:06<2:21:08,  5.36s/it]

Array shape is (256, 149) for 1576_OSM_80.tif


  3%|▎         | 53/1631 [06:12<1:59:15,  4.53s/it]

Array shape is (256, 149) for 1576_OSM_80.tif
Array shape is (256, 149) for 1576_OSM_80.tif


100%|██████████| 1631/1631 [13:13<00:00,  2.06it/s] 


In [17]:
def rasterise(args):
    if args is None:
        return
    (
        label_export_path,
        coastline_lines_clipped,
        clipped_gdf,
        profile,
        array_transform,
        max_dist,
    ) = args
    if label_export_path.exists():
        return
    clipped_gdf["geometry"] = clipped_gdf["geometry"].buffer(0)
    array = np.zeros((profile["height"], profile["width"]), dtype=np.float32)
    for row in range(profile["height"]):
        for col in range(profile["width"]):
            x, y = array_transform * (col + 0.5, row + 0.5)
            point = Point(x, y)

            if len(coastline_lines_clipped) == 0:
                min_distance = max_dist

            else:
                min_distance = min(
                    geom.distance(point) for geom in coastline_lines_clipped["geometry"]
                )

            if min_distance > max_dist:
                min_distance = max_dist
            min_distance = sqrt(min_distance)
            if clipped_gdf.intersects(point).any():
                array[row, col] = -min_distance
            else:
                array[row, col] = min_distance

    profile.update({"count": 1, "dtype": "float32", "compress": "lzw"})

    with rio.open(label_export_path, "w", **profile) as dst:
        dst.write(array, 1)

In [18]:
def make_label(image, coastline_lines, ocean_polygons, max_dist=50):
    try:
        label_export_path = labels_path / image.name

        if label_export_path.exists():
            return

        with rio.open(image) as src:
            pixel_size = src.transform[0]
            extent = src.bounds
            raster_poly = box(*extent)
            raster_crs = src.crs
            profile = src.profile.copy()

        transformer = Transformer.from_crs(
            raster_crs, ocean_polygons.crs, always_xy=True
        )
        reprojected_polygon = transform(transformer.transform, raster_poly)

        clipped_gdf = (
            gpd.clip(ocean_polygons, reprojected_polygon).dissolve().to_crs(raster_crs)
        )

        coastline_lines_clipped = gpd.clip(coastline_lines, reprojected_polygon).to_crs(
            raster_crs
        )

        minx, miny, maxx, maxy = extent
        array_transform = Affine.translation(minx, maxy) * Affine.scale(
            pixel_size, -pixel_size
        )
        args = (
            label_export_path,
            coastline_lines_clipped,
            clipped_gdf,
            profile,
            array_transform,
            max_dist,
        )
        #
        return args
    except Exception as e:
        print(e)
        print(f"Failed to make label for {image.name}")
        return

In [19]:
images = list(images_path.glob(f"*{prefix}.tif"))
len(images)

1614

In [20]:
ocean_polygons = gpd.read_file(ocean_path)

In [21]:
coastline = gpd.read_file(coastline_path)

In [22]:
args_list = []
for image in tqdm(images):
    args_list.append(
        make_label(image, coastline_lines=coastline, ocean_polygons=ocean_polygons)
    )

100%|██████████| 1614/1614 [02:13<00:00, 12.05it/s]


In [23]:
with Pool() as p:
    list(tqdm(p.imap(rasterise, args_list), total=len(args_list)))

100%|██████████| 1614/1614 [13:41<00:00,  1.96it/s]


In [24]:
images = images_path.glob("*.tif")
lebels = labels_path.glob("*.tif")

In [25]:
remove_list = []
for image in tqdm(images):
    expected_label = labels_path / image.name
    if not expected_label.exists():
        print(f"Missing label for {image.name}")
        remove_list.append(image)

2836it [00:00, 12151.64it/s]


In [26]:
len(remove_list)

0

In [27]:
# for i in remove_list:
#     os.remove(i)