In [3]:
# Import lib
import numpy as np
import matplotlib.pyplot as plt
import zipfile
import os
import os
import zipfile
import xarray as xr
from tqdm import tqdm
import requests
import time
import json

## Download the original data

In [4]:
def download_with_progress(url, destination):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kilobyte

    with open(destination, 'wb') as file, tqdm(
        total=total_size, unit='iB', unit_scale=True, desc=destination, ncols=100, leave=True,
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]"
    ) as bar:
        for data in response.iter_content(block_size):
            file.write(data)
            bar.update(len(data))

def get_version_1_of_seasfire_datacube(version='0.1'):
    try:
        # Download zipped cube

        if version == '0.1':
            url = "https://zenodo.org/records/6834585/files/SeasFireCube8daily.zip"
            zip_filename = "SeasFireCube8daily.zip"
        elif version == '0.4':
            url = "https://zenodo.org/records/13834057/files/seasfire_v0.4.zip"
            zip_filename = "SeasFireCube8daily_v0.4.zip"

        if not os.path.exists(zip_filename):
            print("Downloading data cube...")
            download_with_progress(url, zip_filename)
        else:
            print("Data cube already downloaded.")

        # Extract from zip file with progress bar
        with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
            print("Extracting data cube...")
            total_files = len(zip_ref.namelist())
            with tqdm(total=total_files, unit='file', desc='Extracting', ncols=100, leave=True,
                      bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]") as bar:
                for file in zip_ref.namelist():
                    zip_ref.extract(file)
                    bar.update(1)

        if version == '0.1':
            extracted_folder = 'SeasFireCube8daily.zarr'
        elif version == '0.4':
            extracted_folder = 'SeasFireCube8daily_v0.4.zarr'
            
        if not os.path.exists(extracted_folder):
            raise FileNotFoundError(f"Extraction failed, {extracted_folder} not found.")

        # Load dataset
        dataset = xr.open_zarr(extracted_folder)
        print("Dataset successfully loaded.")
        return dataset

    except Exception as e:
        print(f"An error occurred while downloading or loading the data cube: {e}")
        return None

dataset = get_version_1_of_seasfire_datacube(version='0.4')

Data cube already downloaded.
Extracting data cube...


Extracting: 100%|███████████████████████████████████████████████| 927/927 [00:38<00:00, 24.18file/s]

An error occurred while downloading or loading the data cube: Extraction failed, SeasFireCube8daily_v0.4.zarr not found.





In [5]:
dataset

## Data processing functions 

(functions from the given ipynb)

In [13]:
def select_spatio_temporal_data(
    data: xr.DataArray | xr.Dataset,
    time_start: int,
    time_length: int,
    latitude: Optional[Union[int, Tuple[int,int]]] = None,
    longitude: Optional[Union[int, Tuple[int,int]]] = None
) -> xr.DataArray | xr.Dataset:
    """
    Slice an xarray object in time (if present), latitude, and longitude.

    Args:
      data        : xarray DataArray or Dataset
      time_start  : start index in 'time' coordinate (ignored if no 'time' dim)
      time_length : number of consecutive timesteps
      latitude    : None (all), int (single index), or (start, end)
      longitude   : None (all), int (single index), or (start, end)

    Returns:
      The subset of `data` with the specified ranges.
    """
    result = data

    # Time slice (only if 'time' exists)
    if 'time' in result.dims:
        result = result.isel(time=slice(time_start, time_start + time_length))

    # Helper to apply a 1D slice on a given dim
    def _apply_slice(ds, dim, key):
        if key is None:
            return ds
        if isinstance(key, int):
            return ds.isel({dim: key})
        if isinstance(key, (list, tuple)) and len(key) == 2:
            return ds.isel({dim: slice(key[0], key[1])})
        raise ValueError(f"{dim!r} must be None, int, or (start,end)")

    # Latitude & Longitude
    result = _apply_slice(result, 'latitude', latitude)
    result = _apply_slice(result, 'longitude', longitude)

    return result


def plot_earth(
    data: xr.DataArray | xr.Dataset,
    time_start: int,
    time_length: int,
    latitude: None | int | tuple[int, int] = None,
    longitude: None | int | tuple[int, int] = None,
    col_wrap: int = 4
) -> None:
    """
    Plot spatio-temporal slices of the data.

    If time_length == 1, plots a single map; otherwise creates a faceted plot over time.
    """
    subset = select_spatio_temporal_data(
        data, time_start, time_length, latitude, longitude
    )

    if time_length == 1:
        subset.plot()
    else:
        subset.plot(
            x="longitude",
            y="latitude",
            col="time",
            col_wrap=col_wrap
        )
    plt.show()


