# PRISMA Overview

In [None]:
import h5py
import numpy as np
#import cf_xarray as cfxr
import xarray 
import rioxarray
import holoviews as hv
from holoviews import opts
import geoviews as gv
import datashader as ds
from cartopy import crs 
import hvplot
import hvplot.pandas
from holoviews.operation.datashader import regrid, shade
from bokeh.tile_providers import STAMEN_TONER
import rasterio
from osgeo import gdal
from pathlib import Path
import panel
import hvplot.xarray
from ipywidgets import interact, Dropdown, FloatSlider, IntSlider, SelectMultiple, Text
import matplotlib.pyplot as plt
import geopandas

from sklearn.decomposition import PCA
from spectral import *

hv.extension('bokeh', width=1000)

# PRISMA L2D

In [None]:
filename = '../../../data/PRISMA/PRS_L2D_STD_20220718001152_20220718001156_0001/PRS_L2D_STD_20220718001152_20220718001156_0001.he5'
with rasterio.open(filename) as src:
    subdatasets = src.subdatasets

In [None]:
subdatasets

# h5py methods to read, display numpy arrays

In [None]:
f = h5py.File('../../../data/PRISMA/PRS_L2D_STD_20220718001152_20220718001156_0001/PRS_L2D_STD_20220718001152_20220718001156_0001.he5', 'r')

In [None]:
f.keys()

In [None]:
f['HDFEOS/SWATHS/PRS_L2D_HCO/Data Fields/VNIR_Cube']

In [None]:
latitude = f['HDFEOS/SWATHS/PRS_L2D_HCO/Geolocation Fields/Latitude']

In [None]:
longitude = f['HDFEOS/SWATHS/PRS_L2D_HCO/Geolocation Fields/Longitude']

In [None]:
time = f['HDFEOS/SWATHS/PRS_L2D_HCO/Geolocation Fields/Time']

In [None]:
vnir = f['HDFEOS/SWATHS/PRS_L2D_HCO/Data Fields/VNIR_Cube']
swir = f['HDFEOS/SWATHS/PRS_L2D_HCO/Data Fields/SWIR_Cube']

print('VNIR '+str(vnir.shape), 'SWIR '+str(swir.shape))

In [None]:
def reshapeprisma(ndarray):
    '''Consume the PRISMA 3D nparray as read by h5py and reshape to be more consistent with the y,x,band arrangement'''
    reshaped = []
    count = 0
    while count < ndarray.shape[1]:
        if count == 0:
            reshaped = ndarray[:,count,:]
            reshaped = reshaped[..., np.newaxis]
        else:
            newnp = ndarray[:,count,:]
            newnp = newnp[..., np.newaxis]   
            reshaped = np.append(reshaped, newnp, axis=2) 
        count = count+1

    return(reshaped)

In [None]:
vnirnp = np.moveaxis(vnir, 1, 2)
swirnp = np.moveaxis(swir, 1, 2)

In [None]:
vnirnp.shape

In [None]:
swirnp.shape

In [None]:
plt.imshow(vnirnp[:,:,35], cmap='pink')

In [None]:
plt.imshow(swirnp[:,:,35], cmap='pink')

In [None]:
vnirnpstd = vnirnp.std(axis=2)
swirnpstd = swirnp.std(axis=2)

In [None]:
plt.imshow(vnirnpstd, cmap='pink')

In [None]:
plt.imshow(swirnpstd, cmap='pink')

In [None]:
prismaallnp = np.concatenate((vnirnp,swirnp), axis=2)

In [None]:
prismaallnpstd = prismaallnp.std(axis=2)

In [None]:
plt.imshow(prismaallnpstd, cmap='pink')

# Export to a GIS friendly format
## Get geocoding info

In [None]:
prismainfo = gdal.Info('../../../data/PRISMA/PRS_L2D_STD_20220718001152_20220718001156_0001/PRS_L2D_STD_20220718001152_20220718001156_0001.he5')

In [None]:
prismainfo = prismainfo.split("\n")

In [None]:
prismainfodict = {}
for i in prismainfo:
    if ':' in i:
        #print(i.split(':')[0],i.split(':')[1])
        prismainfodict[i.split(':')[0].strip()] = i.split(':')[1]
    else:
        try:
            prismainfodict[i.split('=')[0].strip()] = i.split('=')[1]
        except:
            print(i)

## Extract parameters from metadata to enable their application to the array

## Confirm pixels centre or edge

In [None]:
(float(prismainfodict['Product_LLcorner_easting']) - float(prismainfodict['Product_LRcorner_easting'])) / 1194

In [None]:
(float(prismainfodict['Product_ULcorner_northing']) - float(prismainfodict['Product_LLcorner_northing'])) / 1171

