<a href="https://colab.research.google.com/github/wainainapete/Face-detection/blob/main/gencast_demo_cloud_vm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Copyright 2024 DeepMind Technologies Limited.
>
> Licensed under the Apache License, Version 2.0 (the "License");
> you may not use this file except in compliance with the License.
> You may obtain a copy of the License at
>
>      http://www.apache.org/licenses/LICENSE-2.0
>
> Unless required by applicable law or agreed to in writing, software
> distributed under the License is distributed on an "AS-IS" BASIS,
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> See the License for the specific language governing permissions and
> limitations under the License.

# GenCast Demo

This notebook demonstrates running all GenCast models provided in the repository:

1.  `GenCast 0p25deg <2019`
2.  `GenCast 0p25deg Operational <2019`
3.  `GenCast 1p0deg <2019`
4.  `GenCast 1p0deg Mini <2019`

While `GenCast 1p0deg Mini <2019` is runnable with the freely provided TPUv2-8 configuration in Colab, the other models require compute that can be accessed via Google Cloud.

See [cloud_vm_setup.md](https://github.com/google-deepmind/graphcast/blob/main/docs/cloud_vm_setup.md) for detailed instructions on launching a Google Cloud TPU VM and connecting to it via this notebook. This document also provides some more information on the memory requirements of the models.



# Installation and Initialization

In [None]:
# @title Pip install repo and dependencies

%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

In [None]:
# @title Imports

import dataclasses
import datetime
import math
from typing import Optional
import haiku as hk
from IPython.display import HTML
from IPython import display
import ipywidgets as widgets
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray

from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning



In [None]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())


# Load the Data and initialize the model

In [None]:
# @title Set paths

MODEL_PATH = ""  # E.g. "GenCast 1p0deg _2019.npz"
DATA_PATH = ""  # E.g. "source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc"
STATS_DIR = ""  # E.g. "stats/"

In [None]:
# @title Load the model

with open(MODEL_PATH, "rb") as f:
  ckpt = checkpoint.load(f, gencast.CheckPoint)
params = ckpt.params
state = {}

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")

## Load the example data

Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.

Example HRES-fc0 datasets are available at 0.25 degree resolution.

Some transformations were done from the base datasets:
- We accumulated precipitation over 12 hours instead of the default 1 hour.
- For HRES-fc0 sea surface temperature, we assigned NaNs to grid cells in which sea surface temperature was NaN in the ERA5 dataset (this remains fixed at all times).

The data resolution must match the model that is loaded.



In [None]:
# @title Check example dataset matches model

def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))

def data_valid_for_model(file_name: str, params_file_name: str):
  """Check data type and resolution matches."""
  data_file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  res_matches = data_file_parts["res"].replace(".", "p") in params_file_name.lower()
  source_matches = "Operational" in params_file_name
  if data_file_parts["source"] == "era5":
    source_matches = not source_matches
  return res_matches and source_matches

assert data_valid_for_model(DATA_PATH, MODEL_PATH)


In [None]:
# @title Load weather data

with open(DATA_PATH, "rb") as f:
  example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(DATA_PATH.removesuffix(".nc")).items()]))

example_batch

In [None]:
# @title Plot example data

plot_size = 7
variable = "geopotential"
level = 500
steps = example_batch.dims["time"]


data = {
    " ": scale(select(example_batch, variable, level, steps), robust=True),
}
fig_title = variable
if "level" in example_batch[variable].coords:
  fig_title += f" at {level} hPa"

plot_data(data, fig_title, plot_size, robust=True)


In [None]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", "12h"), # Only 1AR training.
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time']-2)*12}h"), # All but 2 input frames.
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


In [None]:
# @title Load normalization data

with open(STATS_DIR +"diffs_stddev_by_level.nc", "rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"mean_by_level.nc", "rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"stddev_by_level.nc", "rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"min_by_level.nc", "rb") as f:
  min_by_level = xarray.load_dataset(f).compute()

In [None]:
# @title Build jitted functions, and possibly initialize random weights


def construct_wrapped_gencast():
  """Constructs and wraps the GenCast Predictor."""
  predictor = gencast.GenCast(
      sampler_config=sampler_config,
      task_config=task_config,
      denoiser_architecture_config=denoiser_architecture_config,
      noise_config=noise_config,
      noise_encoder_config=noise_encoder_config,
  )

  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level,
  )

  predictor = nan_cleaning.NaNCleaner(
      predictor=predictor,
      reintroduce_nans=True,
      fill_value=min_by_level,
      var_to_clean='sea_surface_temperature',
  )

  return predictor


