# GraphCast SAE Demo

This notebook demonstrates loading GraphCast and extracting internal activations.

In [1]:
%pip install git+https://github.com/theodoremacmillan/graphcast.git@sae-hooks


Collecting git+https://github.com/theodoremacmillan/graphcast.git@sae-hooks
  Cloning https://github.com/theodoremacmillan/graphcast.git (to revision sae-hooks) to /tmp/pip-req-build-as2radib
  Running command git clone --filter=blob:none --quiet https://github.com/theodoremacmillan/graphcast.git /tmp/pip-req-build-as2radib
  Running command git checkout -b sae-hooks --track origin/sae-hooks
  Switched to a new branch 'sae-hooks'
  Branch 'sae-hooks' set up to track remote branch 'sae-hooks' from 'origin'.
  Resolved https://github.com/theodoremacmillan/graphcast.git to commit 39d1de436148c5658726b5a092bd2d7ef8701f2a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting cartopy (from graphcast==0.2.0.dev0)
  Downloading cartopy-0.25.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting colabtools (from graphcast==0.2.0.dev0)
  Downloading colabtools-0.0.1-py3-none-any.whl.metadata (511 bytes)
Collecting dinosaur-dycore (from graphcast==0.2

In [2]:
# @title Workaround for cartopy crashes

# Workaround for cartopy crashes due to the shapely installed by default in
# google colab kernel (https://github.com/anitagraser/movingpandas/issues/81):
!pip uninstall -y shapely
!pip install shapely --no-binary shapely

Found existing installation: shapely 2.1.1
Uninstalling shapely-2.1.1:
  Successfully uninstalled shapely-2.1.1
Collecting shapely
  Downloading shapely-2.1.2.tar.gz (315 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.5/315.5 kB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: shapely
  Building wheel for shapely (pyproject.toml) ... [?25l[?25hdone
  Created wheel for shapely: filename=shapely-2.1.2-cp311-cp311-linux_x86_64.whl size=1207446 sha256=29c667a1fc19cc5deeabfead39c38b220a1cf17c4aae21131e3c70eece9f24fc
  Stored in directory: /root/.cache/pip/wheels/af/a4/43/7be70b9a914836f51744c5e6e2408c9b4d0c3bcb2033d394e0
Successfully built shapely
Installing collected packages: shapely
Successfully installed shapely-2.1.2


In [3]:
# @title Imports

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray


def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


In [6]:
# @title Authenticate with Google Cloud Storage

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
dir_prefix = "graphcast/"

In [5]:
import numpy as np
import xarray as xr
import gcsfs


def load_era5_into_memory(
    start: str,
    end: str,
    zarr_path: str = "gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg_derived.zarr",
    vars_keep=None,
):
    """
    Load a slice of ERA5 data fully into memory.

    Parameters
    ----------
    start, end : str
        Date range (YYYY-MM-DD)
    zarr_path : str
        GCS or local Zarr path
    vars_keep : list[str] | None
        Variables to keep (None = all)

    Returns
    -------
    xarray.Dataset
        Fully-loaded dataset in memory
    """
    if vars_keep is None:
      vars_keep = [
          "geopotential",
          "specific_humidity",
          "temperature",
          "u_component_of_wind",
          "v_component_of_wind",
          "vertical_velocity",
          "2m_temperature",
          "10m_u_component_of_wind",
          "10m_v_component_of_wind",
          "mean_sea_level_pressure",
          "total_precipitation_6hr",
          "toa_incident_solar_radiation",
          "geopotential_at_surface",
          "land_sea_mask"
      ]

    start = np.datetime64(start)
    end = np.datetime64(end)

    # --- Open Zarr store ---
    if zarr_path.startswith("gs://"):
        fs = gcsfs.GCSFileSystem(token="anon")
        store = fs.get_mapper(zarr_path[5:])
        ds = xr.open_zarr(store, consolidated=True)
    else:
        ds = xr.open_zarr(zarr_path, consolidated=True)

    # --- Normalize coords ---
    rename = {}
    if "latitude" in ds.coords:
        rename["latitude"] = "lat"
    if "longitude" in ds.coords:
        rename["longitude"] = "lon"
    if rename:
        ds = ds.rename(rename)

    if ds.lat[0] > ds.lat[-1]:
        ds = ds.reindex(lat=ds.lat[::-1])

    # --- Time slice ---
    ds = ds.sel(time=slice(start, end))

    # --- Variable selection ---
    if vars_keep is not None:
        ds = ds[[v for v in vars_keep if v in ds.data_vars]]

    # --- LOAD EVERYTHING INTO MEMORY ---
    ds = ds.load()

    return ds

In [7]:
ds = load_era5_into_memory(
    start="2020-01-01",
    end="2020-01-02"
)

In [15]:
import os
import numpy as np
import xarray as xr

def write_daily_era5_files(ds: xr.Dataset, out_dir: str):
    """
    Write an in-memory ERA5 Dataset to daily NetCDF files
    compatible with three_step_window().

    Assumes:
      - ds has a 'time' coordinate of type datetime64
      - 6-hourly (or finer) resolution
    """

    os.makedirs(out_dir, exist_ok=True)

    # Group by day
    for day, ds_day in ds.groupby("time.date"):
        day_str = np.datetime_as_string(np.datetime64(day), unit="D")
        out_path = os.path.join(out_dir, f"era5_{day_str}.nc")

        # Preserve original encoding as much as possible
        ds_day.to_netcdf(out_path)

        print(f"[WRITE] {out_path}")


In [17]:
write_daily_era5_files(
    ds,
    out_dir = '/content/era5_daily_nc'
)

[WRITE] /content/era5_daily_nc/era5_2020-01-01.nc
[WRITE] /content/era5_daily_nc/era5_2020-01-02.nc


## Next step is running GraphCast and capturing its internal activations at layer 8

In [18]:
import dataclasses
import functools
import numpy as np
import xarray as xr
import jax
import haiku as hk

from graphcast import (
    autoregressive,
    casting,
    checkpoint,
    data_utils,
    graphcast,
    normalization,
    rollout,
    xarray_jax,
    xarray_tree,
)

from graphcast.deep_typed_graph_net import get_activation_manager
from google.cloud import storage


In [19]:
# ============================================================
# FULL GraphCast activation pipeline — faithful to original
# ============================================================

import os, glob, dataclasses, functools, time
import numpy as np
import xarray as xr
import jax, haiku as hk
from google.cloud import storage

from graphcast import (
    autoregressive,
    casting,
    checkpoint,
    data_utils,
    graphcast,
    normalization,
    rollout,
    xarray_jax,
    xarray_tree,
)
from graphcast.deep_typed_graph_net import get_activation_manager


# ============================================================
# USER INPUTS (YOU SET THESE)
# ============================================================

data_dir = "/content/era5_daily_nc"        # contains era5_YYYY-MM-DD.nc
acts_dir = "/content/graphcast_acts"
os.makedirs(acts_dir, exist_ok=True)

centers = np.arange(
    np.datetime64("2020-01-01T00"),
    np.datetime64("2020-01-02T00"),
    np.timedelta64(6, "h"),
)


# ============================================================
# ERA5 WINDOWING — *EXACTLY YOUR CODE*
# ============================================================

def _open_and_trim(path: str) -> xr.Dataset:
    ds = xr.open_dataset(path)
    if "time" in ds.dims and ds.sizes["time"] > 4:
        ds = ds.isel(time=slice(0, 4))
    return ds


def three_step_window(data_dir: str, center_time: str) -> xr.Dataset | None:
    t0 = np.datetime64(center_time)
    t_minus = t0 - np.timedelta64(6, "h")
    t_plus  = t0 + np.timedelta64(6, "h")

    needed_days = sorted({
        np.datetime64(t_minus, "D"),
        np.datetime64(t0, "D"),
        np.datetime64(t_plus, "D"),
    })

    file_paths = [
        os.path.join(data_dir, f"era5_{str(d)[:10]}.nc")
        for d in needed_days
    ]

    if any(not os.path.exists(p) for p in file_paths):
        return None

    daily = [_open_and_trim(p) for p in file_paths]

    var_time   = [v for v, da in daily[0].data_vars.items() if "time" in da.dims]
    var_static = [v for v, da in daily[0].data_vars.items() if "time" not in da.dims]

    ds_time = xr.concat([d[var_time] for d in daily], dim="time").sortby("time")
    ds_static = daily[0][var_static]

    ds = xr.merge([ds_time, ds_static])

    target_times = np.array([t_minus, t0, t_plus], dtype=ds.time.dtype)
    if not all(t in ds.time.values for t in target_times):
        return None

    ds = ds.sel(time=target_times)

    ds_new = ds.copy()
    for v in ds_new.data_vars:
        if "time" in ds_new[v].dims:
            ds_new[v] = ds_new[v].expand_dims("batch")

    for c in ds.coords:
        if "time" in ds[c].dims:
            ds_new = ds_new.assign_coords(
                {c: ds[c].expand_dims("batch")}
            )

    time_orig = ds["time"]
    t_ref = time_orig.values[0]
    time_delta = time_orig - t_ref

    ds_new = ds_new.assign_coords(time=time_delta)
    ds_new = ds_new.assign_coords(datetime=("time", time_orig.values))
    ds_new = ds_new.assign_coords(
        {"datetime": ds_new["datetime"].expand_dims("batch")}
    )

    return ds_new


# ============================================================
# LOAD GRAPHCAST + STATS — *EXACTLY YOUR CODE*
# ============================================================

gcs = storage.Client.create_anonymous_client()
bucket = gcs.get_bucket("dm_graphcast")
prefix = "graphcast/"

model_source = (
    "GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 "
    "- mesh 2to6 - precipitation input and output.npz"
)

with bucket.blob(f"{prefix}params/{model_source}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

model_config = ckpt.model_config
task_config = ckpt.task_config
params = ckpt.params
state = {}

with bucket.blob(prefix + "stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xr.load_dataset(f).compute()

with bucket.blob(prefix + "stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xr.load_dataset(f).compute()

with bucket.blob(prefix + "stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xr.load_dataset(f).compute()


# ============================================================
# GRAPHCAST CONSTRUCTION — UNCHANGED
# ============================================================

def construct_wrapped_graphcast(model_config, task_config):
    predictor = graphcast.GraphCast(model_config, task_config)
    predictor = casting.Bfloat16Cast(predictor)
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level,
    )
    predictor = autoregressive.Predictor(
        predictor, gradient_checkpointing=True
    )
    return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config)
    return predictor(inputs, targets_template=targets_template, forcings=forcings)


def with_configs(fn):
    return functools.partial(fn, model_config=model_config, task_config=task_config)


def with_params(fn):
    return functools.partial(fn, params=params, state=state)


def drop_state(fn):
    return lambda **kw: fn(**kw)[0]


run_forward_jitted = drop_state(
    with_params(
        jax.jit(with_configs(run_forward.apply))
    )
)


# ============================================================
# ACTIVATION MANAGER — DISK, SUPPORTED
# ============================================================

am = get_activation_manager()
am.__init__(
    enabled=True,
    save_dir=acts_dir,
    save_steps=[2, 4, 6, 8, 10, 12, 14],
    save_node_sets=["mesh_nodes"],
    mode="post_res",
)


# ============================================================
# MAIN LOOP — SAME SEMANTICS AS YOUR SCRIPT
# ============================================================

t_start = time.time()

for center in centers:
    center_str = np.datetime_as_string(center, unit="h")
    print(f"[TIME] {center_str}")

    am.set_time(center_str)

    ds = three_step_window(data_dir, center_str)
    if ds is None:
        print(f"[MISS] {center_str}")
        continue

    inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(
        ds,
        target_lead_times=slice("6h", "6h"),
        **dataclasses.asdict(task_config),
    )

    _ = rollout.chunked_prediction(
        run_forward_jitted,
        rng=jax.random.PRNGKey(0),
        inputs=inputs,
        targets_template=targets * np.nan,
        forcings=forcings,
    )

    print(f"[DONE] {center_str}")

print(f"[ALL DONE] {time.time() - t_start:.1f}s")

[TIME] 2020-01-01T00
[MISS] 2020-01-01T00
[TIME] 2020-01-01T06


  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']
  num_inputs = inputs.dims['time']


[DONE] 2020-01-01T06
[TIME] 2020-01-01T12


  num_target_steps = targets_template.dims["time"]


[DONE] 2020-01-01T12
[TIME] 2020-01-01T18


  num_target_steps = targets_template.dims["time"]


[DONE] 2020-01-01T18
[ALL DONE] 112.2s