## Construct the geotransform

- GT(0) x-coordinate of the upper-left corner of the upper-left pixel.
- GT(1) w-e pixel resolution / pixel width.
- GT(2) row rotation (typically zero).
- GT(3) y-coordinate of the upper-left corner of the upper-left pixel.
- GT(4) column rotation (typically zero).
- GT(5) n-s pixel resolution / pixel height (negative value for a north-up image).

In [None]:
transform = [ float(prismainfodict['Product_ULcorner_easting']), 30, 0, float(prismainfodict['Product_ULcorner_northing']), 0, -30 ]
transform

## Construct the WKT from the EPSG code

In [None]:
import pyproj
projcrs = pyproj.CRS.from_epsg(int(prismainfodict['Epsg_Code']))
projection = projcrs.to_wkt()

## Scale to radiances

In [None]:
# https://prisma.asi.it/missionselect/docs/PRISMA%20Product%20Specifications_Is2_3.pdf
# Scaling factor for SWIR cube 
# in order to transform uint16 
# DN to radiances units 
# [W/(m2+sr+um)] as follows: 
# Radiance_f32 = 
# L2ScaleSwirMin+DN_uint16*(
# L2ScaleSwirMaxL2ScaleSwirMin) /65535 

vnirnpf = (float(prismainfodict['L2ScaleVnirMin'][0])+vnirnp*(float(prismainfodict['L2ScaleVnirMax'])-float(prismainfodict['L2ScaleVnirMin'][0]))/65535)*10000

vnirnpf.max()


In [None]:
swirnpf = (float(prismainfodict['L2ScaleSwirMin'][0])+vnirnp*(float(prismainfodict['L2ScaleSwirMax'])-float(prismainfodict['L2ScaleSwirMin'][0]))/65535)*10000
swirnpf.max()

## Create and export the array as GeoTIFF

In [None]:
# TODO - combine all VNIR and SWIR arrays, ordered by centre wavelength

In [None]:
def CreateGeoTiff(outRaster, data, projection, geo_transform):
    driver = gdal.GetDriverByName('GTiff')
    rows, cols, no_bands = data.shape
    DataSet = driver.Create(outRaster, cols, rows, no_bands, gdal.GDT_Int16)
    DataSet.SetGeoTransform(geo_transform)
    DataSet.SetProjection(projection)

    #data = np.moveaxis(data, -1, 0)
    data = np.moveaxis(data, 2, 0)
    count = 1
    for i, image in reversed(list(enumerate(data, 1))):
        
        DataSet.GetRasterBand(count).WriteArray(image)
        count = count + 1
    DataSet = None

In [None]:
def set_band_descriptions(filepath, bands):
    """
    filepath: path/virtual path/uri to raster
    bands:    ((band, description), (band, description),...)
    """
    ds = gdal.Open(filepath, gdal.GA_Update)
    for band, desc in bands:
        rb = ds.GetRasterBand(band)
        rb.SetNoDataValue(0)
        rb.SetDescription(desc)
    del ds

In [None]:
def getbands(nparray, prefix, cwl):
    tmpbands = list(range(1,nparray.shape[2]+1))
    bands = []
    count = 0
    cwl = list(reversed(cwl.split()))
    for band in tmpbands:
        bands.append((band, prefix+str(band)+" "+str(int(float(cwl[count])))))
        count = count+1
    return(bands)

In [None]:
filename = Path(filename)

In [None]:
vnirfile = filename.name.removesuffix('.he5')+'_VNIR.tif'
vnirfilecog = filename.name.removesuffix('.he5')+'_VNIR_COG.tif'
swirfile = filename.name.removesuffix('.he5')+'_SWIR.tif'
swirfilecog = filename.name.removesuffix('.he5')+'_SWIR_COG.tif'

In [None]:
CreateGeoTiff(vnirfile, vnirnpf,projection, transform)
CreateGeoTiff(swirfile, swirnpf,projection, transform)

In [None]:
set_band_descriptions(vnirfile, getbands(vnirnpf, "VNIR", prismainfodict['List_Cw_Vnir']))
set_band_descriptions(swirfile, getbands(swirnpf, "SWIR", prismainfodict['List_Cw_Swir']))

In [None]:
getbands(vnirnpf, "VNIR", prismainfodict['List_Cw_Vnir'])

## Export as a COG

In [None]:
#gdal_translate world.tif world_webmerc_cog.tif -of COG -co TILING_SCHEME=GoogleMapsCompatible -co COMPRESS=JPEG
gdal.Translate(vnirfilecog,vnirfile, options="-of COG")
gdal.Translate(swirfilecog,swirfile, options="-of COG")

In [None]:
#gdal.Info(vnirfile).split('\n')

