In [None]:
%load_ext autoreload

%autoreload 2
from google.cloud import bigquery
import matplotlib.pyplot as plt
import numpy as np
# Construct a BigQuery client object.
client = bigquery.Client()
import numpy as np
from satools.io import ConstellationData, constellations_dataset
from satextractor.models import constellation_info

from oxeo.water.models.utils import load_tile, TilePath, tile_from_id, get_band_list
import gcsfs
import numpy as np
from skimage.morphology import (
    closing,
    label,
    remove_small_holes,
    remove_small_objects,
    square,
)
from oxeo.water.utils import plot_imgs_in_row

def get_areas(area_id, model, constellation):
    if constellation:
      query = f"""SELECT * FROM `oxeo-main.water.water_ts` 
                WHERE area_id = {area_id}  
                AND run_id LIKE "%-{model}-%"
                AND constellation = "{constellation}" 
                
              """
    else:
      query = f"""SELECT * FROM `oxeo-main.water.water_ts` 
                WHERE area_id = {area_id}  
                AND run_id LIKE "%-{model}-%"
                
              """
    print(query)
    query_job = client.query(query)  # Make an API request.
    return query_job.to_dataframe()

def get_tile_ids(area_id):
    query = f"""SELECT tiles FROM `oxeo-main.water.water_extractions`
                WHERE area_id = {area_id} LIMIT 1  
                
              """
    query_job = client.query(query)  # Make an API request.
    return query_job.to_dataframe().tiles.values[0]
  

In [None]:
area_id = 51318547
constellation = "landsat-5"
res = 30
bands = np.array(get_band_list(constellation))
bands = np.where(np.isin(bands,["red","green","blue"]))[0]
if constellation == "sentinel-2":
    band_info = list(constellation_info.SENTINEL2_BAND_INFO.keys())
elif constellation == "landsat-5":
    band_info = list(constellation_info.LANDSAT5_BAND_INFO.keys())
elif constellation == "landsat-7":
    band_info = list(constellation_info.LANDSAT7_BAND_INFO.keys())
elif constellation == "landsat-8":
    band_info = list(constellation_info.LANDSAT8_BAND_INFO.keys())
band_names = np.array(band_info)[bands].tolist()

In [None]:

job = get_areas(area_id, "cnn", None)
job = job[job.run_id=="51318547-cnn-985d72fe"]
# job["area"] = norm(job.area)
job = job.sort_values(by="date").drop_duplicates(
    subset=["date"], keep="last", ignore_index=True
)
pekel_job = get_areas(area_id, "pekel-la-alumbrera", None)
# pekel_job["area"] = norm(pekel_job.area)
pekel_job = pekel_job.sort_values(by="date").drop_duplicates(
    subset=["date"], keep="first"
)

In [None]:
import pandas as pd
import seaborn as sns
plt.figure(figsize = (25,8))
df = pd.merge(job,pekel_job,on="date")
df["cnn_ma"] = df.area_x.rolling(50).mean()
df["pekel_ma"] = df.area_y.rolling(50).mean()
sns.lineplot(x="date", y="pekel_ma", data=df,alpha=1.0, label="Pekel", color="green", linestyle='--')
# plt.ylim(0, 1e7)
sns.lineplot(x="date", y="cnn_ma", data=df,alpha=1.0,  label="CNN", color="purple")
sns.lineplot(x="date", y="area_x", data=df,alpha=0.2, hue="constellation_x")
# sns.lineplot(x="date", y="area_y", data=df,alpha=0.3, color="red")

In [None]:

tiles = get_tile_ids(area_id)


In [None]:


fs = gcsfs.GCSFileSystem()





img_data = ConstellationData(
    constellation,
    bands=band_info,
    paths=[f"gs://oxeo-water/prod/{t}" for t in tiles],
)

mask_data = ConstellationData(
    constellation, bands=["mask"], paths=[f"gs://oxeo-water/prod/{t}" for t in tiles]
)
img_ds = constellations_dataset([img_data], data_path="data")
pekel_ds = constellations_dataset([mask_data], data_path="mask/pekel")
cnn_ds = constellations_dataset([mask_data], data_path="mask/cnn")


In [None]:
import matplotlib.pyplot as plt
from skimage.exposure import rescale_intensity
from oxeo.water.metrics import segmentation_area
import os

os.environ["LOGURU_LEVEL"] = "INFO"
areas_diff = []

for i in range(0, cnn_ds[constellation].shape[0]):

    cnn = (
        cnn_ds[constellation].sel({"bands": ["mask"]}).isel({"revisits": i}).compute()
    ).data.squeeze()

    cnn[cnn != 1] = 0
    cnn = cnn.astype(bool)
    cnn = closing(cnn, square(3))
    cnn = remove_small_holes(cnn, area_threshold=50, connectivity=2)
    cnn = remove_small_objects(cnn, min_size=50, connectivity=2)
    cnn = label(cnn, background=0, connectivity=2)

    # pekel = tile["pekel"].numpy().squeeze()[i]
    pekel = (
        pekel_ds[constellation].sel({"bands": ["mask"]}).isel({"revisits": i}).compute()
    ).data.squeeze()
    pekel = pekel.astype(bool)
    pekel = closing(pekel, square(3))
    pekel = remove_small_holes(pekel, area_threshold=50, connectivity=2)
    pekel = remove_small_objects(pekel, min_size=50, connectivity=2)
    pekel = label(pekel, background=0, connectivity=2)

    # img = tile["image"].numpy()[i][[1, 2, 3]].transpose(1, 2, 0)
    img = (
        img_ds[constellation]
        .sel({"bands": band_names})
        .isel({"revisits": i})
        .compute()
        .data.transpose(1, 2, 0)
    )
    area_cnn = segmentation_area(cnn, "meter", res)
    area_pekel = segmentation_area(pekel, "meter", res)

    vmin, vmax = np.percentile(img, q=(2, 98))
    img = rescale_intensity(img, in_range=(vmin, vmax), out_range=(0, 1))
    plot_imgs_in_row([img, pekel, cnn], figsize=(12,8))
    plt.show()
