In [1]:
from osgeo_utils.gdal_pansharpen import gdal_pansharpen
import rasterio
from rasterio.enums import Resampling
import cv2
import gc
import numpy as np

The next block of code was taken from Thomas Wang's GitHub repository (https://github.com/ThomasWangWeiHong/Simple-Pansharpening-Algorithms/blob/master/Simple_Pansharpen.py). For the purpose of this project, only the 'Simple Mean' method was considered.

In [2]:
def pansharpen_simple_mean(m, pan, psh):
    """ 
    Inputs:
    - m: File path of multispectral image to undergo pansharpening
    - pan: File path of panchromatic image to be used for pansharpening
    - psh: File path of pansharpened multispectral image to be written to file
  
    """
  
    with rasterio.open(m) as f:
        metadata_ms = f.profile
        img_ms = np.transpose(f.read(tuple(np.arange(metadata_ms['count']) + 1)), [1, 2, 0])
    
    with rasterio.open(pan) as g:
        metadata_pan = g.profile
        img_pan = g.read(1)
  
    ms_to_pan_ratio = metadata_ms['transform'][0] / metadata_pan['transform'][0]
    rescaled_ms = cv2.resize(img_ms, dsize = None, fx = ms_to_pan_ratio, fy = ms_to_pan_ratio, 
                             interpolation = cv2.INTER_CUBIC).astype(metadata_ms['dtype'])

  
    if img_pan.shape[0] < rescaled_ms.shape[0]:
        ms_row_bigger = True
        rescaled_ms = rescaled_ms[: img_pan.shape[0], :, :]
    else:
        ms_row_bigger = False
        img_pan = img_pan[: rescaled_ms.shape[0], :]
        
    if img_pan.shape[1] < rescaled_ms.shape[1]:
        ms_column_bigger = True
        rescaled_ms = rescaled_ms[:, : img_pan.shape[1], :]
    else:
        ms_column_bigger = False
        img_pan = img_pan[:, : rescaled_ms.shape[1]]
  
    del img_ms; gc.collect()
  
  
    if ms_row_bigger == True and ms_column_bigger == True:
        img_psh = np.zeros((img_pan.shape[0], img_pan.shape[1], rescaled_ms.shape[2]), dtype = metadata_pan['dtype'])
    elif ms_row_bigger == False and ms_column_bigger == True:
        img_psh = np.zeros((rescaled_ms.shape[0], img_pan.shape[1], rescaled_ms.shape[2]), dtype = metadata_pan['dtype'])
        metadata_pan['height'] = rescaled_ms.shape[0]
    elif ms_row_bigger == True and ms_column_bigger == False:
        img_psh = np.zeros((img_pan.shape[0], rescaled_ms.shape[1], rescaled_ms.shape[2]), dtype = metadata_pan['dtype'])
        metadata_pan['width'] = rescaled_ms.shape[1]
    else:
        img_psh = np.zeros((rescaled_ms.shape), dtype = metadata_pan['dtype'])
        metadata_pan['height'] = rescaled_ms.shape[0]
        metadata_pan['width'] = rescaled_ms.shape[1]
       
  
    # This is the core of the function where the simple mean is applied 
    for band in range(rescaled_ms.shape[2]):
        img_psh[:, :, band] = 0.5 * (rescaled_ms[:, :, band] + img_pan)
         
  
    del img_pan, rescaled_ms; gc.collect()
  
    
    metadata_pan['count'] = img_psh.shape[2]
    with rasterio.open(psh, 'w', **metadata_pan) as dst:
        dst.write(np.transpose(img_psh, [2, 0, 1]))
  
    return img_psh

In [3]:
# If user selects Simple Mean pansharpening, Thomas Wang's method is applied. Otherwise, GDAL is used
def pansharpen(pan_name, spectral_names, band_nums, weights, dst_filename, resampling, spat_adjust, bitdepth, nodata_value, simple_mean):
    """ 
    Inputs:
    - pan_name: File path of the higher resolution image to be used for pansharpening 
    - spectral_names: File path of the coarser image to undergo pansharpening
    - band_nums: bands in the coarser image to undergo pansharpening when not applied to the whole dataset
    - weights: Specify a weight for the computation of the pseudo panchromatic value. There must be as many -w switches as input spectral bands
    - dst_filename: File path of pansharpened dataset to be written to file
    - resampling: Select a resampling algorithm (nearest, bilinear, cubic [default], cubicspline, lanczos, average)
    - spat_adjust: Select behavior when bands have not the same extent (union [default], intersection, none, nonewithoutwarning)
    - bitdepth: Specify the bit depth of the panchromatic and spectral bands (e.g. 12). If not specified, the NBITS metadata item from the panchromatic band will be used if it exists.
    - nodata_value: Specify nodata value for bands. Used for the resampling and pan-sharpening computation itself. If not set, deduced from the input bands, provided they have a consistent setting.
    - simple_mean: if True, pansharpening is performed using the pansharpen_simple_mean() function. Otherwise, gdal_pansharpen() is selected
  
    """
    if simple_mean == True:
        pansharpen_simple_mean(spectral_names, pan_name, dst_filename)
    else:
        gdal_pansharpen(pan_name, spectral_names, band_nums, weights, dst_filename, resampling, spat_adjust, bitdepth, nodata_value)

    return dst_filename


In [4]:
# List of resampling algorithms
resampling_techniques = [
    Resampling.nearest,
    Resampling.bilinear,
    Resampling.cubic,
    Resampling.cubic_spline,
    Resampling.lanczos,
    Resampling.average,
    Resampling.mode,
    Resampling.gauss,
    Resampling.max,
    Resampling.min,
    Resampling.med,
    Resampling.q1,
    Resampling.q3,
    Resampling.sum,
    Resampling.rms,
]

In [5]:
def create_pyramids(dataset, resampling):
    """ 
    Inputs:
    - dataset: File path of pansharpened dataset
    - resampling: Select a resampling algorithm
    """ 
    factors = [2, 4, 8, 16]
    for i in resampling_techniques:
        if resampling == str(resampling_techniques[i])[11:]:
            dst = rasterio.open(dataset, 'r+')
            dst.build_overviews(factors, resampling_techniques[i])
            dst.update_tags(ns='rio_overview', resampling=str(resampling_techniques[i])[11:])
            dst.close()
            break # exit out of loop here

In [6]:
%load_ext lab_black