## Warp to lat/lon

In [None]:
gdal.Warp(vnirfile+'_geo.tif', vnirfile, options="-t_srs EPSG:4326 -overwrite -tr 0.0003 0.0003")

In [None]:
#gdal.Info(vnirfile+'_geo.tif').split('\n')

In [None]:
prismaxarraygeo = xarray.open_dataset(vnirfile+'_geo.tif')

In [None]:
prismaxarraygeo.band_data[[10,45,65]]

## COVETool PRISMA footprints were implemented in late 2023 - validating below

In [None]:
# Validating CEOS COVE PRISMA - using export as CSV feature in https://ceos-cove.org/en/acquisition_forecaster/
from shapely.geometry import Polygon, MultiPolygon
import pandas as pd
df = pd.read_csv('../prisma/ceos_cove_prediction.csv')
df.columns

In [None]:
polygons = []
for polygon in df['scene_coords']:
    latitude = []
    longitude = []
    for index, value in enumerate(polygon.split(',')):
        if index % 2 == 0:
            longitude.append(float(value))
        else:
            latitude.append(float(value))
    lonlats = []
    for index, value in enumerate(latitude):
          lonlats.append([longitude[index], latitude[index]])
    polygons.append(Polygon(lonlats))
multipolygon_geom = MultiPolygon(polygons)
    

In [None]:
polygon_gdf =  geopandas.GeoDataFrame(geometry=polygons)


In [None]:
polygon_gdf.hvplot(geo=True)

In [None]:
polygon_gdf.set_crs = {'init' :'epsg:4326'}

In [None]:
polygon_gdf.to_file('cove.shp')

In [None]:
polygon_gdf.hvplot.polygons(tiles='ESRI', alpha=0.2)

In [None]:
polygon_gdf.hvplot.polygons(tiles='ESRI', alpha=0.2) * (prismaxarraygeo.isel(band = [17,40,59])/100).hvplot.rgb( geo=True, x='x', y='y', bands='band',  title="RGB Plot with HVPlot", width=1200, tiles='ESRI', alpha=0.5)

# RioXarray import to DataArray

In [None]:
# RioXarray import GeoTIFF to DataArray
#rio_vnir_array = rioxarray.open_rasterio(vnirfile)
#rio_swir_array = rioxarray.open_rasterio(swirfile)

rio_vnir_array = xarray.open_dataarray(vnirfile)
rio_swir_array = xarray.open_dataarray(swirfile)
# Export to Xarray Dataset by band (takes band out of the keys which makes slicing less convenient
rio_vnir = rio_vnir_array.to_dataset('band')
rio_swir = rio_swir_array.to_dataset('band')

In [None]:
(rio_vnir_array.isel(band = [17,40,59])/100).hvplot.rgb(x='x', y='y', bands='band', data_aspect=1,   title="RGB Plot with HVPlot", width=1200, crs=crs.epsg(int(prismainfodict['Epsg_Code'])))

In [None]:
def renamevars(rioxr, prefix):
    count= 0
    renamedict = {}
    while count < len(rioxr.data_vars):
        #name = (prefix+str(count+1)+_+rioxr.attrs['long_name'][count])
        name = rioxr.attrs['long_name'][count]
        renamedict[count+1] = name
        count=count+1
    return(rioxr.rename(renamedict))

In [None]:
rio_swir=renamevars(rio_swir, 'SWIR')
rio_vnir=renamevars(rio_vnir, 'VNIR')

In [None]:
prismaxarray = xarray.merge([rio_vnir, rio_swir])

In [None]:
band = "VNIR63 977"
prismaxarray[band].hvplot(data_aspect=1, flip_yaxis=False, invert=True, xaxis=True, yaxis=True, title="PRISMA VNIR band "+band, cmap="Pink")

In [None]:
# Make an interactive RGB plot
redW = Dropdown(options = list(prismaxarray.data_vars))#prismaxarray.band.values.tolist())
greenW = Dropdown(options = list(prismaxarray.data_vars))#prismaxarray.band.values.tolist())
blueW = Dropdown(options = list(prismaxarray.data_vars))#prismaxarray.band.values.tolist())
startW = Dropdown(options = ['Pause', 'Go'])

@interact(red = redW, green = greenW, blue = blueW, start = startW)

def rgb_combo(red, green, blue, start):
    redW.options = list(prismaxarray.data_vars)#prismaxarray.band.values.tolist()
    greenW.options = list(prismaxarray.data_vars)#prismaxarray.band.values.tolist()
    blueW.options = list(prismaxarray.data_vars)#prismaxarray.band.values.tolist()
    if start == 'Go' and red != green and green !=blue and red != blue:
        #(prismaxarray[[red,green,blue]]/100).to_array().hvplot.rgb(x='x', y='y', bands='variable', data_aspect=1, flip_yaxis=False, xaxis=False, yaxis=None)
        prismaxarray[[red,green,blue]].to_array().plot.imshow(rgb='variable', robust=True)
    return(red,green,blue,start)