@hk.transform_with_state
def run_forward(inputs, targets_template, forcings):
  predictor = construct_wrapped_gencast()
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(inputs, targets, forcings):
  predictor = construct_wrapped_gencast()
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics),
  )


def grads_fn(params, state, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), i, t, f
    )
    return loss, (diagnostics, next_state)

  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True
  )(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads


if params is None:
  init_jitted = jax.jit(loss_fn.init)
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets=train_targets,
      forcings=train_forcings,
  )


loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
)
grads_fn_jitted = jax.jit(grads_fn)
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
)
# We also produce a pmapped version for running in parallel.
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

# Run the model

The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.
This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.

Note that the `Autoregressive rollout` cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices.

In [None]:
# The number of ensemble members should be a multiple of the number of devices.
print(f"Number of local devices {len(jax.local_devices())}")

In [None]:
# @title Autoregressive rollout (loop in python)

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

num_ensemble_members = 8 # @param int
rng = jax.random.PRNGKey(0)
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs
# regardless of total ensemble size.
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    # Use pmapped version to parallelise across devices.
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks)

In [None]:
# @title Plot prediction samples and diffs

plot_size = 5
variable = "2m_temperature"
level = None
steps = predictions.dims["time"]

fig_title = variable
if "level" in predictions[variable].coords:
  fig_title += f" at {level} hPa"

for sample_idx in range(num_ensemble_members):
  data = {
      "Targets": scale(select(eval_targets, variable, level, steps), robust=True),
      "Predictions": scale(select(predictions.isel(sample=sample_idx), variable, level, steps), robust=True),
      "Diff": scale((select(eval_targets, variable, level, steps) -
                          select(predictions.isel(sample=sample_idx), variable, level, steps)),
                        robust=True, center=0),
  }
  display.display(plot_data(data, fig_title + f", Sample {sample_idx}", plot_size, robust=True))


In [None]:
# @title Plot ensemble mean and CRPS

def crps(targets, predictions, bias_corrected = True):
  if predictions.sizes.get("sample", 1) < 2:
    raise ValueError(
        "predictions must have dim 'sample' with size at least 2.")
  sum_dims = ["sample", "sample2"]
  preds2 = predictions.rename({"sample": "sample2"})
  num_samps = predictions.sizes["sample"]
  num_samps2 = (num_samps - 1) if bias_corrected else num_samps
  mean_abs_diff = np.abs(
      predictions - preds2).sum(
          dim=sum_dims, skipna=False) / (num_samps * num_samps2)
  mean_abs_err = np.abs(targets - predictions).sum(dim="sample", skipna=False) / num_samps
  return mean_abs_err - 0.5 * mean_abs_diff


plot_size = 5
variable = "2m_temperature"
level = None
steps = predictions.dims["time"]

fig_title = variable
if "level" in predictions[variable].coords:
  fig_title += f" at {level} hPa"

data = {
    "Targets": scale(select(eval_targets, variable, level, steps), robust=True),
    "Ensemble Mean": scale(select(predictions.mean(dim=["sample"]), variable, level, steps), robust=True),
    "Ensemble CRPS": scale(crps((select(eval_targets, variable, level, steps)),
                        select(predictions, variable, level, steps)),
                      robust=True, center=0),
}
display.display(plot_data(data, fig_title, plot_size, robust=True))

In [None]:
# @title (Optional) Save the predictions.
predictions.to_zarr("predictions.zarr")

# Train the model

The following operations requires larger amounts of memory than running inference.

The first time executing the cell takes more time, as it includes the time to jit the function.

In [None]:
# @title Loss computation
loss, diagnostics = loss_fn_jitted(
    jax.random.PRNGKey(0),
    train_inputs,
    train_targets,
    train_forcings)
print("Loss:", float(loss))

In [None]:
# @title Gradient computation
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")