In [None]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv() 
import os 
os.environ["LOGURU_LEVEL"] = "INFO"

from oxeo.water.models.tile_utils import predict_tile, predict_tile_revisits
from oxeo.core.models.tile import TilePath, tile_from_id, Tile
from oxeo.water.models.segmentation import Segmentation2DPredictor
from oxeo.water.models.pekel import PekelPredictor

from oxeo.water.utils import plot_imgs_in_row
from skimage.exposure import rescale_intensity
from zarr.errors import ArrayNotFoundError
from tqdm import tqdm
import numpy as np

import zarr
import gcsfs
from oxeo.water.metrics import multiclass_metric, dice, iou, accuracy, precision, recall

from oxeo.water.utils import plot_imgs_in_row
from skimage.exposure import rescale_intensity
from zarr.errors import ArrayNotFoundError

import pandas as pd
import numpy as np
from pprint import pprint
fs = gcsfs.GCSFileSystem()
from IPython.display import display, HTML
from collections import defaultdict


In [None]:
# Load hand label json files
# Import the Google Cloud client library and JSON library
from google.cloud import storage
import json



In [None]:
from typing import List
def get_tile_paths(data: dict) -> List[TilePath]:
    tile_paths = []
    for _, tile_data in data.items():
        tile = tile_from_item(tile_data)
        tile_paths.append(TilePath(tile=tile, constellation=tile_data["spacecraft_id"], root="gs://oxeo-water/prod2"))
    return tile_paths

In [None]:
def tile_from_item(item: dict) -> Tile:
    return tile_from_id(item["tile"])

In [None]:
import random
constellation = "landsat-7"
# Instantiate a Google Cloud Storage client and specify required bucket and file
storage_client = storage.Client()
bucket = storage_client.get_bucket('oxeo-handlabelling')
blob = bucket.blob(f'iris/zimmoz-{constellation}_zimmoz.json')



# Download the contents of the blob as a string and then parse it using json.loads() method
data = json.loads(blob.download_as_string(client=None))

tile_paths = get_tile_paths(data)


In [None]:
tile_dates = defaultdict(list)
for item in data.values():
    tile = tile_from_item(item)
    tile_path =  TilePath(tile=tile, constellation=item["spacecraft_id"], root="gs://oxeo-water/prod2")
    tile_dates[tile_path].append(item["datetime"])


In [None]:
predictor = Segmentation2DPredictor(ckpt_path="gs://oxeo-models/semseg/epoch_012.ckpt", fs=fs)

In [None]:
masks = []
for k, v in tile_dates.items():
    mask = predict_tile_revisits(k, v,predictor, fs,-1)
    masks.append(mask)

In [None]:
val_tile_ids = ['35_K_10000_56_758',
                '36_J_10000_44_712',
                '36_K_10000_28_771',
                '36_K_10000_23_803',
                '34_K_10000_76_817',
                '36_K_10000_51_748',
                '36_L_10000_68_866',
                '35_K_10000_32_769',
                '36_L_10000_77_831',
                '35_K_10000_32_770',
                '34_L_10000_61_830',
                '35_K_10000_53_763',
                '35_L_10000_39_827',
                '36_L_10000_79_829',
                '36_J_10000_50_719',
                '36_L_10000_77_830',
                '36_L_10000_80_838',
                '34_K_10000_68_774',
                '36_K_10000_72_753',
                '35_L_10000_75_826',
                '37_K_10000_20_793',
                '36_K_10000_26_770',
                '36_L_10000_24_827',
                '36_K_10000_28_770',
                '37_L_10000_64_882',
                '36_K_10000_20_805','36_L_10000_53_824',
                '35_K_10000_32_771',
                '36_L_10000_42_826',
                '36_L_10000_23_826',
                '36_L_10000_40_828',
                '36_J_10000_24_701',
                '36_L_10000_77_831',
                '36_K_10000_24_745',
                '36_K_10000_68_736',
                '35_K_10000_63_813',
                '34_K_10000_66_772',
                '36_J_10000_44_703',
                '36_L_10000_31_825',
                '36_J_10000_62_728',
                '36_K_10000_60_750',
                '35_K_10000_77_759',
                '36_L_10000_39_825',
                '36_L_10000_40_826',
                '36_K_10000_59_750',
                '37_K_10000_20_793',
                '36_L_10000_80_837',
                '35_K_10000_40_771', '36_J_10000_34_714',
                '35_K_10000_41_775',
                '36_L_10000_25_856',
                '36_J_10000_47_715',
                '36_K_10000_80_799',
                '36_J_10000_37_703',
                '36_J_10000_69_728',
                '36_K_10000_50_750',
                '35_K_10000_41_774',
                '36_J_10000_54_730',
                '34_L_10000_75_829',
                '37_K_10000_20_793',
                '36_J_10000_52_719',
                '36_L_10000_29_824',
                '36_L_10000_37_825',
                '35_K_10000_34_772',
                '36_L_10000_42_828',
                '36_K_10000_63_764',
                '36_J_10000_52_720',
                '36_J_10000_45_697',
                '36_K_10000_21_789',
                '36_K_10000_35_771',
                '36_K_10000_69_780',
                '37_L_10000_56_847','36_L_10000_78_828',
                '36_L_10000_78_830',
                '36_J_10000_47_696',
                '35_K_10000_22_801',
                '37_L_10000_56_848',
                '36_L_10000_43_828',
                '36_L_10000_34_851',
                '35_K_10000_34_773',
                '36_J_10000_49_718',
                '36_J_10000_35_714',
                '36_L_10000_81_837',
                '37_L_10000_60_880',
                '34_K_10000_68_774',
                '36_K_10000_28_770',
                '34_K_10000_68_773',
                '36_K_10000_27_801',
                '35_K_10000_40_772',
                '35_J_10000_73_718',
                '35_L_10000_32_864',
                '36_L_10000_81_835']

