In [20]:
import torch

In [1]:
from dataclasses import dataclass
import os
import glob
from tempfile import TemporaryDirectory
from typing import List, Any, Dict

from shapely.geometry import box, mapping
import rasterio
from rasterio.warp import reproject, Resampling
import pyproj
from osgeo import gdal

from pystac_client import Client
import planetary_computer as pc

In [19]:
chip_paths = []
txt_files = glob.glob("./Data/ms-dataset-chips/**/s2/**/LabelWater.tif", recursive=True)
for file in txt_files:
    chip_paths.append(file)
STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1"
catalog = Client.open(STAC_API)

In [18]:
# Specify the path to the parent directory containing nested folders
parent_dir = "./Data/ms-dataset-chips"
chip_paths = []
# Iterate over each folder in the parent directory
for folder in os.listdir(parent_dir):
    folder_path = os.path.join(parent_dir, folder)
    if os.path.isdir(folder_path):
        # Iterate over files in each folder
        for file in os.listdir(folder_path):
            folder_path = os.path.join(folder_path, file)
            for i in os.listdir(folder_path):
                if file.endswith("WaterLabel.tif"):
                    chip_paths.append(os.path.join(folder_path, file))
STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1"
catalog = Client.open(STAC_API)

FileNotFoundError: [WinError 3] Das System kann den angegebenen Pfad nicht finden: './Data/ms-dataset-chips\\0c7daa97-37f6-4862-867f-b3843f298d9e\\s1\\s2'

In [20]:
STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1"
catalog = Client.open(STAC_API)

In [21]:
@dataclass
class ChipInfo:
    """
    Holds information about a training chip, including geospatial info for coregistration
    """

    path: str
    prefix: str
    crs: Any
    shape: List[int]
    transform: List[float]
    bounds: rasterio.coords.BoundingBox
    footprint: Dict[str, Any]


def get_footprint(bounds, crs):
    """Gets a GeoJSON footprint (in epsg:4326) from rasterio bounds and CRS"""
    transformer = pyproj.Transformer.from_crs(crs, "epsg:4326", always_xy=True)
    minx, miny = transformer.transform(bounds.left, bounds.bottom)
    maxx, maxy = transformer.transform(bounds.right, bounds.top)
    return mapping(box(minx, miny, maxx, maxy))


def get_chip_info(chip_path):
    """Gets chip info from a GeoTIFF file"""
    with rasterio.open(chip_path) as ds:
        chip_crs = ds.crs
        chip_shape = ds.shape
        chip_transform = ds.transform
        chip_bounds = ds.bounds

    # Use the first part of the chip filename as a prefix
    prefix = os.path.basename(chip_path).split(".")[0]

    return ChipInfo(
        path=chip_path,
        prefix=prefix,
        crs=chip_crs,
        shape=chip_shape,
        transform=chip_transform,
        bounds=chip_bounds,
        footprint=get_footprint(chip_bounds, chip_crs),
    )

In [22]:
def reproject_to_chip(
    chip_info, input_path, output_path, resampling=Resampling.nearest
):
    """
    Reproject a raster at input_path to chip_info, saving to output_path.

    Use Resampling.nearest for classification rasters. Otherwise use something
    like Resampling.bilinear for continuous data.
    """
    with rasterio.open(input_path) as src:
        kwargs = src.meta.copy()
        kwargs.update(
            {
                "crs": chip_info.crs,
                "transform": chip_info.transform,
                "width": chip_info.shape[1],
                "height": chip_info.shape[0],
                "driver": "GTiff",
            }
        )

        with rasterio.open(output_path, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=chip_info.transform,
                    dst_crs=chip_info.crs,
                    resampling=Resampling.nearest,
                )

In [23]:
def write_vrt(items, asset_key, dest_path):
    """Write a VRT with hrefs extracted from a list of items for a specific asset."""
    hrefs = [pc.sign(item.assets[asset_key].href) for item in items]
    vsi_hrefs = [f"/vsicurl/{href}" for href in hrefs]
    gdal.BuildVRT(dest_path, vsi_hrefs).FlushCache()


def create_chip_aux_file(
    chip_info, collection_id, asset_key, file_name, resampling=Resampling.nearest
):
    """
    Write an auxiliary chip file.

    The auxiliary chip file includes chip_info for the Collection and Asset, and is
    saved in the same directory as the original chip with the given file_name.
    """
    output_path = os.path.join(
        os.path.dirname(chip_info.path), f"{chip_info.prefix}_{file_name}"
    )
    if os.path.exists(output_path):
        print(f"  ... {chip_info.prefix}_{file_name} already exists")
        return
    search = catalog.search(collections=[collection_id], intersects=chip_info.footprint)
    items = list(search.get_items())
    with TemporaryDirectory() as tmp_dir:
        vrt_path = os.path.join(tmp_dir, "source.vrt")
        write_vrt(items, asset_key, vrt_path)
        reproject_to_chip(chip_info, vrt_path, output_path, resampling=resampling)
    return output_path

In [24]:
# Only permanent water
aux_file_params = [("jrc-gsw", "occurrence", "jrc-gsw-occurrence.tif", Resampling.nearest)]

# Iterate over the chips and generate all aux input files.
chip_paths = chip_paths
count = len(chip_paths)
print(count)
for i, chip_path in enumerate(chip_paths):
    print(f"({i+1} of {count})")
    chip_info = get_chip_info(chip_path)
    for collection_id, asset_key, file_name, resampling_method in aux_file_params:
        print(f"  ... Creating chip data for {collection_id} {asset_key}")
        create_chip_aux_file(
            chip_info, collection_id, asset_key, file_name, resampling=resampling_method
        )

900
(1 of 900)
  ... Creating chip data for jrc-gsw occurrence




(2 of 900)
  ... Creating chip data for jrc-gsw occurrence
(3 of 900)
  ... Creating chip data for jrc-gsw occurrence
(4 of 900)
  ... Creating chip data for jrc-gsw occurrence
(5 of 900)
  ... Creating chip data for jrc-gsw occurrence
(6 of 900)
  ... Creating chip data for jrc-gsw occurrence
(7 of 900)
  ... Creating chip data for jrc-gsw occurrence
(8 of 900)
  ... Creating chip data for jrc-gsw occurrence
(9 of 900)
  ... Creating chip data for jrc-gsw occurrence
(10 of 900)
  ... Creating chip data for jrc-gsw occurrence
(11 of 900)
  ... Creating chip data for jrc-gsw occurrence
(12 of 900)
  ... Creating chip data for jrc-gsw occurrence
(13 of 900)
  ... Creating chip data for jrc-gsw occurrence
(14 of 900)
  ... Creating chip data for jrc-gsw occurrence
(15 of 900)
  ... Creating chip data for jrc-gsw occurrence
(16 of 900)
  ... Creating chip data for jrc-gsw occurrence
(17 of 900)
  ... Creating chip data for jrc-gsw occurrence
(18 of 900)
  ... Creating chip data for jrc-gsw