In [None]:
# notebook I've been working on 6/10
# load dependencies
import joblib
import os
import rasterio
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import rioxarray as rxr
import datetime

In [None]:
# location of the Planet data
local_direc = '/home/etboud/projects/data/planet/'
data_direc = '/home/etboud/projects/data/planet/DLNY/'
focus_year = 2023

# specify if the output snow cover maps should be saved or not (1 = save, 0 = do not save)
saveData = 1

# specify the model used to classify snow presence, snow absence, and artifacts
model = joblib.load('/home/etboud/projects/data/planet/DLNY/classified_training/DLNY_2class_ndvi_model_2023.joblib')

# specify the directories and the indices of which PS scenes to filter, if any
subdirecs = sorted([d for d in glob.glob(data_direc + str(focus_year) + '*') if os.path.isdir(d)])
filtered_scenes = [10,18,27,35,36,38,40,41,44,45,48,50,51,79,96] # 2023, DLNY ---- Bad scenes


# reference file containing the maximum extents
ref_file = data_direc+'20230609_175435_13_242e/25d3618d-2f37-4c9e-8d67-c10e9d477081/PSScene/20230609_175435_13_242e_3B_AnalyticMS_SR_clip.tif' # DLNY

ref_file = rxr.open_rasterio(ref_file)
mask = ref_file.values
mask = np.sum(mask,axis=0)
ref_file = ref_file.isel(band=0)
ref_file.values = mask
ref_file.plot()

In [None]:
# function for classifying snow cover using "model"
def run_sca_prediction_band_selfClassify(f_raster, file_out, model,saveData,ref_file,mask):
        
    ds = rxr.open_rasterio(fname)
    ds = ds.rio.reproject_match(ref_file)
    arr = ds.values
    arr_sum = np.sum(arr,axis=0)
    
    print("Image dimension:".format(), arr.shape)  # 
    X_img = pd.DataFrame(arr.reshape([4,-1]).T)
    X_img.columns = ['b','g','r','nir']
    X_img
    y_img = model.predict(X_img)
    
    out_img = pd.DataFrame()
    out_img['label'] = y_img
    
    # Reshape our classification map
    img_prediction = out_img['label'].to_numpy().reshape(arr[0,:, :].shape)
    # put "no-data" classifications in the cells that are just cut off from the obs
    img_prediction[(arr_sum == 0) & (mask > 0)] = 2

    if saveData:
        ds_save = ref_file
        ds_save.attrs['long_name'] = ('classification')
        ds_save.values = img_prediction
        ds_save.rio.to_raster(file_out)
            
    return img_prediction

# calculate the rgb bands and normalize radiances
# see 1_classify_train_model.ipynb
def calc_rgb(ds):
    # Selecting RGB bands
    blue_band = ds.isel(band=0)
    green_band = ds.isel(band=1)
    red_band = ds.isel(band=2)
    nir_band = ds.isel(band=3)
    
    # normalize
    maxval = green_band.max().values
    minval = green_band.min().values
    red_norm = (red_band - minval) / (maxval - minval)
    green_norm = (green_band - minval) / (maxval - minval)
    blue_norm = (blue_band - minval) / (maxval - minval)
    green_norm = green_norm.where(red_norm <= 1,1)
    blue_norm = blue_norm.where(red_norm <= 1,1)
    red_norm = red_norm.where(red_norm <= 1,1)

    red_band = red_band.values
    green_band = green_band.values
    blue_band = blue_band.values
    nir_band = nir_band.values
    
    # Stack normalized bands to create RGB image
    rgb_image = np.stack([red_norm, green_norm, blue_norm], axis=-1)
    return red_band,green_band,blue_band,nir_band,rgb_image

In [None]:
for dCount,direcc in enumerate(subdirecs):
    if dCount not in filtered_scenes:
        fname = glob.glob(direcc+'/*/PSScene/*SR_clip.tif')[0]
        
        outfile_0 = fname.split('/')[-1].split('_')
        outfile = local_direc+'processed_SCA/'+outfile_0[0]+'_'+outfile_0[1]+'_SCA.tif'
        print(outfile)
            
        classified = run_sca_prediction_band_selfClassify(fname,outfile,model,saveData,ref_file,mask)

        # process the image and plot the image versus the snow cover map for visual inspection
        rgb_image = rxr.open_rasterio(fname)
        _,_,_,_,_,rgb_image = calc_rgb(rgb_image)
        fg,ax = plt.subplots(1,2)
        ax = np.ravel(ax)
        ax[0].imshow(rgb_image,cmap='gray')
        ax[1].imshow(classified,vmin=0,vmax=2,interpolation='none')
        ax[0].set_title(outfile_0[0])
        ax[0].set_xticks([])
        ax[1].set_xticks([])
        ax[0].set_yticks([])
        ax[1].set_yticks([])

In [None]:
# checking the output
# dirr = '/home/etboud/projects/data/planet/processed_SCA/'

# for d in glob.glob(dirr + str(focus_year) + '*'):
#     cl= rxr.open_rasterio(d)
#     #print(d)
#     fig, ax = plt.subplots()
#     ax.imshow(cl[0],vmin=0,vmax=2,interpolation='none')
#     ax.set_title(d.split('/')[-1])