# Generate Labeled Training Images from NAIP and GM-SEUS

In [31]:
# Import necessary libraries
import os
import random
import rasterio
import geopandas as gpd
import numpy as np
from shapely.geometry import Point
from rasterio.windows import Window
from rasterio.mask import mask
from rasterio.warp import calculate_default_transform, reproject, Resampling


# Load config file
def load_config(filename):
    config = {}
    with open(filename, 'r') as f:
        for line in f:
            # Strip whitespace and split by '='
            key, value = line.strip().split('=')
            # Try to convert to numeric values if possible
            try:
                value = float(value) if '.' in value else int(value)
            except ValueError:
                pass  # Leave as string if not a number
            config[key] = value
    return config

## Set Paths and Variables

In [67]:
# Set folder paths
wd = r'S:\Users\stidjaco\R_files\BigPanel'
downloaded_path = os.path.join(wd, r'Data\Downloaded')
derived_path = os.path.join(wd, r'Data\Derived')
derivedTemp_path = os.path.join(derived_path, r'intermediateProducts')
figure_path = os.path.join(wd, r'Figures')

# Set a final gmseus arrays path
gmseusArraysFinalPath = os.path.join(derived_path, r'GMSEUS/GMSEUS_Arrays_Final.shp')

# Whole labeled images path
labelImgsWholePath = os.path.join(derived_path, r'labelImgsWhole')

# Set out paths for images and masks
imagesPath = os.path.join(derived_path, r'GMSEUS/LabeledImages/images') # inputs
masksPath = os.path.join(derived_path, r'GMSEUS/LabeledImages/masks') # targets

# Load the config from the text file
config = load_config(os.path.join(wd, 'Code/config.txt'))

# Set input CRS (native output from GEE) and a toCRS (EPSG:5070)
# geeCRS = 'EPSG:4326'
# toCRS = 'EPSG:5070'

# Set image resolution
res = 0.6

# Set dimensions
dim = 256 # 256x256 images

# Set the tile area proprtion: Proportional area of the array that sets the number of tiles
tileAreaProp = 0.2 # 20% of the array, so if an array is 1 km2, will generate tiles for 0.2 km2 within the array (or one for small arrays)

## Helper Functions

In [68]:
# Function to check if folder exists, if not create it
def checkFolder(folder):
    if not os.path.exists(folder):
        os.makedirs(folder)

# Process Imagery into Images and Masks

## Setup

In [69]:
# Check if folders exist
checkFolder(imagesPath)
checkFolder(masksPath)

# Call gmseus arrays
gmseusArrays = gpd.read_file(gmseusArraysFinalPath)

# # Transform gmseus arrays to toCRS
# gmseusArrays = gmseusArrays.to_crs(toCRS)

# Set number of tiles to generate per array
tileArea = dim * dim * res * res  # Area of each tile in square meters

