In [1]:
import re

import ee
import geopandas as gpd
import numpy as np
import pandas as pd

import agrigee_lite as agl

%load_ext autoreload
%autoreload 2

# Disable numpy scientific notation
np.set_printoptions(suppress=True, precision=3)

In [2]:
ee.Initialize(opt_url="https://earthengine-highvolume.googleapis.com", project="ee-paulagibrim")

In [None]:
# gdf = gpd.read_parquet("data_new/BA.parquet")

# gdf["y_true"] = gdf.crop_class.map({"Mosaic of Uses": 0,
# "Forest Plantation": 1,
# "Soybean": 2,
# "Other Temporary Crops": 3,
# "Other Perennial Crops": 4,
# "Sugar Cane": 5,
# "Coffee": 6,
# "Cotton": 7})

# from sklearn.model_selection import StratifiedShuffleSplit

# def split_stratified_gdf(gdf: gpd.GeoDataFrame, label_col: str = "y_true", seed: int = 42):
#     gdf = gdf.copy()

#     # Primeiro split: 70% train, 30% restante
#     split1 = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=seed)
#     train_idx, temp_idx = next(split1.split(gdf, gdf[label_col]))

#     # Criação da coluna set com valores iniciais
#     gdf["set"] = ""

#     gdf.loc[train_idx, "set"] = "train"

#     # Segundo split: 50% validate, 50% test dos 30% restantes
#     temp_gdf = gdf.iloc[temp_idx]
#     split2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=seed)
#     val_idx, test_idx = next(split2.split(temp_gdf, temp_gdf[label_col]))

#     gdf.loc[temp_gdf.iloc[val_idx].index, "set"] = "validate"
#     gdf.loc[temp_gdf.iloc[test_idx].index, "set"] = "test"

#     return gdf

# gdf = split_stratified_gdf(gdf)

In [3]:
gdf = gpd.read_parquet("BA_agl_ready.parquet")

In [None]:
# gdf = gdf.sample(1000, random_state=42)
# gdf = gdf.reset_index(drop=True)

In [None]:
class AGLSitsDatasetPreCalc():
    def __init__(self, gdf: str | gpd.GeoDataFrame, use_split: str | None = None, transform = None, max_observations: int = 75):
        if isinstance(gdf, str):
            self.gdf = gpd.read_parquet(gdf)
        elif isinstance(gdf, gpd.GeoDataFrame):
            self.gdf = gdf
        else:
            raise TypeError("gdf must be a string (file path) or a GeoDataFrame")  # noqa: TRY003

        if use_split is not None:
            self.gdf = self.gdf[self.gdf["set"] == use_split].reset_index(drop=True)

        self.removed_timestamps = gpd.GeoDataFrame(columns=self.gdf.columns, crs=self.gdf.crs)
        self.satellites = self.gdf.filter(regex=r'_observations$').columns.str.replace('_observations', '', regex=False).tolist()
        self.bands_order = ["blue", "green", "red", "re1", "re2", "re3", "nir", "re4", "swir1", "swir2", "vv", "vh"]

        self.transform = transform

        self.num_classes = self.gdf.y_true.nunique()

In [10]:
from tqdm.std import tqdm

class AGLS2Dataset():
    def __init__(self, gdf: str | gpd.GeoDataFrame, use_split: str | None = None, transform = None, max_observations: int = 75):
        if isinstance(gdf, str):
            self.gdf = gpd.read_parquet(gdf)
        elif isinstance(gdf, gpd.GeoDataFrame):
            self.gdf = gdf
        else:
            raise TypeError("gdf must be a string (file path) or a GeoDataFrame")  # noqa: TRY003

        if use_split is not None:
            self.gdf = self.gdf[self.gdf["set"] == use_split].reset_index(drop=True)

        self.removed_timestamps = gpd.GeoDataFrame(columns=self.gdf.columns, crs=self.gdf.crs)
        self.transform = transform
        self.max_observations = max_observations

    def filter_observations(self, min_observations: int = 1):
        self.removed_timestamps = self.gdf[self.gdf["s2sr_observations"] < min_observations].copy()
        self.gdf = self.gdf[self.gdf["s2sr_observations"] >= min_observations].reset_index(drop=True)

        self.bands = np.zeros((len(gdf), self.max_observations, 10), dtype=np.float16)
        self.doys = np.zeros((len(gdf), self.max_observations), dtype=np.uint16)

        for idx in tqdm(range(len(self.gdf))):
            long_sits = agl.misc.wide_to_long_dataframe(self.gdf[sorted(filter(lambda x: x.startswith("s2sr_"), self.gdf.columns.tolist()))].iloc[idx:idx+1])
            print(long_sits.columns.tolist())
            single_bands = long_sits[["blue", "green", "red", "re1", "re2", "re3", "nir", "re4", "swir1", "swir2"]].to_numpy(dtype=np.float16)
            single_doys = long_sits.doy.to_numpy()

            self.bands[idx, :len(single_bands), :] = single_bands
            self.doys[idx, :len(single_doys)] = single_doys


