In [None]:
# from pystac.extensions.eo import EOExtension as eo
import pystac_client
import planetary_computer
import geopandas as gpd
import rasterio as rio
from pathlib import Path
from shapely.geometry import Point
import pyproj
from tqdm.auto import tqdm
from rasterio import Affine
from shapely.geometry import box
from shapely.ops import transform
from pyproj import Transformer
from multiprocess import Pool
import numpy as np
from math import sqrt
import time

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [None]:
vector_points_path = Path(
    "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data v2/OSM/OSM training data.gpkg"
)
vector_points = gpd.read_file(vector_points_path)

In [None]:
# OSM_ocean_path = Path(
#     "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data/Aus coastline polygon.gpkg"
# )
# OSM_coastline_path = Path(
#     "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data/Aus coastline.gpkg"
# )
# OSM_ocean_path = Path(
#     "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data/NZ Polygons.gpkg"
# )
# OSM_coastline_path = Path(
#     "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data/NZ Lines.gpkg"
# )
# prefix = "NZ_80"

OSM_ocean_path = Path(
    "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data/OSM polygons.gpkg"
)
OSM_coastline_path = Path(
    "/Users/Nick/Library/Mobile Documents/com~apple~CloudDocs/QGIS/Coastline training data/OSM lines.gpkg"
)
prefix = "OSM_80"

In [None]:
images_path = Path.cwd() / "training data" / "images_2_3_4_8_V3"
labels_path = Path.cwd() / "training data" / "labels_2_3_4_8_V3"
labels_path.mkdir(exist_ok=True, parents=True)
images_path.mkdir(exist_ok=True, parents=True)

In [None]:
bands = ["B02", "B03", "B04", "B08"]
time_steps = 6

In [None]:
def find_prods(point, time_of_interest="2022-01-01/2023-01-01"):
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        intersects=point,
        datetime=time_of_interest,
        query={"eo:cloud_cover": {"lt": 80}},
    )
    items = search.item_collection()
    return items

In [None]:
def wgs_point_to_local_box(product, wgs_point, vector_data_crs_number):
    local_crs_number = product.properties["proj:epsg"]
    source_crs = pyproj.CRS(f"EPSG:{vector_data_crs_number}")
    target_crs = pyproj.CRS(f"EPSG:{local_crs_number}")
    transformer = pyproj.Transformer.from_crs(source_crs, target_crs, always_xy=True)
    x, y = wgs_point.x, wgs_point.y
    x_transformed, y_transformed = transformer.transform(x, y)
    point_utm = Point(x_transformed, y_transformed)
    point_buffer = point_utm.buffer(2560 / 2)
    return point_buffer.bounds

In [None]:
# get epsg code
vector_data_crs_number = vector_points.crs.to_epsg()
vector_data_crs_number

In [None]:
total_bands = time_steps * len(bands)
total_bands

In [None]:
# for row in tqdm(vector_points.iterrows(), total=len(vector_points)):
def downlaod_image(row):
    try:
        wgs_point = row[1].geometry

        if wgs_point is None:
            return

        export_name = f"{row[0]}_{prefix}.tif"
        export_path = images_path / export_name

        if export_path.stem in skip_list:
            return

        if export_path.exists():
            return

        search_point = {"type": "Point", "coordinates": [wgs_point.x, wgs_point.y]}

        products = find_prods(search_point)
        arrays = []
        # print(products)
        for product in products:
            # print(product.id)
            b_box = wgs_point_to_local_box(product, wgs_point, vector_data_crs_number)
            band_arrays = []
            skip = False
            for band in bands:
                if skip:
                    continue
                # time.sleep(2)
                with rio.open(product.assets[band].href) as src:
                    window = rio.windows.from_bounds(*b_box, src.transform)

                    array = src.read(1, window=window)
                    if array.shape != (256, 256):
                        print(f"Array shape is {array.shape} for {export_name}")
                        skip = True
                        continue
                    transform = rio.windows.transform(window, transform=src.transform)
                    profile = src.profile.copy()
                    if band == bands[0]:
                        # print(np.count_nonzero(array == 0))
                        if np.count_nonzero(array == 0) > 100:
                            skip = True
                            continue

                    band_arrays.append(array)

            for band_array in band_arrays:
                arrays.append(band_array)

            if len(arrays) == total_bands:
                break

        if len(arrays) != total_bands:
            print(f"Could not find 3 images for {export_name}")
            time.sleep(2)
            return
        profile.update(
            {
                "height": array.shape[0],
                "width": array.shape[1],
                "transform": transform,
                "count": total_bands,
            }
        )
        with rio.open(export_path, "w", **profile) as dst:
            dst.write(np.array(arrays))
    except Exception as e:
        print(e)
        print(f"Failed to download {export_name}")
        time.sleep(2)
        return