In [None]:
# Loop through each labeled image and extract tiles 
for image_filename in os.listdir(labelImgsWholePath):
    if not image_filename.endswith('.tif'):
        continue  # Skip non-TIFF files

    # Extract array ID from filename
    array_id = image_filename.replace("id", "").replace(".tif", "")
    try:
        array_id = int(array_id)  # Convert to integer
    except ValueError:
        continue  # Skip if the filename isn't in expected format

    # Open the geotiff image
    image_path = os.path.join(labelImgsWholePath, image_filename)
    with rasterio.open(image_path) as src:
        print(f"Processing {image_filename} - Original CRS: {src.crs}")

        # Use the image's existing CRS (don't force-set it). If it's missing, print a warning and pass this image
        src_crs = src.crs
        if not src_crs:
            print(f"Warning: No CRS found for image {image_filename}. Skipping.")
            continue

        # Set toCRS to the image's CRS
        toCRS = src_crs

        # # Check if reprojection is needed -- For now, maintain the original CRS
        # if src_crs != toCRS:
        #     print(f"Reprojecting {image_filename} to {toCRS}...")

        #     # Calculate new transformation
        #     transform, width, height = calculate_default_transform(
        #         src_crs, toCRS, src.width, src.height, *src.bounds
        #     )

        #     # Update metadata for reprojected image
        #     kwargs = src.meta.copy()
        #     kwargs.update({
        #         'crs': toCRS,
        #         'transform': transform,
        #         'width': width,
        #         'height': height
        #     })

        #     # Perform reprojection and keep MemoryFile open
        #     memfile = rasterio.MemoryFile()
        #     with memfile.open(**kwargs) as dst:
        #         for i in range(1, src.count + 1):
        #             reproject(
        #                 source=rasterio.band(src, i),
        #                 destination=rasterio.band(dst, i),
        #                 src_transform=src.transform,
        #                 src_crs=src_crs,
        #                 dst_transform=transform,
        #                 dst_crs=toCRS,
        #                 resampling=Resampling.nearest
        #             )

        #     # Keep the MemoryFile dataset open for reading
        #     reprojected_src = memfile.open()  # Keep file open!
        # else:
        #     reprojected_src = src  # Use original if no reprojection needed

        # Set reprojected_src to the original src
        reprojected_src = src

        # Find the corresponding polygon for this array ID
        array_row = gmseusArrays[gmseusArrays['arrayID'] == array_id]
        if array_row.empty:
            print(f"Warning: No matching polygon found for array ID {array_id}. Skipping.")
            continue

        # Trasform array_row to desired CRS and get the geometry
        array_row = array_row.to_crs(toCRS)
        array_geom = array_row.geometry.values[0]

        # From the array_row, get the totArea attribute and calculate the number of tiles to generate
        totArea = array_row['totArea'].values[0]

        # Set the number of tiles to be (totArea / tileArea) * tileAreaProp OR 1 if the result is less than 1
        numTiles = max(1, int((totArea / tileArea) * tileAreaProp))

        # Generate numTiles number of random points inside the polygon
        for tile_idx in range(numTiles):
            while True:
                # Generate a random point within the bounds of the polygon
                minx, miny, maxx, maxy = array_geom.bounds
                random_point = Point(random.uniform(minx, maxx), random.uniform(miny, maxy))
                
                # Ensure point is inside the polygon
                if array_geom.contains(random_point):
                    break
            
            # Convert point to pixel coordinates using the correct DatasetReader
            x, y = random_point.x, random_point.y
            row, col = reprojected_src.index(x, y) 

            # Define window for cropping the raster (ensuring it fits inside the raster bounds)
            window = Window(col - dim // 2, row - dim // 2, dim, dim)

            # Read bands
            img_bands = reprojected_src.read([1, 2, 3, 4], window=window)  # Assuming bands are R, G, B, N
            mask_band = reprojected_src.read(5, window=window)  # Assuming mask is stored in band 5

            # Save image bands
            img_outfile = os.path.join(imagesPath, f"id{array_id}_tile{tile_idx}.tif")
            with rasterio.open(
                img_outfile,
                'w',
                driver='GTiff',
                height=dim,
                width=dim,
                count=4,
                dtype=img_bands.dtype,
                crs=toCRS,
                transform=reprojected_src.window_transform(window)
            ) as dst:
                for i in range(4):
                    dst.write(img_bands[i], i + 1)

            # Save mask band
            mask_outfile = os.path.join(masksPath, f"id{array_id}_tile{tile_idx}.tif")
            with rasterio.open(
                mask_outfile,
                'w',
                driver='GTiff',
                height=dim,
                width=dim,
                count=1,
                dtype=mask_band.dtype,
                crs=toCRS,
                transform=reprojected_src.window_transform(window)
            ) as dst:
                dst.write(mask_band, 1)

            print(f"Saved image: {img_outfile} and mask: {mask_outfile}")

        # Close reprojected dataset after processing all tiles
        reprojected_src.close()
        if 'memfile' in locals():  # Close MemoryFile if it was used
            memfile.close()


Processing id12.tif - Original CRS: EPSG:26911
Saved image: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/images\id12_tile0.tif and mask: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/masks\id12_tile0.tif
Processing id14.tif - Original CRS: EPSG:26911
Saved image: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/images\id14_tile0.tif and mask: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/masks\id14_tile0.tif
Processing id16.tif - Original CRS: EPSG:26911
Saved image: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/images\id16_tile0.tif and mask: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/masks\id16_tile0.tif
Processing id2.tif - Original CRS: EPSG:26911
Saved image: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/images\id2_tile0.tif and mask: S:\Users\stidjaco\R_files\BigPanel\Data\Derived\GMSEUS/LabeledImages/masks\id2_tile0.tif