In [None]:
import matplotlib.pyplot as plt
metrics = {
    "dice": dice,
    "iou": iou,
    "accuracy": accuracy,
    "precision": precision,
    "recall": recall
}

cnn_ckpt = "/home/fran/repos/oxeo-water/oxeo/water/logs/experiments/unet_semseg_all_tiles/runs/2022-01-12/21-21-12/checkpoints/epoch_012.ckpt"
cnn_ft_ckpt = "/home/fran/repos/oxeo-water/oxeo/water/logs/experiments/unet_semseg_zimmoz/runs/2022-04-14/03-35-29/checkpoints/epoch_027.ckpt"

#predictor = Segmentation2DPredictor(ckpt_path=cnn_ckpt, fs=None)

predictor = PekelPredictor(fs, n_jobs=1)

only_water = True
cols = [f"water_{m}" for m in metrics.keys()]
for constellation in ["landsat-5", "landsat-7", "landsat-8", "sentinel-2"]:
     
    

    # Instantiate a Google Cloud Storage client and specify required bucket and file
    storage_client = storage.Client()
    bucket = storage_client.get_bucket('oxeo-handlabelling')
    blob = bucket.blob(f'iris/zimmoz-{constellation}_zimmoz.json')



    # Download the contents of the blob as a string and then parse it using json.loads() method
    data = json.loads(blob.download_as_string(client=None))

    tile_paths = get_tile_paths(data)
    
    tile_dates = defaultdict(list)
    ts_indices = []
    
    y_true = []
    y_pred = []
    
    
    
    for item in data.values():
        tile_id = item["tile"]
        if tile_id in val_tile_ids:
            tile = tile_from_item(item)
            tile_path =  TilePath(tile=tile, constellation=item["spacecraft_id"], root="gs://oxeo-water/prod2")
            tile_dates[tile_path].append(item["datetime"])

            gt_mask = zarr.open_array(f"gs://oxeo-water/prod2/{tile_id}/{constellation}/mask/zimmoz", mode="r")
            ts = zarr.open_array(f"gs://oxeo-water/prod2/{tile_id}/{constellation}/timestamps", mode="r")[:]

            ts = [t[:10] for t in ts]

            ts_index = ts.index(item['datetime'])


            y_true.append(gt_mask[ts_index])

        
    for i, item in enumerate(tile_dates.items()):
        k, v = item
        mask = predict_tile_revisits(k, v,predictor, fs,-1)
        mask = np.array(mask).squeeze()
        y_pred.append(mask)
        
        
    
    y_pred = np.stack(y_pred).astype(np.uint8)
    y_true = np.stack(y_true).astype(np.uint8)
        
    print(y_true.shape, y_pred.shape)


    df = pd.DataFrame(columns=cols) # one df per constellation
    for i in range(y_true.shape[0]):
        series = pd.Series()
        metrics_res = []
        for k, v in metrics.items():
            y_true_values = np.unique(y_true[i])
            metric = multiclass_metric(v, y_true[i], y_pred[i])
            if not 1 in y_true_values:
                metrics_res.append(None)
            else:
                metrics_res.append(metric[1])
        #plot_imgs_in_row([y_true[i],y_pred[i]], ("y_true","y_pred"))
       
        #plt.show()
            
        
            

        df = df.append(pd.Series(metrics_res, index=cols), ignore_index=True)
    display(df.describe())

    