In [None]:
vector_points = vector_points.iloc[::-1]

In [None]:
skip_list = [
    "614_OSM_80",
    "611_OSM_80",
    "579_OSM_80",
    "577_OSM_80",
    "554_OSM_80",
    "550_OSM_80",
    "536_OSM_80",
    "533_OSM_80",
    "531_OSM_80",
    "526_OSM_80",
    "523_OSM_80",
    "509_OSM_80",
    "490_OSM_80",
    "478_OSM_80",
    "466_OSM_80",
    "465_OSM_80",
    "462_OSM_80",
]

In [None]:
for i in tqdm(vector_points.iterrows(), total=len(vector_points)):
    if i[0] not in skip:
        downlaod_image(i)

In [None]:
# for i in tqdm(vector_points.iterrows(), total=len(vector_points)):
#     if i[0] not in skip:
#         downlaod_image(i)

In [None]:
# with Pool(1) as p:
#     list(tqdm(p.imap(downlaod_image, vector_points.iterrows()), total=len(vector_points)))

In [None]:
def rasterise(args):
    if args is None:
        return
    (
        label_export_path,
        coastline_lines_clipped,
        clipped_gdf,
        profile,
        array_transform,
        max_dist,
    ) = args
    if label_export_path.exists():
        return
    clipped_gdf["geometry"] = clipped_gdf["geometry"].buffer(0)
    array = np.zeros((profile["height"], profile["width"]), dtype=np.float32)
    for row in range(profile["height"]):
        for col in range(profile["width"]):
            x, y = array_transform * (col + 0.5, row + 0.5)
            point = Point(x, y)

            if len(coastline_lines_clipped) == 0:
                min_distance = max_dist

            else:
                min_distance = min(
                    geom.distance(point) for geom in coastline_lines_clipped["geometry"]
                )

            if min_distance > max_dist:
                min_distance = max_dist
            min_distance = sqrt(min_distance)
            if clipped_gdf.intersects(point).any():
                array[row, col] = -min_distance
            else:
                array[row, col] = min_distance

    profile.update({"count": 1, "dtype": "float32", "compress": "lzw"})

    with rio.open(label_export_path, "w", **profile) as dst:
        dst.write(array, 1)

In [None]:
def make_label(image, coastline_lines, ocean_polygons, max_dist=50):
    try:
        label_export_path = labels_path / image.name

        if label_export_path.exists():
            return

        with rio.open(image) as src:
            pixel_size = src.transform[0]
            extent = src.bounds
            raster_poly = box(*extent)
            raster_crs = src.crs
            profile = src.profile.copy()

        transformer = Transformer.from_crs(
            raster_crs, ocean_polygons.crs, always_xy=True
        )
        reprojected_polygon = transform(transformer.transform, raster_poly)

        clipped_gdf = (
            gpd.clip(ocean_polygons, reprojected_polygon).dissolve().to_crs(raster_crs)
        )

        coastline_lines_clipped = gpd.clip(coastline_lines, reprojected_polygon).to_crs(
            raster_crs
        )

        minx, miny, maxx, maxy = extent
        array_transform = Affine.translation(minx, maxy) * Affine.scale(
            pixel_size, -pixel_size
        )
        args = (
            label_export_path,
            coastline_lines_clipped,
            clipped_gdf,
            profile,
            array_transform,
            max_dist,
        )
        #
        return args
    except Exception as e:
        print(e)
        print(f"Failed to make label for {image.name}")
        return

In [None]:
images = list(images_path.glob(f"*{prefix}.tif"))
len(images)

In [None]:
ocean_polygons = gpd.read_file(OSM_ocean_path)

In [None]:
coastline = gpd.read_file(OSM_coastline_path)

In [None]:
args_list = []
for image in tqdm(images):
    args_list.append(
        make_label(image, coastline_lines=coastline, ocean_polygons=ocean_polygons)
    )

In [None]:
with Pool() as p:
    list(tqdm(p.imap(rasterise, args_list), total=len(args_list)))

In [None]:
t = Path("/Users/Nick/Desktop/S2Coastline DL/training data/images_2_3_4_8")
files = list(t.glob("*Aus.tif"))
len(files)