In [None]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv() 
import os 
os.environ["LOGURU_LEVEL"] = "INFO"

from oxeo.water.models.tile_utils import predict_tile
from oxeo.core.models.tile import TilePath, tile_from_id, Tile
from oxeo.water.models.segmentation import Segmentation2DPredictor
from oxeo.water.models.pekel import PekelPredictor

from oxeo.water.utils import plot_imgs_in_row
from skimage.exposure import rescale_intensity
from zarr.errors import ArrayNotFoundError
from tqdm import tqdm
import numpy as np

import zarr
import gcsfs
from oxeo.water.metrics import multiclass_metric, dice, iou, accuracy, precision, recall

from oxeo.water.utils import plot_imgs_in_row
from skimage.exposure import rescale_intensity
from zarr.errors import ArrayNotFoundError
from tqdm import tqdm
import pandas as pd
import numpy as np
from pprint import pprint
fs = gcsfs.GCSFileSystem()
from IPython.display import display, HTML


In [None]:
# Load hand label json files
# Import the Google Cloud client library and JSON library
from google.cloud import storage
import json



In [None]:
from typing import List
def get_tile_paths(data: dict) -> List[TilePath]:
    tile_paths = []
    for _, tile_data in data.items():
        tile = tile_from_item(tile_data)
        tile_paths.append(TilePath(tile=tile, constellation=tile_data["spacecraft_id"], root="gs://oxeo-water/prod2"))
    return tile_paths

In [None]:
def tile_from_item(item: dict) -> Tile:
    return tile_from_id(item["tile"])

In [None]:
import random
constellation = "landsat-7"
# Instantiate a Google Cloud Storage client and specify required bucket and file
storage_client = storage.Client()
bucket = storage_client.get_bucket('oxeo-handlabelling')
blob = bucket.blob(f'iris/zimmoz-{constellation}_zimmoz.json')



# Download the contents of the blob as a string and then parse it using json.loads() method
data = json.loads(blob.download_as_string(client=None))

tile_paths = get_tile_paths(data)


In [None]:
from collections import defaultdict
tile_dates = defaultdict(list)

In [None]:
for item in data.values():
    tile_dates[item["tile"]].append(item["datetime"])


In [None]:
import yaml

In [None]:
print(yaml.dump({constellation: dict(tile_dates)}))

In [None]:
data_dates += [item["datetime"] for item in data.values()]

In [None]:
data_dates

In [None]:
from oxeo.satools.io import strdates_to_datetime
for tile_path in tile_paths:
    print(tile_path.tile.id)
    arr = zarr.open_array(tile_path.timestamps_path, 'r')[:]
    print(arr.shape)

    

In [None]:
train_set = ['37_L_10000_54_878',
 '35_K_10000_56_808',
 '36_J_10000_42_711',
 '35_K_10000_34_774',
 '36_L_10000_52_824']

[t.tile.id for t in tile_paths if t.tile.id not in train_set]


In [None]:
with open("/home/julien/Documents/oxeo/oxeo-data/data/train.json") as f:
        data = json.load(f)

    random_items = random.sample(data, 5)

In [None]:
metrics.values()

In [None]:

metrics = {
    "dice": dice,
    "iou": iou,
    "accuracy": accuracy,
    "precision": precision,
    "recall": recall
}



pred_model = "pekel"
only_water = True
cols = [f"water_{m}" for m in metrics.keys()]
for constellation in ["landsat-5", "landsat-7", "landsat-8", "sentinel-2"]:
    df = pd.DataFrame(columns=cols) # one df per constellation
    
    
    # Instantiate a Google Cloud Storage client and specify required bucket and file
    storage_client = storage.Client()
    bucket = storage_client.get_bucket('oxeo-handlabelling')
    blob = bucket.blob(f'iris/zimmoz-{constellation}_zimmoz.json')



    # Download the contents of the blob as a string and then parse it using json.loads() method
    data = json.loads(blob.download_as_string(client=None))

    tile_paths = get_tile_paths(data)
    
    
    for _, item in data.items():
        try:
            series = pd.Series()
            # get all tiles ids in data items


            tile = item["tile"]
            gt_mask = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/mask/zimmoz", mode="r")
            timestamps = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/timestamps", mode="r")
            pred_mask = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/mask/{pred_model}", mode="r")

            ts = timestamps[:]
            ts = [t[:10] for t in ts]

            ts_index = ts.index(item['datetime'])

            y_true = gt_mask[ts_index]
            y_pred = pred_mask[ts_index]

            for k, v in metrics.items():
                metric = multiclass_metric(v, y_true, y_pred)
                if only_water:
                    metric = { key:value for (key,value) in metric.items() if key == 1}
                metric = pd.Series(metric)

                #print(tile, constellation, pred_model, k, metric)
                series = pd.concat([series, metric], axis=0, ignore_index=True)
                #df = df.append(metric, ignore_index=True)

            series.index = cols
            
            df = df.append(series, ignore_index=True)
            
        except:
            continue
    display(df.describe())
            #df.columns = ["bg", f"water_{k}", f"clouds_{k}", "other"]
        #df.drop(df.columns[[0,3]], axis=1, inplace=True)
    #    concat_df = pd.concat([concat_df, df], axis=0, ignore_index=True)
     #   print(f"\n{constellation}")
   # display(concat_df.describe())

