# Earth Engine Dataset Search Assistant

## Overview
This notebook demonstrates the creation and usage of an Earth Engine Dataset Explorer, which helps users find relevant datasets for their geospatial analysis tasks. The core functionality includes:

1. Loading and processing Earth Engine dataset metadata and pre-computed embeddings.
2. Implementing a dataset search function that uses vector similarity to find relevant datasets based on user queries.
3. Creating an interactive user interface that displays search results, dataset details, code samples, and map visualizations.

The notebook also includes an [appendix](https://colab.research.google.com/drive/1C6yFUK3av_-QB6QUdXqhGMAK1YUu5vs0#scrollTo=ySYDz8HzJyxF) that explains how the dataset summaries and embeddings were generated. It uses the [Gemini 1.5 Pro language model](https://blog.google/technology/ai/gemini-1-5/) to create concise summaries of dataset descriptions and the [Google Text Embedding API](https://cloud.google.com/natural-language/docs/embedding-overview) to generate vector representations of these summaries. These embeddings are then stored and used for efficient similarity-based dataset retrieval. The entire process, from loading the Earth Engine catalog to creating and storing the embeddings, is documented to allow users to reproduce or customize the dataset search functionality.

## Example

GIF to be added here

## Setup Details

You will need:

- A Google cloud project with the Earth Engine API enabled. ([Details](https://developers.google.com/earth-engine/cloud/earthengine_cloud_project_setup)).
- A Gemini API key. ([Details](https://ai.google.dev/gemini-api/docs/api-key).)
- (Optionally) A Google Maps API key ([Details](https://developers.google.com/maps/documentation/javascript/get-api-key).) For use with geemap.


Each of the above can be stored in the [colab "Secrets" panel](https://medium.com/@parthdasawant/how-to-use-secrets-in-google-colab-450c38e3ec75). Add the following strings as secret:

 - Use `GOOGLE_PROJECT_ID` and `EE_PROJECT_ID` for the Cloud project id.
   - This is duplicated because the geemap library expects a secret associated with the string `EE_PROJECT_ID`.
 - Use `GOOGLE_API_KEY` for the Gemini API key
 - Use `GOOGLE_API_KEY` for the Google Maps API key, also expected by geemap.


In [None]:
#@title Install Python Libraries

%%capture
!pip install google_cloud_aiplatform langchain-community langchain_google_genai chromadb langchain iso8601 bokeh

In [None]:
#@title Imports
import contextlib
import dateutil
import io
import json
import math
import os
import re
import requests
import shutil
import sys
import threading
import traceback
import time

import ee
import geemap
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from langchain.embeddings.base import Embeddings
from langchain_google_genai import ChatGoogleGenerativeAI
import numpy as np
import pandas as pd
import PIL
import vertexai
from vertexai.preview.language_models import TextEmbeddingModel

import google.ai.generativelanguage as glm
import google.api_core
from google.cloud import storage
from google.colab import userdata
import google.generativeai as genai

import ipywidgets as widgets
from IPython.display import display, HTML, Javascript
import json
import time
import ee
import geemap

import IPython
from jinja2 import Template
from io import BytesIO
from PIL import Image
import base64
from concurrent import futures
from google.colab import output as notebook_output

In [None]:
#@title Setup
project_name = userdata.get('GOOGLE_PROJECT_ID')
vertex_ai_zone = "us-central1"
ee_catalog_jsonl_path = "gs://science-ai-ee-catalog-index/catalog_summaries.jsonl"
genai.configure(api_key=userdata.get('GOOGLE_API_KEY'))


ee.Authenticate()
ee.Initialize(project=project_name)
storage_client = storage.Client(project=project_name)
vertexai.init(project=project_name, location=vertex_ai_zone)

# Make sure geemap initialized correctly.
Map = geemap.Map()
Map.add("layer_manager")

# Define classes for working with the Earth Engine data catalog

These will soon be broken up into their own files.

In [None]:
# @title Imports and Helper Methods
import dataclasses
import enum
import tqdm
from typing import Any, Iterator, Optional, Sequence, Iterable, List, Dict
from vertexai.language_models import TextEmbeddingModel
import langchain
import datetime
import logging
from concurrent import futures
import iso8601
from contextlib import redirect_stdout


CUSTOM_DATASET_SUMMARIES_GCS_PATH = 'gs://science-ai-ee-catalog-index/catalog_summaries.jsonl'
LOCAL_DATASET_SUMMARIES_PATH = 'catalog_summaries.jsonl'
BATCH_SIZE = 5

def _get_gcs_blob(storage_client, gcs_path: str):
  parts = gcs_path.split('/')
  bucket_name = parts[2]
  blob_path = '/'.join(parts[3:])
  bucket = storage_client.get_bucket(bucket_name)
  return bucket.blob(blob_path)

def _load_dataset_summaries(storage_client, dataset_summaries_gcs_path: str):
  blob = _get_gcs_blob(storage_client, dataset_summaries_gcs_path)
  blob.download_to_filename(LOCAL_DATASET_SUMMARIES_PATH)
  return LOCAL_DATASET_SUMMARIES_PATH


def matches_interval(
    collection_interval: tuple[datetime.datetime, datetime.datetime],
    query_interval: tuple[datetime.datetime, datetime.datetime],
):
  """Checks if the collection's datetime interval matches the query datetime interval.

  Args:
    collection_interval: Temporal interval of the collection.
    query_interval: a tuple with the query interval start and end

  Returns:
    True if the datetime interval matches
  """
  start_query, end_query = query_interval
  start_collection, end_collection = collection_interval
  if end_collection is None:
    # End date should always be set in STAC JSON files, but just in case...
    end_collection = datetime.datetime.now(tz=datetime.UTC)
  return end_query > start_collection and start_query <= end_collection



def matches_datetime(
    collection_interval: tuple[datetime.datetime, Optional[datetime.datetime]],
    query_datetime: datetime.datetime,
):
  """Checks if the collection's datetime interval matches the query datetime.

  Args:
    collection_interval: Temporal interval of the collection.
    query_datetime: a datetime coming from a query

  Returns:
    True if the datetime interval matches
  """
  if collection_interval[1] is None:
    # End date should always be set in STAC JSON files, but just in case...
    end_date = datetime.datetime.now(tz=datetime.UTC)
  else:
    end_date = collection_interval[1]
  return collection_interval[0] <= query_datetime <= end_date

In [None]:
# @title class BBox()
@dataclasses.dataclass
class BBox:
  """Class representing a lat/lon bounding box."""
  west: float
  south: float
  east: float
  north: float

  def is_global(self) -> bool:
    return (
        self.west == -180 and self.south == -90 and
        self.east == 180 and self.north == 90)

  @classmethod
  def from_list(cls, bbox_list: list[float]):
    """Constructs a BBox from a list of four numbers [west,south,east,north]."""
    if bbox_list[0] > bbox_list[2]:
      raise ValueError(
          'The smaller (west) coordinate must be listed first in a bounding box'
          f' corner list. Found {bbox_list}'
      )
    if bbox_list[1] > bbox_list[3]:
      raise ValueError(
          'The smaller (south) coordinate must be listed first in a bounding'
          f' box corner list. Found {bbox_list}'
      )
    return cls(bbox_list[0], bbox_list[1], bbox_list[2], bbox_list[3])

  def to_list(self) -> list[float]:
    return [self.west, self.south, self.east, self.north]

  def intersects(self, query_bbox) -> bool:
    """Checks if this bbox intersects with the query bbox.

    Doesn't handle bboxes extending past the antimeridaian.

    Args:
      query_bbox: Bounding box from the query.

    Returns:
      True if the two bounding boxes intersect
    """
    return (
        query_bbox.west < self.east
        and query_bbox.east > self.west
        and query_bbox.south < self.north
        and query_bbox.north > self.south
    )

In [None]:
# @title class Collection()
class Collection:
  """A simple wrapper for a STAC Collection.."""
  stac_json: dict[str, Any]

  def __init__(self, stac_json: dict[str, Any]):
    self.stac_json = stac_json
    if stac_json.get('gee:status') == 'deprecated':
      # Set the STAC 'deprecated' field that we don't set in the jsonnet files
      stac_json['deprecated'] = True

  def __getitem__(self, item: str) -> Any:
    return self.stac_json[item]

  def get(self, item: str, default: Optional[Any] = None) -> Optional[Any]:
    """Matches dict's get by returning None if there is no item."""
    return self.stac_json.get(item, default)

  def public_id(self) -> str:
    return self['id']

  def hyphen_id(self) -> str:
    return self['id'].replace('/', '_')

  def get_dataset_type(self) -> str:
    """Could be Image, ImageCollection, FeatureCollection, Feature."""
    return self['gee:type']

  def is_deprecated(self) -> bool:
    """Returns True for collections that are deprecated or have a successor."""
    if self.get('deprecated', False):
      logging.info('Skipping deprecated collection: %s', self.public_id())
      return True

  def datetime_interval(
      self,
  ) -> Iterable[tuple[datetime.datetime, Optional[datetime.datetime]]]:
    """Returns datetime objects representing temporal extents."""
    for stac_interval in self.stac_json['extent']['temporal']['interval']:
      if not stac_interval[0]:
        raise ValueError(
            'Expected a non-empty temporal interval start for '
            + self.public_id()
        )
      start_date = iso8601.parse_date(stac_interval[0])
      if stac_interval[1] is not None:
        end_date = iso8601.parse_date(stac_interval[1])
      else:
        end_date = None
      yield (start_date, end_date)

  def start(self) -> datetime.datetime:
    return list(self.datetime_interval())[0][0]

  def start_str(self) -> datetime.datetime:
    return self.start().strftime("%Y-%m-%d")

  def end(self) -> Optional[datetime.datetime]:
    return list(self.datetime_interval())[0][1]

  def end_str(self) -> Optional[datetime.datetime]:
    return self.end().strftime("%Y-%m-%d")

  def bbox_list(self) -> Sequence[BBox]:
    if 'extent' not in self.stac_json:
      # Assume global if nothing listed.
      return (BBox(-180, -90, 180, 90),)
    return tuple([
        BBox.from_list(x)
        for x in self.stac_json['extent']['spatial']['bbox']
    ])

  def bands(self) -> List[Dict]:
    summaries = self.stac_json.get('summaries')
    if not summaries:
      return []
    return summaries.get('eo:bands', [])

  def spatial_resolution_m(self) -> float:
    summaries = self.stac_json.get('summaries')
    if not summaries:
      return -1
    if summaries.get('gsd'):
      return summaries.get('gsd')[0]

    # Hacky fallback for cases where the stac does not follow convention.
    gsd_lst = re.findall(r'"gsd": (\d+)', json.dumps(self.stac_json))

    if len(gsd_lst) > 0:
      return float(gsd_lst[0])

    return -1


  def temporal_resolution_str(self) -> str:
    interval_dict = self.stac_json.get('gee:interval')
    if not interval_dict:
      return ""
    return f"{interval_dict['interval']} {interval_dict['unit']}"


  def python_code(self)-> str:
    code = self.stac_json.get('code')
    if not code:
      return ''

    return code.get('py_code')

  def image_preview_url(self):
    for link in self.stac_json['links']:
      if 'rel' in link and link['rel'] == 'preview' and link['type'] == 'image/png':
        return link['href']
    raise ValueError(f"No preview image found for {id}")


  def catalog_url(self):
    links = self.stac_json['links']
    for link in links:
      if 'rel' in link and link['rel'] == 'catalog':
        return link['href']

      # Ideally there would be a 'catalog' link but sometimes there isn't.
      base_url = "https://developers.google.com/earth-engine/datasets/catalog/"
      if link['href'].startswith(base_url):
        return link['href'].split('#')[0]

    logging.warning(f"No catalog link found for {self.public_id()}")
    return ""

In [None]:
# @title class CollectionList()
class CollectionList(Sequence[Collection]):
  """List of stac.Collections; can be filtered to return a smaller sublist."""

  _collections = Sequence[Collection]

  def __init__(self, collections: Sequence[Collection]):
    self._collections = tuple(collections)

  def __iter__(self):
    return iter(self._collections)

  def __getitem__(self, index):
    return self._collections[index]

  def __len__(self):
    return len(self._collections)

  def __eq__(self, other: object) -> bool:
    if isinstance(other, CollectionList):
      return self._collections == other._collections
    return False

  def __hash__(self) -> int:
    return hash(self._collections)

  def filter_by_ids(self, ids: Iterable[str]):
    """Returns a sublist with only the collections matching the given ids."""
    return self.__class__(
        [c for c in self._collections if c.public_id() in ids]
    )

  def filter_by_datetime(
      self,
      query_datetime: datetime.datetime,
  ):
    """Returns a sublist with the time interval matching the given time."""
    result = []
    for collection in self._collections:
      for datetime_interval in collection.datetime_interval():
        if matches_datetime(datetime_interval, query_datetime):
          result.append(collection)
          break
    return self.__class__(result)

  def filter_by_interval(
      self,
      query_interval: tuple[datetime.datetime, datetime.datetime],
  ):
    """Returns a sublist with the time interval matching the given interval."""
    result = []
    for collection in self._collections:
      for datetime_interval in collection.datetime_interval():
        if matches_interval(datetime_interval, query_interval):
          result.append(collection)
          break
    return self.__class__(result)

  def filter_by_bounding_box_list(
      self, query_bbox: BBox):
    """Returns a sublist with the bbox matching the given bbox."""
    result = []
    for collection in self._collections:
      for collection_bbox in collection.bbox_list():
        if collection_bbox.intersects(query_bbox):
          result.append(collection)
          break
    return self.__class__(result)

  def filter_by_bounding_box(
      self, query_bbox: BBox):
    """Returns a sublist with the bbox matching the given bbox."""
    result = []
    for collection in self._collections:
      for collection_bbox in collection.bbox_list():
        if collection_bbox.intersects(query_bbox):
          result.append(collection)
          break
    return self.__class__(result)


  def start_str(self) -> datetime.datetime:
      return self.start().strftime("%Y-%m-%d")


  def sort_by_spatial_resolution(self, reverse=False):
        """
        Sorts the collections based on their spatial resolution.
        Collections with spatial_resolution_m() == -1 are pushed to the end.

        Args:
            reverse (bool): If True, sort in descending order (highest resolution first).
                            If False (default), sort in ascending order (lowest resolution first).

        Returns:
            CollectionList: A new CollectionList instance with sorted collections.
        """
        def sort_key(collection):
            resolution = collection.spatial_resolution_m()
            if resolution == -1:
                return float('inf') if not reverse else float('-inf')
            return resolution

        sorted_collections = sorted(
            self._collections,
            key=sort_key,
            reverse=reverse
        )
        return self.__class__(sorted_collections)


  def limit(self, n: int):
    """
    Returns a new CollectionList containing the first n entries.

    Args:
        n (int): The number of entries to include in the new list.

    Returns:
        CollectionList: A new CollectionList instance with at most n collections.
    """
    return self.__class__(self._collections[:n])


  def to_df(self):
    """Converts a collection list to a dataframe with a select set of fields."""

    rows = []
    for col in self._collections:
      # Remove text in parens in dataset name.
      short_title = re.sub(r'\([^)]*\)', '', col.get('title')).strip()

      row = {
          'id': col.public_id(),
          'name': short_title,
          'temp_res': col.temporal_resolution_str(),
          'spatial_res_m': col.spatial_resolution_m(),
          'earliest': col.start_str(),
          'latest': col.end_str(),
          'url': col.catalog_url()
      }
      rows.append(row)
    return pd.DataFrame(rows)

In [None]:
#@title class Catalog()
class Catalog:
  """Class containing all collections in the EE STAC catalog."""

  collections: CollectionList

  def __init__(self, storage_client: storage.Client):
    self.collections = CollectionList(self._load_collections(storage_client))

  def get_collection(self, id: str) -> Collection:
    """Returns the collection with the given id."""
    col = self.collections.filter_by_ids([id])
    if len(col) == 0:
      raise ValueError(f'No collection with id {id}')
    return col[0]

  def _read_file(self, file_blob: google.cloud.storage.blob.Blob) -> Collection:
    """Reads the contents of a file from the specified bucket."""
    file_contents = file_blob.download_as_string().decode()
    return Collection(json.loads(file_contents))

  def _read_files(
      self, file_blobs: list[google.cloud.storage.blob.Blob]
  ) -> list[Collection]:
    """Processes files in parallel."""
    collections = []
    with futures.ThreadPoolExecutor(max_workers=10) as executor:
      file_futures = [
          executor.submit(self._read_file, file_blob)
          for file_blob in file_blobs
      ]
      for future in file_futures:
        collections.append(future.result())
    return collections

  def _load_collections(
      self, storage_client: storage.Client
  ) -> Sequence[Collection]:
    """Loads all EE STAC JSON files from GCS, with datetimes as objects."""
    bucket = storage_client.get_bucket('earthengine-stac')
    files = [
        x
        for x in bucket.list_blobs(prefix='catalog/')
        if x.name.endswith('.json')
        and not x.name.endswith('/catalog.json')
        and not x.name.endswith('/units.json')
    ]
    logging.warning('Found %d files, loading...', len(files))
    collections = self._read_files(files)

    code_samples_dict = self._load_all_code_samples(storage_client)

    res = []
    for c in collections:
      if c.is_deprecated():
        continue
      c.stac_json['code'] = code_samples_dict.get(c.hyphen_id())
      res.append(c)
    logging.warning(
        'Loaded %d collections (skipping deprecated ones)', len(res)
    )
    # Returning a tuple for immutability.
    return tuple(res)

  def _load_all_code_samples(self, storage_client: storage.Client):
    """Loads js + py example scripts from GCS into dict keyed by dataset ID."""

    # Get json file from GCS bucket
    # 'gs://earthengine-catalog/catalog/example_scripts.json'
    bucket = storage_client.get_bucket('earthengine-catalog')
    blob= bucket.blob('catalog/example_scripts.json')
    file_contents = blob.download_as_string().decode()
    data = json.loads(file_contents)

    # Flatten json to get a map from ID (using '_' rather than '/') to code
    # sample.
    all_datasets_by_provider = data[0]['contents']
    code_samples_dict = {}
    for provider in all_datasets_by_provider:
      for dataset in provider['contents']:
        js_code = dataset['code']
        py_code = self._make_python_code_sample(js_code)
        code_samples_dict[dataset['name']] = {
            'js_code': js_code, 'py_code': py_code}

    return code_samples_dict

  def _make_python_code_sample(self, js_code: str) -> str:
    """Converts EE JS code into python."""

    # geemap appears to have some stray print statements.
    _ = io.StringIO()
    with redirect_stdout(_):
      code_list = geemap.js_snippet_to_py(js_code,
                                      add_new_cell=False,
                                      import_ee=False,
                                      import_geemap=False,
                                      show_map=False)
    return ''.join(code_list)

## Make sure it's working


In [None]:
catalog = Catalog(storage_client)



In [None]:
col = catalog.get_collection('CGIAR/SRTM90_V4')
Map = geemap.Map()
exec(col.python_code(), {'ee': ee, 'Map': Map, 'm': Map})
Map

Map(center=[36.2841, -112.8598], controls=(WidgetControl(options=['position', 'transparent_bg'], widget=Search…

In [None]:
col_list = catalog.collections.filter_by_ids(['CGIAR/SRTM90_V4', 'CIESIN/GPWv411/GPW_Land_Area'])
col_list
df = col_list.to_df()
HTML(df.to_html(render_links=True, escape=False))

Unnamed: 0,id,name,temp_res,spatial_res_m,earliest,latest,url
0,CGIAR/SRTM90_V4,SRTM Digital Elevation Data Version 4,,90.0,2000-02-11,2000-02-22,https://developers.google.com/earth-engine/datasets/catalog/CGIAR_SRTM90_V4
1,CIESIN/GPWv411/GPW_Land_Area,GPWv411: Land Area,,927.67,2000-01-01,2020-01-01,https://developers.google.com/earth-engine/datasets/catalog/CIESIN_GPWv411_GPW_Land_Area


# Dataset Search Logic

We load some pre-generated, per-dataset embeddings into a [vector store](https://cloud.google.com/discover/what-is-a-vector-database) as the backbone to our dataset search tool.

This tool can either be leveraged on its own, or invovked by an LLM "agent" as demonstrated later on in this notebook.

In [None]:
# @title Embeddings existing location

# Pre-built embeddings. See notebook Appendix for details
EMBEDDINGS_CLOUD_PATH = 'gs://science-ai-ee-catalog-index/catalog_embeddings.jsonl'

# Copy embeddings from GCS bucket to a local file
EMBEDDINGS_LOCAL_PATH = 'catalog_embeddings.jsonl'


In [None]:
#@title Embeddings classes and helper methods
from langchain.embeddings.base import Embeddings
from langchain.indexes import VectorstoreIndexCreator
from langchain.schema import Document
from langchain_core.vectorstores.base import VectorStore
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain_core.language_models.base import BaseLanguageModel
import numpy as np


class PrecomputedEmbeddings(Embeddings):
    def __init__(self, embeddings_dict):
        self.embeddings_dict = embeddings_dict
        self.model = TextEmbeddingModel.from_pretrained("google/text-embedding-004")

    def embed_documents(self, texts):
        return [self.embeddings_dict[text] for text in texts]

    def embed_query(self, text):
      embeddings = self.model.get_embeddings([text])
      return embeddings[0].values


def load_embeddings(gcs_path=EMBEDDINGS_CLOUD_PATH, local_path=EMBEDDINGS_LOCAL_PATH):
  parts = gcs_path.split('/')
  bucket_name = parts[2]
  blob_path = '/'.join(parts[3:])
  bucket = storage_client.get_bucket(bucket_name)
  blob = bucket.blob(blob_path)
  blob.download_to_filename(local_path)
  return local_path

def make_langchain_index(embeddings_df: pd.DataFrame) -> VectorStoreIndexWrapper:
  """Creates an index from a dataframe of precomputed embeddings."""
  # Create a dictionary mapping texts to their embeddings
  embeddings_dict = dict(zip(embeddings_df['id'], embeddings_df['embedding']))

  # Create our custom embeddings class
  precomputed_embeddings = PrecomputedEmbeddings(embeddings_dict)

  # Create Langchain Document objects
  documents = []
  for index, row in embeddings_df.iterrows():
    page_content = row['id']
    metadata = {'summary': row['summary'], 'name': row['name']}
    documents.append(Document(page_content=page_content, metadata=metadata))


  # Create the VectorstoreIndexCreator
  index_creator = VectorstoreIndexCreator(
      embedding=precomputed_embeddings
  )

  # Create the index
  return index_creator.from_documents(documents)

# Wrap Langchain embeddings in our own EE dataset wrapper
class EarthEngineDatasetIndex():
  index: VectorStoreIndexWrapper
  vectorstore: VectorStore
  data_catalog: Catalog
  llm: BaseLanguageModel

  def __init__(self, data_catalog, index, llm):
    self.index = index
    self.data_catalog = data_catalog
    self.vectorstore = index.vectorstore
    self.llm = llm


  def find_top_matches(
      self,
      query: str,
      results: int = 10,
      threshold: float = 0.7,
      bounding_box: Optional[list[float]] = None,
      temporal_interval: tuple[datetime.datetime, datetime.datetime] = None) -> CollectionList:
    """
    Retrieve relevant dataset from the Earth Engine data catalog.

    query: str. The kind of data being searched for. ie 'population'.
    results: int. The number of datasets to return. 4 is recommended.
    threshold: float. The maximum dot product between the query and catalog
      embeddings. Recommended 0.7.
    bounding_box: Optional[list[float]]. The spatial bounding box for the query, in the
      format [lon1, lat1, lon2, lon2]. If None then no spatial filter is appled.
    temporal: Optional[list[Optional[list[int]]]]. If provided, temporal
      constraints are provided as a list of two int lists following the structure
      [[year, month, day], [year, month, day]]. A none can be used to set no
      start or end date. For example [None, [2022,12,31]] will return all datasets
      that have data before 2022-12-31.)
    """
    similar_docs = self.index.vectorstore.similarity_search_with_score(query, llm=self.llm, k=results)
    dataset_ids = [doc[0].page_content for doc in similar_docs]
    datasets = self.data_catalog.collections.filter_by_ids(dataset_ids)
    return datasets


  def find_top_matches_with_score_df(self,
      query: str,
      results: int = 20,
      threshold: float = 0.7,
      bounding_box: Optional[list[float]] = None,
      temporal_interval: tuple[datetime.datetime, datetime.datetime] = None):
    similar_docs = self.index.vectorstore.similarity_search_with_score(query, llm=self.llm, k=results)
    dataset_ids, scores = zip(*[(doc[0].page_content, doc[1]) for doc in similar_docs])

    col_list = self.data_catalog.collections.filter_by_ids(list(dataset_ids))
    df = col_list.to_df()
    # truncate scores to 2 decimals places
    df['relevance'] = [round(score, 2) for score in scores]
    return df




In [None]:
from google.colab import data_table

# Load our embeddings data into a dataframe:
data_table.disable_dataframe_formatter()

local_path = load_embeddings(EMBEDDINGS_CLOUD_PATH, EMBEDDINGS_LOCAL_PATH)
embeddings_df = pd.read_json(local_path, lines=True)
embeddings_df.head()

Unnamed: 0,id,metadata,page_content,type,embedding
0,,"{'file_path': 'catalog/AAFC/AAFC_ACI.jsonnet',...","Since 2009, Agriculture and Agri-Food Canada (...",Document,"[-0.0323356427, 0.0206475724, -0.0276225228, -..."
1,,{'file_path': 'catalog/ACA/ACA_reef_habitat_v1...,"The Allen Coral Atlas, funded by Vulcan Inc. a...",Document,"[-0.0067234263, 0.0562574901, -0.0420167074, 0..."
2,,{'file_path': 'catalog/ACA/ACA_reef_habitat_v2...,"The Allen Coral Atlas, a project led by Arizon...",Document,"[-0.0015228223, 0.0610635951, -0.0483361706, 0..."
3,,{'file_path': 'catalog/AHN/AHN_AHN2_05M_INT.js...,The AHN DEM is a detailed (0.5m resolution) el...,Document,"[-0.0101616979, -0.0809158906, -0.0406738855, ..."
4,,{'file_path': 'catalog/AHN/AHN_AHN2_05M_NON.js...,"The AHN DEM, a high-resolution (0.5m) model of...",Document,"[0.0055814106, -0.08407039200000001, -0.044175..."


In [None]:
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", google_api_key=userdata.get('GOOGLE_API_KEY'))

local_path = load_embeddings(EMBEDDINGS_CLOUD_PATH, EMBEDDINGS_LOCAL_PATH)
embeddings_df = pd.read_json(local_path, lines=True)
langchain_index = make_langchain_index(embeddings_df)

In [None]:
from google.colab import data_table
data_table.enable_dataframe_formatter()



ee_index = EarthEngineDatasetIndex(catalog, langchain_index, llm)
df = ee_index.find_top_matches_with_score_df("Datasets to measure changes in lakes over time.")
# make a sample snippet of html for a hyperlink

data_table.DataTable(df, include_index=False, num_rows_per_page=20, min_width=300)

Unnamed: 0,id,name,temp_res,spatial_res_m,earliest,latest,url,relevance
0,CIESIN/GPWv411/GPW_Water_Area,GPWv411: Water Area,,927.67,2000-01-01,2020-01-01,https://developers.google.com/earth-engine/dat...,0.64
1,ECMWF/ERA5_LAND/DAILY_RAW,ERA5-Land Daily Aggregated - ECMWF Climate Rea...,1 month,11132.0,1963-07-11,2023-02-27,https://developers.google.com/earth-engine/dat...,0.63
2,GLIMS/20230607,GLIMS 2023: Global Land Ice Measurements From ...,,-1.0,1750-01-01,2023-06-07,https://developers.google.com/earth-engine/dat...,0.62
3,GOOGLE/GLOBAL_CCDC/V1,Google Global Landsat-based CCDC Segments,,30.0,1999-01-01,2020-01-01,https://developers.google.com/earth-engine/dat...,0.62
4,JRC/GSW1_4/GlobalSurfaceWater,"JRC Global Surface Water Mapping Layers, v1.4",,30.0,1984-03-16,2022-01-01,https://developers.google.com/earth-engine/dat...,0.61
5,JRC/GSW1_4/Metadata,"JRC Global Surface Water Metadata, v1.4",,30.0,1984-03-16,2022-01-01,https://developers.google.com/earth-engine/dat...,0.61
6,JRC/GSW1_4/MonthlyHistory,"JRC Monthly Water History, v1.4",1 month,30.0,1984-03-16,2022-01-01,https://developers.google.com/earth-engine/dat...,0.6
7,JRC/GSW1_4/YearlyHistory,"JRC Yearly Water Classification History, v1.4",1 year,30.0,1984-03-16,2022-01-01,https://developers.google.com/earth-engine/dat...,0.6
8,LANDSAT/COMPOSITES/C02/T1_L2_ANNUAL_NDWI,Landsat Collection 2 Tier 1 Level 2 Annual NDW...,1 year,30.0,1984-01-01,2024-01-01,https://developers.google.com/earth-engine/dat...,0.6
9,USGS/WBD/2017/HUC08,HUC08: USGS Watershed Boundary Dataset of Subb...,,-1.0,2017-04-22,2017-04-23,https://developers.google.com/earth-engine/dat...,0.6


# Gemini Methods

Python functions that somehow invoke an LLM.

In [None]:
import typing_extensions as typing


def country_code_to_bbox(country_code: str) -> List[float]:
  """Returns a bounding box for the given country code."""
  if country_code == 'global':
    return BBox(-180, -90, 180, 90)
  country_boundaries = ee.FeatureCollection('WM/geoLab/geoBoundaries/600/ADM0')
  country = ee.Feature(
      country_boundaries.filterMetadata('shapeGroup', 'equals', country_code).first())
  coord_list = country.bounds().geometry().coordinates().getInfo()[0]
  west = coord_list[0][0]  # Minimum X coordinate
  south = coord_list[0][1]  # Minimum Y coordinate
  east = coord_list[2][0]  # Maximum X coordinate
  north = coord_list[2][1]  # Maximum Y coordinate

  return BBox(west, south, east, north)

class DataSearchSpec(typing.TypedDict):
  dataset_description: str
  country_of_interest: str
  country_code_iso_3166: str
  start_date: str
  end_date: str

def extract_problem_spec(query: str, model_name: str = 'gemini-1.5-pro-latest'):
  """Prompts LLM to extract query parameters from an NL query."""

  # Set the `response_mime_type` to output JSON
  # Pass the schema object to the `response_schema` field
  generation_config = {
      "response_mime_type": "application/json",
      "response_schema": DataSearchSpec}

  model = genai.GenerativeModel(model_name, generation_config=generation_config)

  prompt = f"""
  An Earth Engine user has described the type of problem they are interested
  in studying and has requested a list of relevant datasets in the Earth Engine
  data catalog to help study their problem. A downstream function will perform
  the datataset search itself. Your task is to provide relevant input into the
  dataset search function.

  Summarize the type of dataset the user may be seeking in the dataset_description field.

  If the user specifies a region of interest in the query, determine
  what country the region of interest (ROI) is contained within.
  Set country_of_interest to the name of the country. Also specify the
  correspondingISO 3166-1 alpha-3 three letter country code.

  Use your pretraining knowledge to determine what country contains a ROI.
  For example, if the query is "How is water pollution affecting the health of
  the Ganges River and the communities that depend on it?", recall that the
  Ganges River is in the country India.


  If the user does not specify
  a region of interest, set country_of_interest to 'global'.


  IMPORTANT: if the query suggests multiple countries of interest, just
  set country_of_interest and country_code_iso_3166 'global'.

  If the user specifies a time period, set start_date and end_date to strings
  of the form YYYY-MM-DD. If no time period is specified, set start_date and end_date
  to an empty string. For example, if the user says something like "all fires after 2000",
  set start_date to '2000-01-01' and leave the end_date as an empty string.

  Here is the original query: {query}
  """
  response = model.generate_content(prompt)
  answer = response.text.replace('\n', '')

  gemini_json_fail = True
  while gemini_json_fail:
    response = model.generate_content(prompt)
    answer = response.text.replace('\n', '')
    try:
      json_answer = json.loads(answer)
      gemini_json_fail = False
    except json.JSONDecodeError:
      pass

  json_answer['bbox'] = country_code_to_bbox(json_answer['country_code_iso_3166'])
  return json_answer

def explain_relevance(
    query: str,
    dataset_id: str,
    catalog: Catalog,
    model_name: str = 'gemini-1.5-pro-latest'):
  """Prompts LLM to explain the relevance of a dataset to a query."""

  stac_json = catalog.get_collection(dataset_id).stac_json
  return explain_relevance_from_stac_json(query, stac_json, model_name)



def explain_relevance_from_stac_json(
    query, stac_json, model_name: str = 'gemini-1.5-pro-latest'):

  stac_json_str = json.dumps(stac_json)

  prompt = f'''
  I am an Earth Engine user contemplating using a dataset to support
  my investigation of the following query. Provide a concise, paragraph-long
  summary explaining why this dataset may be a good fit for my use case.
  If it does not seem like an appropriate dataset, say so.
  If relevant, call attention to a max of 3 bands that may be of particular interest.
  Weigh the tradeoffs between temporal and spatial resolution, particularly
  if the original query specifies regions of interest, time periods, or
  frequency of data collection. If I have not specified any
  spatial constraints, do your best based on the nature of their query. For example,
  if I'm wanting to study something small, like buildings, I will likely need good spatial resolution.

  Here is the original query:
  {query}

  Here is the stac json metadata for the dataset:
  {stac_json_str}
  '''
  model = genai.GenerativeModel(model_name)
  response = model.generate_content(prompt)
  return response.text


In [None]:
country_code_to_bbox('MEX')

BBox(west=-118.36763562904797, south=14.534546914563876, east=-86.70969478524208, north=32.718725333455886)

In [None]:
extract_problem_spec("I'd like to study fires in mexico and canada in 2020.")

{'country_code_iso_3166': 'global',
 'country_of_interest': 'global',
 'dataset_description': 'Datasets related to fire occurrence, fire risk, or fire danger',
 'end_date': '2020-12-31',
 'start_date': '2020-01-01',
 'bbox': BBox(west=-180, south=-90, east=180, north=90)}

In [None]:
extract_problem_spec("I'd like to study fires in the American West in 2020.")

{'country_code_iso_3166': 'USA',
 'country_of_interest': 'United States',
 'dataset_description': 'Datasets related to fire occurrence and burn severity',
 'end_date': '2020-12-31',
 'start_date': '2020-01-01',
 'bbox': BBox(west=144.6179488508778, south=-14.532923761685318, east=295.4362477659406, north=71.38761268905303)}

In [None]:
extract_problem_spec("How has deforestation in the Brazilian Amazon rainforest changed over the past 20 years?")

{'country_code_iso_3166': 'BRA',
 'country_of_interest': 'Brazil',
 'dataset_description': 'Datasets related to deforestation, forest cover change, or land use change',
 'end_date': '',
 'start_date': '2003-01-01',
 'bbox': BBox(west=-73.98281842982283, south=-33.75098937028551, east=-28.84297279288292, north=5.271785611922114)}

In [None]:
extract_problem_spec("How are the increasing frequency and intensity of wildfires affecting the Pantanal wetland area?")

{'country_code_iso_3166': 'BRA',
 'country_of_interest': 'Brazil',
 'dataset_description': 'Datasets related to wildfire occurrence, frequency, and intensity. Datasets related to wetland ecosystems, specifically the Pantanal.',
 'end_date': '',
 'start_date': '',
 'bbox': BBox(west=-73.98281842982283, south=-33.75098937028551, east=-28.84297279288292, north=5.271785611922114)}

In [None]:
extract_problem_spec("What is the extent of deforestation in the Congo Basin due to logging and mining activities")

{'country_code_iso_3166': 'global',
 'country_of_interest': 'global',
 'dataset_description': 'Datasets related to deforestation, logging, and mining activities. This may include land cover datasets, forest loss datasets, and extractive industry datasets.',
 'end_date': '',
 'start_date': '',
 'bbox': BBox(west=-180, south=-90, east=180, north=90)}

In [None]:
print(explain_relevance("Datasets to measure changes in lakes in the past 5 years.", 'CGIAR/SRTM90_V4', catalog))

The SRTM Digital Elevation Data Version 4 dataset is not a good fit for measuring changes in lakes over the past 5 years. While it provides global elevation data ("elevation" band) which could be used to identify water bodies in a single snapshot, its key limitation is its temporal resolution. The data was collected only once in February 2000, making it unsuitable for analyzing changes over time. To study lake dynamics, you would need a dataset with a temporal resolution of at least annual, and ideally more frequent, observations over the past 5 years. 



In [None]:
print(explain_relevance("Create flood inundation maps based on elevation and river flow data", 'CGIAR/SRTM90_V4', catalog))

The CGIAR/SRTM90_V4 dataset, representing the Shuttle Radar Topography Mission (SRTM) digital elevation data, is **highly suitable** for creating flood inundation maps.  The key band of interest is **'elevation'**, providing a global elevation model crucial for determining water flow patterns and identifying areas prone to flooding. While the dataset offers a single, static snapshot in time between February 11th and 22nd, 2000, its consistent and high-quality elevation data at a **90m resolution** makes it valuable for modeling flood inundation based on different river flow scenarios. However, the lack of temporal variation means this dataset alone cannot predict real-time flood events and would require supplementary river flow data. 



# UI code

In [None]:
#@title CSS
from google.colab import syntax
# Custom CSS for Material Design styling with enhanced table styling, chat panel, and debug panel
CSS = syntax.css("""
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500&display=swap');

body {
    font-family: 'Roboto', sans-serif;
    margin: 0;
    padding: 0;
}

.main-title {
    font-size: 24px;
    font-weight: 500;
    color: #212121;
    margin-bottom: 16px;
}

.custom-title {
    font-size: 18px;
    font-weight: 500;
    color: #212121;
    margin-bottom: 12px;
}

.details-text {
    font-size: 14px;
    color: #616161;
    line-height: 1.5;
}

.custom-table {
    width: 100%;
    border-collapse: collapse;
    margin-bottom: 24px;
    font-family: 'Roboto', sans-serif;
}
.custom-table th, .custom-table td {
    text-align: left;
    padding: 12px;
    border-bottom: 1px solid #E0E0E0;
}
.custom-table th {
    background-color: #F5F5F5;
    font-weight: 500;
    color: #212121;
}
.custom-table tr:hover {
    background-color: #E3F2FD;
}
.custom-table tr.selected {
    background-color: #BBDEFB;
}

/* Ensure borders are visible */
.jupyter-widgets.widget-box {
    border: 1px solid #E0E0E0 !important;
    overflow: auto;
}

/* Make the map span full width */
.geemap-container {
    width: 100% !important;
    height: 600px !important;
}
""")

In [None]:
#@title Main UI definition
import ipywidgets as widgets
from IPython.display import display, HTML, Javascript
from jinja2 import Template
import geemap
import time
import uuid
from google.colab import output


class DatasetSearchInterface:

  collections: CollectionList
  query: str
  dataset_table: widgets.Widget
  code_output: widgets.Widget
  details_output: widgets.Widget
  map_output: widgets.Widget
  geemap_instance: geemap.Map

  # Parent containers for controlling widget visibility.
  details_code_box: widgets.Widget
  map_widget: widgets.Widget


  def __init__(self, query: str, collections: CollectionList):

    self.query = query
    self.collections = collections

    # Create the output widgets
    self.code_output = widgets.Output(layout=widgets.Layout(width='50%'))
    self.details_output = widgets.Output(layout=widgets.Layout(height='300px', width='100%'))

    # Initialize dataset table
    table_html = self._build_table_html(collections)
    self.dataset_table = widgets.HTML(value=table_html)

    _callback_id = 'dataset-select' + str(uuid.uuid4())
    output.register_callback(_callback_id, self.update_outputs)
    self._dataset_select_js_code = self._dataset_select_js_code(_callback_id)


    # Initialize map
    self.map_output = widgets.Output(layout=widgets.Layout(width='100%'))
    self.geemap_instance = geemap.Map(height='600px', width='100%')


  def display(self):
    """Display the UI in the cell."""
    # Create title and description with Material Design styling
    title = widgets.HTML(value='<h2 class="main-title">Earth Engine Dataset Explorer</h2>')

    # Wrap outputs in a widget box for border styling
    details_widget = widgets.Box([self.details_output], layout=widgets.Layout(border='1px solid #E0E0E0', padding='10px', margin='5px', width='100%'))
    code_widget = widgets.Box([self.code_output], layout=widgets.Layout(border='1px solid #E0E0E0', padding='10px', margin='5px', width='100%'))
    self.map_widget = widgets.Box([self.map_output], layout=widgets.Layout(border='1px solid #E0E0E0', padding='10px', margin='5px', width='100%', height='600x'))

    # Create the vertical box for code and details
    self.details_code_box = widgets.VBox([details_widget, code_widget], layout=widgets.Layout(width='50%', height='600px'))

    # Create a horizontal box for map and details/code
    map_details_code_box = widgets.HBox([self.map_widget, self.details_code_box], layout=widgets.Layout(border='1px solid #E0E0E0', padding='10px', margin='5px'))

    # Create the main layout with Material Design styling
    main_content = widgets.VBox([
        self.dataset_table,
        map_details_code_box
    ], layout=widgets.Layout(width='100%', border='1px solid #E0E0E0', padding='10px', margin='5px'))

    # Add debug panel to the main layout
    main_layout = widgets.VBox([
        title,
        main_content,
    ], layout=widgets.Layout(height='1500px', width='100%', padding='24px'))

    # Display the widget
    display(HTML(f'<style>{CSS}</style>'))
    display(main_layout)
    display(Javascript(self._dataset_select_js_code))


  def _build_table_html(self, collections: CollectionList):
        # Create the table HTML
    table_html = """
    <table class="custom-table">
        <tr>
            <th>Dataset ID</th>
            <th> Name </th>
            <th>Temporal Resolution</th>
            <th>Spatial Resolution (m)</th>
            <th>Earliest</th>
            <th>Latest</th>
        </tr>
    """
    for dataset in collections:
        table_html += f"""
        <tr data-dataset="{dataset.public_id()}">
            <td>{dataset.public_id()}</td>
            <td>{dataset.get('title')}</td>
            <td>{dataset.temporal_resolution_str()}</td>
            <td>{dataset.spatial_resolution_m()}</td>
            <td>{dataset.start_str()}</td>
            <td>{dataset.end_str()}</td>
        </tr>
        """

    table_html += "</table>"
    return table_html


  def update_outputs(self, selected_dataset):
    collection = self.collections.filter_by_ids([selected_dataset])
    print(collection)

    if not collection:
      self.details_code_box.layout.visibility = 'hidden'
      self.map_widget.layout.visibility = 'hidden'
      return

    dataset = collection[0]
    with self.code_output:
        self.code_output.clear_output()
        display(HTML('<div class="custom-title">Earth Engine Code</div>'))
        print(dataset.python_code())

    with self.details_output:
      self.details_output.clear_output()
      display(HTML('<h3>Thinking...</h3>'))
      llm_thoughts = explain_relevance_from_stac_json(
          self.query, dataset.stac_json)
      details_html = f"""
      <div class="custom-title">Thoughts with Gemini</div>
      <div class="details-text"><span>{llm_thoughts}</span></div>
      """
      self.details_output.clear_output()
      display(HTML(details_html))

    with self.map_output:
        self.map_output.clear_output()

        # Clear previous layers. Keep only the base layer
        self.geemap_instance.layers = self.geemap_instance.layers[:1]
        exec(dataset.python_code(), {'ee': ee,
                                     'Map': self.geemap_instance,
                                     'm': self.geemap_instance})

        display(self.geemap_instance)

    self.details_code_box.layout.visibility = 'visible'
    self.map_widget.layout.visibility = 'visible'

    # Update debug panel
    current_time = time.strftime("%Y-%m-%d %H:%M:%S")


  def _dataset_select_js_code(self, callback_id):
    """Handles a dataset onclick event"""
    # JavaScript for handling table row selection
    return Template(syntax.javascript("""
    function initializeTableInteraction() {
        const table = document.querySelector('.custom-table');
        if (!table) {
            console.error('Table not found');
            return;
        }

        function selectRow(row) {
            // Remove selection from previously selected row
            const prevSelected = table.querySelector('tr.selected');
            if (prevSelected) prevSelected.classList.remove('selected');

            // Add selection to the new row
            row.classList.add('selected');
            const selectedDataset = row.dataset.dataset;
            console.log('Selected dataset:', selectedDataset);
            google.colab.kernel.invokeFunction('{{callback_id}}', [selectedDataset], {});

        }

        table.addEventListener('click', (event) => {
            const row = event.target.closest('tr');
            if (!row || !row.dataset.dataset) return;
            selectRow(row);
        });

        // Select the first row by default
        const firstRow = table.querySelector('tr[data-dataset]');
        if (firstRow) {
            selectRow(firstRow);
        }
    }

    // Run the initialization function after a short delay to ensure the DOM is ready
    setTimeout(initializeTableInteraction, 1000);
    """)).render(callback_id=callback_id)

# Main Demo!

In [None]:
catalog = Catalog(storage_client)



In [None]:
# @title Enter Question here
from google.colab import output
output.no_vertical_scroll()

def Question(query):
  ee_index = EarthEngineDatasetIndex(catalog, langchain_index, llm)
  datasets = ee_index.find_top_matches(query)
  datasets = datasets.sort_by_spatial_resolution().limit(5)
  # return datasets
  dataset_search = DatasetSearchInterface(query, datasets)
  dataset_search.display()

query = "I'd like to estimate the 2020 population for administrative areas in Kenya."#@param {type:"string"}
# ds = Question(query)
# ds.to_df()
Question(query)

<IPython.core.display.Javascript object>

VBox(children=(HTML(value='<h2 class="main-title">Earth Engine Dataset Explorer</h2>'), VBox(children=(HTML(va…

<IPython.core.display.Javascript object>

<__main__.CollectionList object at 0x78b34a3b3a30>


# Appendix: Generating Dataset summaries and embeddings

## Overview

In the earlier portions of this notebook, we provide "pre-baked" embeddings available for download in a GCS bucket.

 The following code documents how this was done in case someone would like to make their own embeddings for dataset search.


In [None]:
# @title Source code for dataset summarization and embedding modules
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain_core.language_models.base import BaseLanguageModel
from google.api_core.exceptions import ResourceExhausted
from tenacity import retry, stop_after_attempt, wait_fixed
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial


@retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
def summarize_text(text: str, llm: BaseLanguageModel) -> str:
    """Summarize a given text using a language model.

    This function splits the input text into chunks, then uses a map-reduce
    summarization chain to generate a summary.

    Args:
        text (str): The text to be summarized.
        llm (BaseLanguageModel): The language model to use for summarization.

    Returns:
        str: The summarized text.

    Raises:
        Exception: If summarization fails after 5 attempts.
    """
    # Remove newlines in description
    text = re.sub('\n\s*', ' ', text)

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )

    docs = text_splitter.create_documents([text])
    chain = load_summarize_chain(llm, chain_type="map_reduce")
    return chain.run(docs)


def summarize_collection(collection: 'Collection', llm: BaseLanguageModel) -> Dict[str, str]:
    """Summarize the dataset description and band information for a data collection.

    Args:
        collection (Collection): The collection object containing dataset information.
        llm (BaseLanguageModel): The language model to use for summarization.

    Returns:
        Dict[str, str]: A dictionary containing the collection's ID, name, and summarized description.
    """
    summarized_description = summarize_text(collection.get('description'), llm)

    # Adding text about individual bands improves search performance.
    band_descriptions = ""
    for band in collection.bands():
        band_descriptions += f'"{band["name"]}" represents {band["description"]}\n'
        if 'gee:classes' in band:
            band_descriptions += "    Classes:\n"
            for band_class in band['gee:classes']:
                band_descriptions += f'    {band_class["description"]}\n'

    summarized_description = summarized_description + "\n\n" + band_descriptions

    return {
        'id': collection.public_id(),
        'name': collection.get('title'),
        'summary': summarized_description
    }


def summarize_ee_catalog(catalog: 'Catalog', llm: BaseLanguageModel, output_path: Optional[str] = None) -> pd.DataFrame:
    """Generate summaries of all dataset descriptions in an Earth Engine data catalog.

    This function processes all collections in the catalog concurrently,
    summarizing each collection's description and band information.

    Args:
        catalog (Catalog): The Earth Engine data catalog to summarize.
        llm (BaseLanguageModel): The language model to use for summarization.
        output_path (Optional[str]): If provided, the path to save the output DataFrame as a JSON file.

    Returns:
        pd.DataFrame: A DataFrame containing the summarized information for all collections.

    Note:
        This function uses a ThreadPoolExecutor with a maximum of 8 workers to handle
        potential throttling issues with the language model API.
    """
    summarize_collection_partial = partial(summarize_collection, llm=llm)

    with ThreadPoolExecutor(max_workers=8) as executor:
        results = list(tqdm.tqdm(
            executor.map(summarize_collection_partial, catalog.collections),
            total=len(catalog.collections)))
    return results
    df = pd.DataFrame(results)

    if output_path:
        with open(output_path, 'w') as f:
            f.write(df.to_json(orient='records', lines=True))
    return df



def get_embeddings_wrapper(texts: List[str], model: TextEmbeddingModel):
  # VertexAI allows you to send batches of 5 embeddings requests at once.
  BATCH_SIZE = 5
  embs = []
  for i in tqdm.tqdm(range(0, len(texts), BATCH_SIZE)):
      time.sleep(1)  # to avoid the quota error
      result = model.get_embeddings(texts[i : i + BATCH_SIZE])
      embs = embs + [e.values for e in result]
  return embs


def add_embeddings_to_df(
    df: pd.DataFrame, col_to_embed: str,  model: TextEmbeddingModel) -> pd.DataFrame:
    get_embeddings_partial = partial(get_embeddings_wrapper, model=model)
    df = df.assign(embedding=get_embeddings_partial(list(df[col_to_embed])))
    return df





In [None]:
import google
from google.cloud import storage

# @title Initialize Language and Text embedding models plus output destinations

# We use Gemini 1.5 pro to summarize the original dataset descriptions
gemini_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", google_api_key=userdata.get('GOOGLE_API_KEY'))

# We use a VertexAI model for embedding the dataset summaries to eventually be
# loaded into a Vectorstore.
embedding_model = TextEmbeddingModel.from_pretrained("google/text-embedding-004")

# We write the output to disk to reduce the risk of needing to rerun.
CATALOG_SUMMARIES_PATH = 'catalog_summaries.jsonl'
EMBEDDINGS_LOCAL_PATH = 'catalog_embeddings.jsonl'

# Eventually we upload embeddings and summaries to GCS.
GCP_PROJECT = userdata.get('GOOGLE_PROJECT_ID')
DESTINATION_BUCKET = userdata.get('EMBEDDINGS_BUCKET_NAME')
EMBEDDINGS_GCS_PATH = 'catalog_embeddings.jsonl'

storage_client = storage.Client(project=GCP_PROJECT)

In [None]:
# @title Load the entire EE Pubic data catalog from GCS:
catalog = Catalog(storage_client)



In [None]:
#@title Use an LLM to generate per-collection dataset summaries.
# This tends to take around 10-15 minutes.

# Don't run this by default in case someone is blindly running everything in the notebook.
GENERATE_SUMMARIES = True #@param {type:"boolean"}

summary_json_list = []
if GENERATE_SUMMARIES:
  summary_json_list = summarize_ee_catalog(catalog, gemini_llm)

# Write to a file so we minimize the need to repeat this expensive step.
with open(CATALOG_SUMMARIES_PATH, 'w') as f:
  for entry in summary_json_list:
    json.dump(entry, f)
    f.write('\n')

100%|██████████| 867/867 [12:06<00:00,  1.19it/s]


In [None]:
# @title View summary results
catalog_summary_df = pd.read_json(CATALOG_SUMMARIES_PATH, lines=True)
catalog_summary_df.head()

Unnamed: 0,id,name,summary
0,AAFC/ACI,Canada AAFC Annual Crop Inventory,Agriculture and Agri-Food Canada annually maps...
1,ACA/reef_habitat/v2_0,Allen Coral Atlas (ACA) - Geomorphic Zonation ...,"The Allen Coral Atlas is a global, high-resolu..."
2,AHN/AHN2_05M_INT,"AHN Netherlands 0.5m DEM, Interpolated",The AHN DEM is a detailed (0.5m resolution) el...
3,AHN/AHN2_05M_NON,"AHN Netherlands 0.5m DEM, Non-Interpolated",The AHN DEM is a high-resolution (0.5m) model ...
4,AHN/AHN2_05M_RUW,"AHN Netherlands 0.5m DEM, Raw Samples",The AHN DEM is a highly detailed (0.5m resolut...


In [None]:
#@title Upload summaries to GCS
bucket = google.cloud.storage.bucket.Bucket(
    storage_client, name=DESTINATION_BUCKET, user_project=GCP_PROJECT)
blob = bucket.blob(CATALOG_SUMMARIES_PATH)
blob.upload_from_filename(CATALOG_SUMMARIES_PATH)

In [None]:
# @title Calculate embeddings for each dataset summary
# This takes around 3-5 minutes due to the text embedding model's rate limits.

embedding_df = add_embeddings_to_df(catalog_summary_df, 'summary', embedding_model)

# First store locally, just in case something happens to the Colab runtime.
with open(EMBEDDINGS_LOCAL_PATH, 'w') as f:
  f.write(embedding_df.to_json(orient='records', lines=True))

# Make sure we can read the embeddings that were written to file.
embeddings_df = pd.read_json(EMBEDDINGS_LOCAL_PATH, lines=True)
embedding_df.head()

100%|██████████| 174/174 [04:09<00:00,  1.43s/it]


Unnamed: 0,id,name,summary,embedding
0,AAFC/ACI,Canada AAFC Annual Crop Inventory,Agriculture and Agri-Food Canada annually maps...,"[-0.02059192582964897, 0.009980480186641216, -..."
1,ACA/reef_habitat/v2_0,Allen Coral Atlas (ACA) - Geomorphic Zonation ...,"The Allen Coral Atlas is a global, high-resolu...","[0.004030477721244097, 0.05731431767344475, -0..."
2,AHN/AHN2_05M_INT,"AHN Netherlands 0.5m DEM, Interpolated",The AHN DEM is a detailed (0.5m resolution) el...,"[-0.011531920172274113, -0.06319545209407806, ..."
3,AHN/AHN2_05M_NON,"AHN Netherlands 0.5m DEM, Non-Interpolated",The AHN DEM is a high-resolution (0.5m) model ...,"[0.008249891921877861, -0.08075475692749023, -..."
4,AHN/AHN2_05M_RUW,"AHN Netherlands 0.5m DEM, Raw Samples",The AHN DEM is a highly detailed (0.5m resolut...,"[-0.005648438353091478, -0.06969967484474182, ..."


In [None]:
#@title Upload embeddings to GCS
bucket = google.cloud.storage.bucket.Bucket(
    storage_client, name=DESTINATION_BUCKET, user_project=GCP_PROJECT)
blob = bucket.blob(EMBEDDINGS_GCS_PATH)
blob.upload_from_filename(EMBEDDINGS_LOCAL_PATH)

In [None]:
# @title Make sure we can load the new file from GCS

bucket = storage_client.get_bucket(DESTINATION_BUCKET)
blob = bucket.blob(EMBEDDINGS_GCS_PATH)
blob.download_to_filename(EMBEDDINGS_LOCAL_PATH)

embeddings_df = pd.read_json(EMBEDDINGS_LOCAL_PATH, lines=True)
embeddings_df

Unnamed: 0,id,name,summary,embedding
0,AAFC/ACI,Canada AAFC Annual Crop Inventory,Agriculture and Agri-Food Canada annually maps...,"[-0.0205919258, 0.0099804802, -0.0294760894000..."
1,ACA/reef_habitat/v2_0,Allen Coral Atlas (ACA) - Geomorphic Zonation ...,"The Allen Coral Atlas is a global, high-resolu...","[0.0040304777, 0.0573143177, -0.0509719588, 0...."
2,AHN/AHN2_05M_INT,"AHN Netherlands 0.5m DEM, Interpolated",The AHN DEM is a detailed (0.5m resolution) el...,"[-0.0115319202, -0.06319545210000001, -0.03415..."
3,AHN/AHN2_05M_NON,"AHN Netherlands 0.5m DEM, Non-Interpolated",The AHN DEM is a high-resolution (0.5m) model ...,"[0.0082498919, -0.0807547569, -0.0562631935, 0..."
4,AHN/AHN2_05M_RUW,"AHN Netherlands 0.5m DEM, Raw Samples",The AHN DEM is a highly detailed (0.5m resolut...,"[-0.0056484384000000006, -0.0696996748, -0.052..."
...,...,...,...,...
862,projects/planet-nicfi/assets/basemaps/americas,NICFI Satellite Data Program Basemaps for Trop...,The Norway's International Climate and Forest ...,"[0.014683888300000001, 0.034109890500000004, -..."
863,projects/planet-nicfi/assets/basemaps/asia,NICFI Satellite Data Program Basemaps for Trop...,High-resolution satellite imagery of the tropi...,"[0.021393563600000002, 0.0348088369, -0.003804..."
864,projects/sat-io/open-datasets/GLOBathy/GLOBath...,GLOBathy Global lakes bathymetry dataset,"GLOBathy, a new global dataset, provides depth...","[0.020634584100000002, 0.0340872109, -0.085019..."
865,projects/sat-io/open-datasets/ORNL/LANDSCAN_GL...,LandScan Population Data Global 1km,"LandScan, a high-resolution population dataset...","[0.0694423467, 0.0048219636000000005, -0.04556..."
