In [1]:
from osgeo_utils.gdal_pansharpen import gdal_pansharpen
import rasterio
from rasterio.enums import Resampling
import cv2
import gc
import numpy as np
from osgeo import gdal
import rioxarray
import xarray as xr

In [2]:
%load_ext lab_black

The next block of code was taken from Thomas Wang's GitHub repository (https://github.com/ThomasWangWeiHong/Simple-Pansharpening-Algorithms/blob/master/Simple_Pansharpen.py). For the purpose of this project, only the 'Simple Mean' method was considered.

In [3]:
def pansharpen_simple_mean(m, pan, psh):
    """
    Inputs:
    - m: File path of multispectral image to undergo pansharpening
    - pan: File path of panchromatic image to be used for pansharpening
    - psh: File path of pansharpened multispectral image to be written to file

    """

    with rasterio.open(m) as f:
        metadata_ms = f.profile
        img_ms = np.transpose(
            f.read(tuple(np.arange(metadata_ms["count"]) + 1)), [1, 2, 0]
        )

    with rasterio.open(pan) as g:
        metadata_pan = g.profile
        img_pan = g.read(1)

    ms_to_pan_ratio = metadata_ms["transform"][0] / metadata_pan["transform"][0]
    rescaled_ms = cv2.resize(
        img_ms,
        dsize=None,
        fx=ms_to_pan_ratio,
        fy=ms_to_pan_ratio,
        interpolation=cv2.INTER_CUBIC,
    ).astype(metadata_ms["dtype"])

    if img_pan.shape[0] < rescaled_ms.shape[0]:
        ms_row_bigger = True
        rescaled_ms = rescaled_ms[: img_pan.shape[0], :, :]
    else:
        ms_row_bigger = False
        img_pan = img_pan[: rescaled_ms.shape[0], :]

    if img_pan.shape[1] < rescaled_ms.shape[1]:
        ms_column_bigger = True
        rescaled_ms = rescaled_ms[:, : img_pan.shape[1], :]
    else:
        ms_column_bigger = False
        img_pan = img_pan[:, : rescaled_ms.shape[1]]

    del img_ms
    gc.collect()

    if ms_row_bigger == True and ms_column_bigger == True:
        img_psh = np.zeros(
            (img_pan.shape[0], img_pan.shape[1], rescaled_ms.shape[2]),
            dtype=metadata_pan["dtype"],
        )
    elif ms_row_bigger == False and ms_column_bigger == True:
        img_psh = np.zeros(
            (rescaled_ms.shape[0], img_pan.shape[1], rescaled_ms.shape[2]),
            dtype=metadata_pan["dtype"],
        )
        metadata_pan["height"] = rescaled_ms.shape[0]
    elif ms_row_bigger == True and ms_column_bigger == False:
        img_psh = np.zeros(
            (img_pan.shape[0], rescaled_ms.shape[1], rescaled_ms.shape[2]),
            dtype=metadata_pan["dtype"],
        )
        metadata_pan["width"] = rescaled_ms.shape[1]
    else:
        img_psh = np.zeros((rescaled_ms.shape), dtype=metadata_pan["dtype"])
        metadata_pan["height"] = rescaled_ms.shape[0]
        metadata_pan["width"] = rescaled_ms.shape[1]

    # This is the core of the function where the simple mean is applied
    for band in range(rescaled_ms.shape[2]):
        img_psh[:, :, band] = 0.5 * (rescaled_ms[:, :, band] + img_pan)

    del img_pan, rescaled_ms
    gc.collect()

    metadata_pan["count"] = img_psh.shape[2]
    with rasterio.open(psh, "w", **metadata_pan) as dst:
        dst.write(np.transpose(img_psh, [2, 0, 1]))

    return img_psh

Here I was trying to get both pansharpening functions into just a big one, but it only works for simple mean since when I try to add more parameters, I get "TypeError: not a sequence".
Should I drop this idea and only have the two functions separate?

In [4]:
# If user selects Simple Mean pansharpening, Thomas Wang's method is applied. Otherwise, gdal_pansharpen() is used
def pansharpen(
    pan_name,
    spectral_names,
    dst_filename,
    band_nums=None,
    weights=None,
    resampling=None,
    spat_adjust=None,
    bitdepth=None,
    nodata_value=False,
    simple_mean=False,
):
    """
    Inputs:
    - pan_name: File path of the higher resolution image to be used for pansharpening
    - spectral_names: File path of the coarser image to undergo pansharpening
    - band_nums: bands in the coarser image to undergo pansharpening when not applied to the whole dataset
    - weights: Specify a weight for the computation of the pseudo panchromatic value. There must be as many -w switches as input spectral bands
    - dst_filename: File path of pansharpened dataset to be written to file
    - resampling: Select a resampling algorithm (nearest, bilinear, cubic [default], cubicspline, lanczos, average)
    - spat_adjust: Select behavior when bands have not the same extent (union [default], intersection, none, nonewithoutwarning)
    - bitdepth: Specify the bit depth of the panchromatic and spectral bands (e.g. 12). If not specified, the NBITS metadata item from the panchromatic band will be used if it exists.
    - nodata_value: Specify nodata value for bands. Used for the resampling and pan-sharpening computation itself. If not set, deduced from the input bands, provided they have a consistent setting.
    - simple_mean: if True, pansharpening is performed using the pansharpen_simple_mean() function. Otherwise, gdal_pansharpen() is selected

    """
    if simple_mean == True:
        pansharpen_simple_mean(spectral_names, pan_name, dst_filename)
    else:
        gdal_pansharpen(
            pan_name,
            spectral_names,
            band_nums,
            weights,
            dst_filename,
            resampling,
            spat_adjust,
            bitdepth,
            nodata_value,
        )

    return dst_filename

