In [None]:
"""
Created by Thai Tran, LAL
"""
import arcpy
import os
import shutil
from glob import glob
from enum import Enum
import rasterio
from rasterio.transform import Affine
from rasterio.enums import Resampling
from rasterio.warp import calculate_default_transform
import numpy as np
def copy_dir(src, dst, *, follow_sym=True):
    if os.path.isdir(dst):
        dst = os.path.join(dst, os.path.basename(src))
    if os.path.isdir(src):
        shutil.copyfile(src, dst, follow_symlinks=follow_sym)
        shutil.copystat(src, dst, follow_symlinks=follow_sym)
    return dst
class Upscale(Enum):
    B01 = 6
    B02 = 1
    B03 = 1
    B04 = 1
    B05 = 2
    B06 = 2
    B07 = 2
    B08 = 1
    B8A = 2
    B09 = 6
    B10 = 6
    B11 = 2
    B12 = 2
    SCL = 2
def upscale( filename):
    head, tail = os.path.split(str(filename))
    band = (tail[-11:-8]).replace('_', '')
    upscale_factor = Upscale[band].value
    if upscale_factor == 2:
        with rasterio.open(filename) as dataset:
            data = dataset.read(
                out_shape=(
                    dataset.count,
                    int(dataset.height * upscale_factor),
                    int(dataset.width * upscale_factor)
                ),
                resampling=Resampling.bilinear
            )
            # scale image transform
            transform = dataset.transform * dataset.transform.scale(
                (dataset.width / data.shape[-1]),
                (dataset.height / data.shape[-2])
            )
            # export image save in the disk with crs geotiff
            profile = dataset.profile
            profile.update(
                transform=dataset.transform * Affine.scale(1 / upscale_factor),
                height=data.shape[-2],
                width=data.shape[-1]
            )
            output_file = filename.replace('/60HTB', '/resample').replace('20m.tif', '10m.tif')
            with rasterio.open(output_file, "w", **profile) as resampled:
                resampled.write(data)
                
if __name__ == "__main__":
    source = arcpy.GetParameterAsText(0)
    resample = arcpy.GetParameterAsText(1)
##    source = "D:/Projetcs/Manuka/toolbox/data/60HTB"
##    resample = "D:/Projetcs/Manuka/toolbox/data/resample"
    shutil.copytree(source, resample, copy_function=copy_dir, dirs_exist_ok=True)
    # List of jp2 files
    listOfiles = glob(source + '/**',  recursive = True)
    FileList = [num for num in listOfiles if ('10m.jp2' in num or '20m.jp2' in num ) and ('_B' in num or '_SCL_' in num) and 'B01' not in num]
    arcpy.AddMessage("Total files: ")
    arcpy.AddMessage(len(FileList))
    # Loop through all files in the directory and check if the file name ends with '_B02_10m.jp2'
    heads, tails = os.path.split(str(source))
    headr, tailr = os.path.split(str(resample))
    tailss = "\\"+tails
    tailrr = "\\"+tailr
    for filename in FileList:
        print('Converting to tif: ',filename)
        arcpy.AddMessage('Converting to tif: ')        
        arcpy.AddMessage(filename)
        
        input_file = filename
        output_file = filename.replace(tailss,tailrr).replace('.jp2', '.tif')
        # Open the input file with rasterio
        with rasterio.open(input_file) as src:  
            # Define the output profile and options
            profile = src.profile.copy()
            # Update the metadata for the output TIF file
            profile.update(
                driver='GTiff',  # Specify the output driver as GeoTIFF
                count=1  # Set the number of bands to 1
                )
            # Create the output TIF file
            with rasterio.open(output_file, 'w', **profile) as dst:
                # Copy the data from the input file to the output file
                for i in range(1, src.count + 1):
                    data = src.read(i)
                    dst.write(data, i)
        upscale(output_file)
        # Close the input and output files
        src.close()
        dst.close()
        arcpy.AddMessage(output_file)
        arcpy.AddMessage('Done!')
    # create list of folders
    Folderlist = []
    Folders = os.listdir(resample)
    for folder in Folders:
        if (folder.endswith(".SAFE")):
            Folderlist.append(os.path.join(folder))
        
    
    for filename in Folderlist:
        # print(filename)  
        # create list of bands
        listOfiles = glob(os.path.join(resample,filename + '/**'),  recursive = True)
        FileList = [num for num in listOfiles if '10m.tif' in num and '_B' in num ]
        # sort the list to: band 2, band 3, band 4, band 5, band 6, band 7, band 8, band 8a, band 11, band 12
        FileList.sort(key= lambda x: x[-10:-8] )
        last_element = FileList.pop()
        FileList.insert(7, last_element)
        ###        
        # get file name    
        head, tail = os.path.split(FileList[0])
        filename = tail[:-11]
        #Stacking 10 bands into 1 file.
        # Read metadata of first file
        with rasterio.open(FileList[0]) as src0:
            meta = src0.meta
        # # Update meta to reflect the number of layers
        meta.update(count = len(FileList))
        # # Read each layer and write it to stack
        outstack = os.path.join(head, filename + 'ALL10m.tif')
        with rasterio.open(outstack, 'w', **meta) as dst:
            for id, layer in enumerate(FileList, start=1):                    
                with rasterio.open(layer) as src1:
                    dst.write_band(id, src1.read(1))
        print(outstack, ": Stacking successfully!") 
        # Cloud removal part
        src = rasterio.open(outstack)
        # print(src.crs)
        # print(src.count)
        if(src.count<10):
            print("Not all bands present ABORTING: " + filename)
        else:
            B2 =src.read(1) #Blue
            B3 =src.read(2) #Green
            B4 =src.read(3) #Red
            B5 =src.read(4)
            B6 =src.read(5)
            B7 =src.read(6)
            B8 =src.read(7) #vnir
            B8A =src.read(8)
            B11 =src.read(9)
            B12 =src.read(10)
        scl = FileList[-1]
        scl = scl.replace('B12', 'SCL')
        print('scl',scl)
        src = rasterio.open(scl)
        r_mask = src.read(1)
        r_mask.shape
        # create mask boolean with the condition
        mask_all = np.isin(r_mask, [4, 7])
        # Apply to 10 bands
        B2_mask = B2 * mask_all
        B3_mask = B3 * mask_all
        B4_mask = B4 * mask_all
        B5_mask = B5 * mask_all
        B6_mask = B6 * mask_all
        B7_mask = B7 * mask_all
        B8_mask = B8 * mask_all
        B8A_mask = B8A * mask_all
        B11_mask = B11 * mask_all
        B12_mask = B12 * mask_all
        file_list_mask = [B2_mask, B3_mask, B4_mask, B5_mask, B6_mask, B7_mask, B8_mask, B8A_mask, B11_mask, B12_mask]
         # # Update meta to reflect the number of layers
        meta.update(count = len(file_list_mask))
        # # Read each layer and write it to stack
        
        outstackmask = os.path.join(head, filename + 'CloudFree.tif') 
        with rasterio.open(outstackmask, 'w', **meta) as dst:
            for id, layer in enumerate(file_list_mask, start=1):
                # with rasterio.open(layer) as src1:
                dst.write_band(id, layer)
        print(outstackmask,": Stacking successfully!")
        # copy to output folder:
        fdn, fln = os.path.split(outstackmask)
        shutil.copy2(outstackmask,os.path.join(resample, fln ) )
