In [0]:
import os
import pandas as pd
import requests
import tempfile
import xarray as xr

from bs4 import BeautifulSoup
from urllib.parse import urlparse

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
import pyspark.sql.types as ST

In [0]:
class GRIBMessagePartition(InputPartition):
  def __init__(self, path: str):
    self.path = path

In [0]:
class IFSDataSourceReader(DataSourceReader):
  def __init__(self, schema, options):    
    self.schema: ST.StructType = schema
    self.options = options
  
  @property
  def forecast_time(self):
    forecast_time_str: str = self.options.get("forecastTime")
    if not forecast_time_str:
      raise ValueError("The 'forecastTime' option is required.")
    return int(forecast_time_str)
  
  @property
  def forecast_date(self):
    forecast_date: str = self.options.get("forecastDate")
    if not forecast_date:
      raise ValueError("The 'forecastDate' option is required.")
    return forecast_date
  
  @property
  def forecast_url(self):
    return (
      "https://data.ecmwf.int/forecasts/"
      f"{self.forecast_date}/"
      f"{self.forecast_time:02}z/"
      "ifs/0p25/oper/"
      )
  
  @property
  def root_url(self):
    parsed_url = urlparse(self.forecast_url)
    return f"{parsed_url.scheme}://{parsed_url.netloc}"
  
  @property
  def input_files(self):
    r = requests.get(self.forecast_url)
    r.raise_for_status()
    soup = BeautifulSoup(r.text, 'html.parser')
    all_urls = soup.find_all('a', href=True)
    grib_urls = [f"{self.root_url}{a['href']}" for a in all_urls if a['href'].endswith('.grib2')]
    return grib_urls
  
  @property
  def variables(self):
    return self.options.get("variables").split(",")
  
  def partitions(self):
    parts = []
    for path in self.input_files:
      parts.append(GRIBMessagePartition(path))
    return parts

  def read(self, partition):
    with tempfile.NamedTemporaryFile(mode="w+b", suffix=".grib2") as fn:
      r = requests.get(partition.path, stream=True)
      r.raise_for_status()
      for chunk in r.iter_content(chunk_size=8192):
        fn.write(chunk)
      var_dfs = []
      for var in self.variables:
        print(f"{partition=}, {var=}")
        with xr.open_dataset(fn.name, filter_by_keys={"cfVarName": var}) as ds:
          col_subset = ["time", "valid_time", "latitude", "longitude", var]
          var_pdf = ds.to_dataframe().reset_index(drop=False)[col_subset]
          var_pdf["time"] = var_pdf["time"].dt.tz_localize("utc")
          var_pdf["valid_time"] = var_pdf["valid_time"].dt.tz_localize("utc")
          var_pdf["variable"] = var
          var_pdf["value"] = var_pdf[var].astype('float32')
          var_pdf["path"] = partition.path
          output_cols = ["time", "valid_time", "variable", "longitude", "latitude", "value", "path"]
          var_dfs.append(var_pdf[output_cols])
      output_pdf = pd.concat(var_dfs, ignore_index=True)
    yield from output_pdf.itertuples(index=False)

In [0]:
class IFSDataSource(DataSource):

    @classmethod
    def name(cls) -> str:
        """
        Get the name of the data source.

        Returns:
            str: The name of the data source.
        """
        return "ifs"
    
    def schema(self) -> ST.StructType:
        """
        Define the schema for the output data.

        Returns:
            StructType: The schema including fields for the variable identifier, band index,
            metadata, coordinates, and values.
        """
        return ST.StructType([
            ST.StructField("time", ST.TimestampType(), True),
            ST.StructField("valid_time", ST.TimestampType(), True),
            ST.StructField("variable", ST.StringType(), True),
            ST.StructField("x", ST.DoubleType(), True),
            ST.StructField("y", ST.DoubleType(), True),
            ST.StructField("m", ST.DoubleType(), True),
            ST.StructField("path", ST.StringType(), True),
        ])

    def reader(self, schema: ST.StructType):
        return IFSDataSourceReader(schema, self.options)