In [None]:
%load_ext autoreload
%autoreload 2
from skimage.exposure import equalize_adapthist, rescale_intensity
from skimage.color import label2rgb
from matplotlib import cm, colors
from skimage import exposure
from skimage import img_as_float

In [None]:

import matplotlib.pyplot as plt

from oxeo.water.datamodules.datasets import VirtualDataset
from oxeo.water.datamodules import ConstellationDataModule
from oxeo.water.datamodules import transforms as oxtransforms
from torchvision.transforms import Compose
from oxeo.satools.io import ConstellationData,create_index_map,  load_virtual_datasets
from oxeo.satools.processing import to_toa_reflectance, parsemeta, to_toa_radiance
from oxeo.water.models.pekel import PekelPredictor
from satextractor.models import constellation_info


import dask.array as da
import numpy as np
import xarray as xr
import pandas as pd

import gcsfs
fs = gcsfs.GCSFileSystem()


paths = [
         #'oxeo-water/prod/54_K_10000_34_772',
         'oxeo-water/india_wri/43_P_10000_63_132',
         'oxeo-water/india_wri/43_P_10000_64_131',
         'oxeo-water/india_wri/43_P_10000_64_132',
        ]

constellations = ["landsat-7","landsat-8", "sentinel-2"]
all_paths = {kk:[f"gs://{path}" for path in paths] for kk in constellations}

data_landsat_7 = ConstellationData("landsat-7",bands=list(constellation_info.LANDSAT7_BAND_INFO.keys()),
                             paths=all_paths["landsat-7"])


data_landsat_8 = ConstellationData("landsat-8",bands=list(constellation_info.LANDSAT8_BAND_INFO.keys()),
                             paths=all_paths["landsat-8"])

data_sen2 = ConstellationData("sentinel-2",bands=list(constellation_info.SENTINEL2_BAND_INFO.keys()),
                             paths=all_paths["sentinel-2"])




train_constellation_regions={"data":[[data_landsat_7], [data_landsat_8], [data_sen2], ]}
ds = load_virtual_datasets(train_constellation_regions, 
                           date_range=("2018-01-01","2018-05-01"), 
                           fs_mapper=fs.get_mapper)
#index_map = create_index_map(train_constellation_regions, ("2018-01-01","2019-02-01"),100, "train_2.csv")


In [None]:
ds[0][0]

In [None]:
predictor = PekelPredictor()

In [None]:
masks = predictor.predict(ds[0][0],"landsat-7",verbose=11)

In [None]:
plt.imshow(masks[1])

In [None]:
zarr_arr = zarr.open_array("gs://oxeo-water/india_wri/43_P_10000_64_132/landsat-8/data")

In [None]:
ds[0][1]["landsat-8"].revisits

In [None]:

img = zarr_arr[0][[0]].transpose(1,2,0)

vmin, vmax = np.percentile(img, q=(2, 98))
img = exposure.rescale_intensity(img,in_range=(vmin,vmax))
plt.imshow(img)


In [None]:
def plot_bands(arr, band_names):
    rows = arr.shape[0]
    cols = arr.shape[1]
    fig, ax = plt.subplots(rows, cols, figsize=(14,rows))
    # axes are in a two-dimensional array, indexed by [row, col]
   # fig.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
    for i in range(rows):
        for j in range(cols):
            if i == 0:
                ax[i, j].set_title('Band ' + band_names[j])
            img = arr[i,j]
            vmin, vmax = np.percentile(img, q=(2, 98))
            img = exposure.rescale_intensity(img,in_range=(vmin,vmax))

            ax[i, j].imshow(img)  
            ax[i,j].axis('off')
    fig

In [None]:
img = zarr_arr[1]

In [None]:
img = img[[3,2,1]].transpose(1,2,0)/10000
vmin, vmax = np.percentile(img, q=(2, 98))
img = exposure.rescale_intensity(img,in_range=(vmin,vmax))
plt.imshow(img)

In [None]:
bands = list(constellation_info.LANDSAT8_BAND_INFO.keys())
plot_bands(zarr_arr[:20], bands)

In [None]:
#fmask_cloud = ((b7/b7.max()) > 0.03) & (ndsi<0.8) & (ndvi<0.8)