In [None]:
# TODO - Make the RGB plot interactive. It is possible as below but doesn't work with iwidget method above
(prismaxarray[['SWIR74 1687','SWIR18 1088','VNIR28 618']]/100).to_array().hvplot.rgb(x='x', y='y', bands='variable', data_aspect=1, flip_yaxis=False, xaxis=False, yaxis=None, title="RGB Plot with HVPlot", width=1200)

In [None]:
prismaxarray.spatial_ref.attrs

## Principal Components Analysis

In [None]:
# https://www.spectralpython.net/algorithms.html#principal-components

In [None]:
(m, c)  = kmeans(prismaallnp, 20, 30)

In [None]:
m.shape

In [None]:
c.shape

In [None]:
imshow(m)

In [None]:
for i in range(c.shape[0]):
    plt.plot(c[i])
plt.grid()

In [None]:
pc = principal_components(prismaallnp)

In [None]:
 v = imshow(pc.cov)

In [None]:
pc_0999 = pc.reduce(fraction=0.9999)

In [None]:
len(pc_0999.eigenvalues)

In [None]:
img_pc = pc_0999.transform(prismaallnp)

In [None]:
img_pc.shape

In [None]:
v = imshow(img_pc[:,:,:3], stretch_all=True)

In [None]:
pcafile = filename.name.removesuffix('.he5')+'_PCA.tif'

In [None]:
CreateGeoTiff(pcafile, img_pc,projection, transform) #TODO Figure out how to construct an Xarray from arrays - Inefficient

In [None]:
rio_pca_array = xarray.open_dataarray(pcafile)

In [None]:
pcafile

In [None]:
img_pc.shape[2]

In [None]:
index = img_pc.shape[2]
blue = index-1
green = index-2
red = index-3

In [None]:
# TODO - Make the RGB plot interactive. It is possible as below but doesn't work with iwidget method above
(rio_pca_array[[blue,green,red]]/100).hvplot.rgb(x='x', y='y', bands= 'band', data_aspect=1, flip_yaxis=False, xaxis=False, yaxis=None, title="RGB PCA Plot with HVPlot", width=1200)

# PRISMA L2C

In [None]:
f = h5py.File('../../../data/PRISMA/PRS_L2C_STD_20201130005043_20201130005047_0001/PRS_L2C_STD_20201130005043_20201130005047_0001.he5', 'r')

In [None]:
latitude = f['HDFEOS/SWATHS/PRS_L2C_HCO/Geolocation Fields/Latitude']

In [None]:
longitude = f['HDFEOS/SWATHS/PRS_L2C_HCO/Geolocation Fields/Longitude']

In [None]:
time = f['HDFEOS/SWATHS/PRS_L2C_HCO/Geolocation Fields/Time']

In [None]:
swir = f['HDFEOS/SWATHS/PRS_L2C_HCO/Data Fields/VNIR_Cube']
vnir = f['HDFEOS/SWATHS/PRS_L2C_HCO/Data Fields/SWIR_Cube']

print('VNIR '+str(vnir.shape), 'SWIR '+str(swir.shape))

In [None]:
plt.imshow(vnir[0:,70,0:])

In [None]:
plt.imshow(latitude[:])

In [None]:
plt.imshow(longitude[:])

In [None]:
prismaxarray = xarray.DataArray(data=swir[:], coords=dict(lon=(["x","y"], longitude[:]),lat=(["x","y"], latitude[:])), dims=["y","band","x"])

In [None]:
prismaxarray.isel(band=[4]).plot(robust=True)

In [None]:
#red, green, blue
prismaxarray.isel(band=[6,16,35]).plot.imshow(rgb='band', robust=True)

In [None]:
redW = Dropdown(options = prismaxarray.band.values.tolist())
greenW = Dropdown(options = prismaxarray.band.values.tolist())
blueW = Dropdown(options = prismaxarray.band.values.tolist())
startW = Dropdown(options = ['Pause', 'Go'])

@interact(red = redW, green = greenW, blue = blueW, start = startW)

def rgb_combo(red, green, blue, start):
    redW.options = prismaxarray.band.values.tolist()
    greenW.options = prismaxarray.band.values.tolist()
    blueW.options = prismaxarray.band.values.tolist()
    if start == 'Go':
        prismaxarray.isel(band=[red,green,blue]).plot.imshow(rgb='band', robust=True)


In [None]:
#TODO - Adapt and implement warp routine from EMIT 