def lat_lon_to_index(
    coords: list[str],
    dim: str,
    size: int
) -> list[int]:
    """
    Convert human-readable lat/lon strings to array indices.

    Args:
      coords : list of strings like '25N', '45W'
      dim    : 'latitude' or 'longitude'
      size   : length of that dimension (e.g. 720 or 1440)
    """
    indices = []
    for val in coords:
        deg = float(val[:-1])
        dirc = val[-1].upper()
        if dim == 'latitude':
            if dirc == 'N':
                idx = int((90 - deg) / 180 * size)
            elif dirc == 'S':
                idx = int((deg + 90) / 180 * size)
            else:
                raise ValueError("Latitude must end with 'N' or 'S'")
        elif dim == 'longitude':
            if dirc == 'E':
                idx = int((deg + 180) / 360 * size)
            elif dirc == 'W':
                idx = int((180 - deg) / 360 * size)
            else:
                raise ValueError("Longitude must end with 'E' or 'W'")
        else:
            raise ValueError("dim must be 'latitude' or 'longitude'")
        indices.append(idx)
    return sorted(indices)

NameError: name 'Optional' is not defined

## Generate the train/val/test .json files under data folder

In [7]:
def sample_new_coords(shape, existing_coords, N):
    """
    Randomly sample N distinct coordinates in an array of given shape,
    excluding any in existing_coords.

    Parameters
    ----------
    shape : tuple of ints
        The overall grid shape, e.g. (100,100,100).
    existing_coords : ndarray of shape (M, ndim)
        Integer coordinates to exclude.
    N : int
        Number of new coordinates to sample.

    Returns
    -------
    new_coords : ndarray of shape (N, ndim)
        The newly sampled coordinates.
    """
    shape = tuple(shape)
    ndim = len(shape)

    # 1) convert existing coords to flat indices
    #    (coords.T gives a tuple of arrays for each dimension)
    flat_existing = np.ravel_multi_index(existing_coords.T, shape)

    # 2) total number of points
    total = np.prod(shape)

    # 3) build the set of available flat indices
    #    (assume existing are unique -> faster)
    all_flat = np.arange(total)
    unused_flat = np.setdiff1d(all_flat, flat_existing, assume_unique=True)

    if N > unused_flat.size:
        raise ValueError(f"Cannot sample {N} points; only {unused_flat.size} free.")

    # 4) choose N of them without replacement
    chosen_flat = np.random.choice(unused_flat, size=N, replace=False)

    # 5) map back to nd-coordinates
    new_coords = np.column_stack(np.unravel_index(chosen_flat, shape))
    return new_coords


def build_dataset(type='train', latitude=['25N', '75N'], longitude=['15W', '45E'], NPR=5, fire_threshold=20000, time_lag=39, fire_var='BAs_GWIS', local_var=("t2m", "tp", "vpd_cf"), ocis=("nao", "nina34_anom")):

    # 1) decide time ranges
    if type == 'train':
        initial_t, total_steps = 0, 46 * 17
    elif type == 'test':
        initial_t, total_steps = 46 * 17, 46 * 2
    elif type == 'val':
        initial_t, total_steps = 46 * 19, 46 * 2
    else:
        raise ValueError("Invalid type. Choose 'train','test','val'.")

    lat_idx = latitude2index(latitude)
    lon_idx = longitude2index(longitude)

    core_start = initial_t + time_lag
    core_len   = total_steps - time_lag

    # 2) load fire and predictors ONCE
    fire_arr = select_spatio_temporal_data(dataset[fire_var], core_start, core_len, lat_idx, lon_idx).values

    local_arrs = {
        v: select_spatio_temporal_data(dataset[v], initial_t, total_steps, lat_idx, lon_idx).values
        for v in local_var
    }
    oci_arrs = {
        v: select_spatio_temporal_data(dataset[v], initial_t, total_steps, lat_idx, lon_idx).values
        for v in ocis
    }

    # 3) find positives & sample negatives
    flat = fire_arr.ravel()
    pos_flat = np.nonzero(flat > fire_threshold)[0]
    pos_coords = np.column_stack(np.unravel_index(pos_flat, fire_arr.shape))
    neg_coords = sample_new_coords(fire_arr.shape, pos_coords, NPR * len(pos_coords))

    # 4) build output
    out = []
    for coords, label, desc in ((pos_coords,1,"pos"), (neg_coords,0,"neg")):
        pbar = tqdm(coords, desc=f"Processing {desc}", unit="pt")
        for t_rel, y, x in pbar:
            t_abs   = core_start + int(t_rel)
            rel_idx = t_abs - initial_t       # now in  [0 .. total_steps)
            t0, t1 = rel_idx - time_lag, rel_idx

            # Nested structure
            inst = {"local_variables": {}, "ocis": {}, "target": label}

            # fill local_variables with UPPERCASE keys
            for v, arr in local_arrs.items():
                key = v.upper()
                inst["local_variables"][key] = arr[t0:t1, y, x].tolist()

            # fill ocis with UPPERCASE keys
            for v, arr in oci_arrs.items():
                key = v.upper()
                inst["ocis"][key] = arr[t0:t1, y, x].tolist()

            out.append(inst)

    import random
    random.shuffle(out)

    with open(f"./data/{type}.json", "w", encoding="utf-8") as f:
        for inst in out:
            f.write(
                json.dumps(inst, ensure_ascii=False, separators=(',', ':'))
                + "\n"
            )
    print(f"Dataset saved to {type}.json (NDJSON, one object per line)")