In [11]:
new_dataset = AGLS2Dataset(gdf)
new_dataset.filter_observations()

  0%|          | 1/292114 [00:06<503:07:56,  6.20s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 2/292114 [00:15<636:42:49,  7.85s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 3/292114 [00:19<516:05:56,  6.36s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 4/292114 [00:25<496:40:27,  6.12s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 5/292114 [00:30<456:55:33,  5.63s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 6/292114 [00:34<408:56:42,  5.04s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 7/292114 [00:37<359:16:54,  4.43s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 8/292114 [00:40<334:03:27,  4.12s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 9/292114 [00:45<339:01:03,  4.18s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 10/292114 [00:49<339:37:30,  4.19s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 11/292114 [00:53<346:42:59,  4.27s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 12/292114 [00:58<361:32:30,  4.46s/it]

['indexnum', 'blue', 'doy', 'green', 'nir', 're1', 're2', 're3', 're4', 'red', 'swir1', 'swir2']


  0%|          | 12/292114 [01:04<435:45:11,  5.37s/it]


KeyboardInterrupt: 

In [3]:
pattern = rf'^{re.escape("s2sr")}_'
filtered_columns = gdf.filter(regex=pattern).columns

In [6]:
gdf.s2sr_observations.argmax()

np.int64(88390)

In [12]:
bands = np.zeros((len(gdf), 120, 10), dtype=np.float16)
doys = np.zeros((len(gdf), 120), dtype=np.uint16)

In [None]:
agl.misc.wide_to_long_dataframe(gdf[filtered_columns])

In [None]:
# results = agl.get.multiple_sits(gdf, agl.sat.Landsat8(), force_redownload=True)
# gdf = pd.concat([gdf, results], axis=1)

In [7]:
class AGLSitsDataset():
    def __init__(self, gdf: str | gpd.GeoDataFrame, use_split: str | None = None, transform = None):
        if isinstance(gdf, str):
            self.gdf = gpd.read_parquet(gdf)
        elif isinstance(gdf, gpd.GeoDataFrame):
            self.gdf = gdf
        else:
            raise TypeError("gdf must be a string (file path) or a GeoDataFrame")  # noqa: TRY003

        if use_split is not None:
            self.gdf = self.gdf[self.gdf["set"] == use_split].reset_index(drop=True)

        self.removed_timestamps = gpd.GeoDataFrame(columns=self.gdf.columns, crs=self.gdf.crs)
        self.satellites = self.gdf.filter(regex=r'_observations$').columns.str.replace('_observations', '', regex=False).tolist()
        self.bands_order = ["blue", "green", "red", "re1", "re2", "re3", "nir", "re4", "swir1", "swir2", "vv", "vh"]

        self.transform = transform
        self.num_classes = self.gdf.y_true.nunique()

    def filter_observations(self, satellites: list[str] | None = None, min_observations: int = 1):
        if satellites is None:
            satellites = self.satellites
        elif not set(satellites).issubset(set(self.satellites)):
            raise ValueError("Some satellites are not available in the dataset")  # noqa: TRY003

        pattern = r'^(?:' + '|'.join([re.escape(sat) + r'_observations$' for sat in satellites]) + r')'
        filtered_columns = self.gdf.filter(regex=pattern).columns
        keep_gdf = self.gdf[self.gdf[filtered_columns].min(axis=1) >= min_observations].reset_index(drop=True)
        remove_gdf = self.gdf[self.gdf[filtered_columns].min(axis=1) < min_observations].reset_index(drop=True)
        if keep_gdf.empty:
            raise ValueError("No observations found for the specified satellites with the given minimum observations")  # noqa: TRY003
        self.gdf = keep_gdf
        self.removed_timestamps = pd.concat([self.removed_timestamps, remove_gdf], ignore_index=True)

    def __getitem__(self, idx: int):
        result = {}

        for satellite in self.satellites:
            pattern = rf'^{re.escape(satellite)}_'
            filtered_columns = self.gdf.filter(regex=pattern).columns

            long_sits = agl.misc.wide_to_long_dataframe(self.gdf[filtered_columns].iloc[idx:idx+1]).drop(columns=["indexnum"])
            long_sits = long_sits.dropna().reset_index(drop=True)

            if len(long_sits) == 0:
                result[f"{satellite}_doys"] = np.array([], dtype=np.uint16)
                result[f"{satellite}_bands"] = np.array([], dtype=np.float32)
                continue

            long_sits = long_sits.loc[~long_sits.duplicated("doy")].reset_index(drop=True)

            ordered_bands = [
                band for band in self.bands_order
                if band in long_sits.columns
            ]

            doys = long_sits["doy"].to_numpy().astype(np.uint16)
            bands = long_sits[ordered_bands].to_numpy().astype(np.float32)

            result[f"{satellite}_doys"] = doys
            result[f"{satellite}_bands"] = bands

        if self.transform:
            result = self.transform(result)

        return result

    def get_idx_metadata(self, idx: int):
        if idx < 0 or idx >= len(self.gdf):
            raise IndexError("Index out of bounds of DataFrame")  # noqa: TRY003

        satellite_prefixes = [re.escape(s) + "_" for s in self.satellites]
        pattern = r'^(?:' + '|'.join(satellite_prefixes) + ')'

        metadata_columns = [col for col in self.gdf.columns if not re.match(pattern, col)]

        metadata = self.gdf.loc[idx, metadata_columns]
        return metadata

In [None]:
gdf.crop_class.unique()

In [8]:
dataset = AGLSitsDataset(gdf)

In [11]:
dataset[88390]['s2sr_doys'].shape

(116,)

In [None]:
dataset.filter_observations(min_observations=12)

In [None]:
dataset.removed_timestamps

In [None]:
gdf.l8sr_observations.min()

In [None]:
dataset.get_idx_metadata(0)

In [None]:
dataset[6]

In [None]:
gdf[gdf.l8sr_observations>0]