This function is writing the coordinate system but the rasters are displayed in the middle of the Pacific Ocean on ArcGIS. Is there something I am missing?

In [5]:
# When the user wants to stacks the multispectral with the panchromatic bands into one single raster dataset
def stack_bands(pan_name, spectral_names, dst_filename):
    """
    Inputs:
    - pan_name: File path of the higher resolution image (panchromatic band)
    - spectral_names: File path of the pansharpened image
    - dst_filename: File path of the stacked dataset
    """
    bands = list()
    panchromatic = rioxarray.open_rasterio(pan_name)
    multispectral = rioxarray.open_rasterio(spectral_names)

    # All the bands are added to a list
    bands.append(panchromatic.sel(band=1))

    for file in range(multispectral.rio.count):
        bands.append(multispectral.sel(band=file + 1))

    # The bands are stacked into a single array
    bands_array = xr.DataArray(bands)
    stack = bands_array.stack()

    # Spatial reference
    crs = int(multispectral.rio.crs.to_proj4()[11:])
    stack.rio.write_crs(crs, inplace=True).rio.set_spatial_dims(
        y_dim=stack.dims[1], x_dim=stack.dims[2]
    ).rio.write_coordinate_system(inplace=True)

    # The data array is converted to dataset
    stack.rio.to_raster(dst_filename)

In [6]:
# List of resampling algorithms included in rasterio
resampling_techniques = [
    Resampling.nearest,
    Resampling.bilinear,
    Resampling.cubic,
    Resampling.cubic_spline,
    Resampling.lanczos,
    Resampling.average,
    Resampling.mode,
    Resampling.gauss,
    Resampling.max,
    Resampling.min,
    Resampling.med,
    Resampling.q1,
    Resampling.q3,
    Resampling.sum,
    Resampling.rms,
]

In [7]:
# Generates raster overviews
def create_pyramids(ds, resampling):
    """
    Inputs:
    - ds: File path of pansharpened dataset
    - resampling: Select a resampling algorithm
    """
    factors = [2, 4, 8, 16]
    for i in resampling_techniques:
        if resampling == str(resampling_techniques[i])[11:]:
            dst = rasterio.open(ds, "r+")
            dst.build_overviews(factors, resampling_techniques[i])
            dst.update_tags(
                ns="rio_overview", resampling=str(resampling_techniques[i])[11:]
            )
            dst.close()
            break  # exit out of loop here

## Resampling
In addition to the pansharpening tool, this project aims to create an up-and-downsampling tool where the users can change (resample) the cell size of their raster datasets without external bands.

The next cell was based on https://rasterio.readthedocs.io/en/latest/topics/resampling.html and https://pygis.io/docs/e_raster_resample.html. However, the function was unable to write the output raster successfully.

In [8]:
def resample(factor, ds, resampling):
    """
    Inputs:
    - factor: numeric entry to upsample/downsample the raster
    - ds: File path of raster dataset
    - resampling: Select a resampling algorithm

    """
    upscale_factor = factor

    for i in resampling_techniques:
        if r_technique == str(resampling_techniques[i])[11:]:

            with rasterio.open(ds) as dataset:

                # resample data to target shape
                data = dataset.read(
                    out_shape=(
                        dataset.count,
                        int(dataset.height * upscale_factor),
                        int(dataset.width * upscale_factor),
                    ),
                    resampling=resampling_techniques[i],
                )

                # scale image transform
                transform = dataset.transform * dataset.transform.scale(
                    (dataset.width / data.shape[-1]), (dataset.height / data.shape[-2])
                )

                ## Write outputs
                # set properties for output
                dst_kwargs = dataset.meta.copy()

                ### the error source is here
                dst_kwargs.update(
                    {
                        "crs": dataset.crs,
                        "transform": transform,
                        "width": data.shape[-1],
                        "height": data.shape[-2],
                        "nodata": 0,
                    }
                )

                with rasterio.open("resampled_raster.tif", "w", **dst_kwargs) as dst:
                    # iterate through bands
                    for i in range(data.shape[0]):
                        dst.write(data[i].astype(rasterio.uint32), i + 1)

                break

A new tool, gdal.Translate(), was found to create the resampled raster successfully. This tool was i) simplified to receive just a few parameters and ii) complemented with the create_pyramids() function. 

For now, this function is writing an output raster with the cells resized, but it's creating the stats for the input raster instead of the output, and ArcMap shows the pyramids are absent.

In [9]:
def resize(
    output_name,
    ds,
    Res,
    resampling,
    of,
    ot=None,
    statistics=False,
    pyramids=False,
    resampling_pyramids=None,
):
    """
    Inputs:
    - output_name: File path to the output raster resized
    - ds: File path of input dataset
    - Res: Set the size of the cells
    - resampling: Select a resampling algorithm (nearest [default], bilinear, cubic, cubicspline, lanczos, average, rms, mode)
    - of: output format
    - ot: output type (Byte, UInt16, Int16, UInt32, Int32, Float32, Float64, CInt16, CInt32, CFloat32 or CFloat64)
    - statistics: whether to calculate statistics
    - pyramids: whether to calculate pyramids
    - resampling_pyramids: resampling algorithm if pyramids == True

    """

    gdal.Translate(
        output_name,
        ds,
        xRes=Res,
        yRes=Res,
        resampleAlg=resampling,
        format=of,
        stats=statistics,
    )

    if pyramids == True:
        create_pyramids(output_name, resampling_pyramids)