In [None]:
from matplotlib.backends.backend_pdf import PdfPages
def save_multi_image(filename, figures):
    pp = PdfPages(filename)

    for fig in figures:
        fig.savefig(pp, format='pdf')
    pp.close()
    
    
    
for constellation in ["landsat-7", "landsat-8", "sentinel-2"]:

    # Instantiate a Google Cloud Storage client and specify required bucket and file
    storage_client = storage.Client()
    bucket = storage_client.get_bucket('oxeo-handlabelling')
    blob = bucket.blob(f'iris/zimmoz-{constellation}_zimmoz.json')



    # Download the contents of the blob as a string and then parse it using json.loads() method
    data = json.loads(blob.download_as_string(client=None))

    tile_paths = get_tile_paths(data)    
    # Load the timestamps per tile and compare with gt
    figures = []

    pred_model = "pekel"
    label = 1
    # get all tiles ids in data items
    for _, item in tqdm(data.items()):
        tile = item["tile"]
        try:
            img_zarr = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/data", mode="r")
            gt_mask = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/mask/zimmoz", mode="r")
            timestamps = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/timestamps", mode="r")
            pred_mask = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/mask/{pred_model}", mode="r")

        except ArrayNotFoundError:
            print("Array not found doing prediction.")
            if pred_mask == "cnn":
                predictor = Segmentation2DPredictor(ckpt_path="gs://oxeo-models/semseg/epoch_012.ckpt", fs=fs)
            else:
                predictor = PekelPredictor(fs=fs)
            t = tile_from_id(tile)
            tile_path  = TilePath(tile=t, constellation=item["spacecraft_id"], root="gs://oxeo-water/prod2")
            try:
                predict_tile(tile_path, pred_model, predictor, 
                    revisit_chunk_size=128,
                    start_date="1980-03-19", end_date="2100-03-19",
                     fs=fs, overwrite=True)
            except:
                continue
            pred_mask = zarr.open_array(f"gs://oxeo-water/prod2/{tile}/{constellation}/mask/{pred_model}", mode="r")
        ts = timestamps[:]
        ts = [t[:10] for t in ts]

        ts_index = ts.index(item['datetime'])

        img = img_zarr[ts_index][[3,2,1]].transpose(1,2,0)
        vmin, vmax = np.percentile(img, q=(2, 98))
        img = rescale_intensity(img, in_range=(vmin, vmax), out_range=(0, 1))
        fig = plot_imgs_in_row([img,  gt_mask[ts_index]==label, pred_mask[ts_index]==label], labels=("img", "gt", "pred"), figsize=(15,8))
        figures.append(fig)

In [None]:
df

In [None]:
df.describe()

In [None]:
np.intersect1d(data_dates, ts)


In [None]:
ts[ts_index[1]]

In [None]:
i = ts_index[0]
plot_imgs_in_row([img[i][[3,2,1]].transpose(1,2,0), cnn_mask[i], gt_mask[i]])

In [None]:
import matplotlib.pyplot as plt
for i in ts_index:
    print(iou(gt_mask[i]==2, cnn_mask[i]==2))
    plt.imshow(img[i][[3,2,1]].transpose(1,2,0)//10000)
    plt.show()
    plt.imshow(cnn_mask[i])
    plt.show()
    plt.imshow(gt_mask[i])
    plt.show()