#### .json file generator.

Fire data is continuous variable. The fire_threshold can transform the continuous fire area variable to the binary status. Smaller fire_threshold, larger .json file created.

In [8]:
# For real implementation, suggest the fire_threshold to be 1.

build_dataset(type='train', fire_threshold=10000)
build_dataset(type='val', fire_threshold=10000)
build_dataset(type='test', fire_threshold=10000)

TypeError: 'NoneType' object is not subscriptable

## Dataset and Dataloader

In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader

class JsonFireDataset(Dataset):
    def __init__(self, json_path, local_keys=None, oci_keys=None):
        """
        Args:
            json_path (str): path to your NDJSON file (one JSON object per line).
            local_keys (list of str): names of the local_variables channels, e.g. ['T2M','TP','VPD_CF'].
                                      If None, inferred from the first sample.
            oci_keys   (list of str): names of the ocis channels, e.g. ['NAO','NINA34_ANOM'].
                                      If None, inferred from the first sample.
        """
        # load all lines
        with open(json_path, 'r', encoding='utf-8') as f:
            self.records = [json.loads(line) for line in f]

        # infer channel order if not given
        first = self.records[0]
        if local_keys is None:
            local_keys = list(first['local_variables'].keys())
        if oci_keys is None:
            oci_keys = list(first['ocis'].keys())

        self.local_keys = local_keys
        self.oci_keys   = oci_keys
        self.channel_keys = self.local_keys + self.oci_keys

        # sanity-check that every record has the same length L
        L = len(first['local_variables'][self.local_keys[0]])
        for rec in self.records:
            assert all(len(rec['local_variables'][k]) == L for k in self.local_keys), "inconsistent L"
            assert all(len(rec['ocis'][k]) == L for k in self.oci_keys),       "inconsistent L"
        self.L = L

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]

        # collect each channel sequence into a list
        seqs = []
        for k in self.local_keys:
            seqs.append(rec['local_variables'][k])
        for k in self.oci_keys:
            seqs.append(rec['ocis'][k])

        # stack → shape (channel, L)
        x = torch.tensor(seqs, dtype=torch.float32)
        y = torch.tensor(rec['target'], dtype=torch.long)
        return x, y

# ── USAGE ────────────────────────────────────────────────────────────
if __name__ == "__main__":
    # point these to your actual JSON file
    train_ds = JsonFireDataset(
        "./data/train.json",
        local_keys=["T2M","TP","VPD_CF"],
        oci_keys=["NAO","NINA34_ANOM"],
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=2048,
        shuffle=True,
        num_workers=4,    # adjust to your machine
        pin_memory=True,  # if you’re on GPU
    )

    # iterate
    for batch_x, batch_y in train_loader:
        # batch_x: (32, 5, L)
        # batch_y: (32,)
        print(batch_x.shape, batch_y.shape)
        break


## Others, delete in the final version

In [None]:
data = earth_graph(dataset['BAs_GWIS'], 0, 100, latitude2index(latitude), longitude2index(longitude), plot=False)

In [None]:
print(data.values.shape)
arr_1d = data.values.reshape(-1)  # or use arr.ravel()

# Plot histogram
plt.hist(arr_1d, bins=50, log=True)
plt.title("Histogram of values")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()