In [23]:
# @title Imports

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
#from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray




def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


In [24]:
# @title Choose the model
# Rewrite by S.F. Sune, https://github.com/sfsun67.
'''
    We have three options. Acquiring from https://console.cloud.google.com/storage/browser/dm_graphcast:
    GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz
    GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz
    GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
'''
# find the result in this path /root/data/params, and list of names of all files in the "params/", with the "params/" perfix removed from the names.
import os
import glob

# Define the directory path
dir_path_params = "/root/data/params"

# Use glob to get all file paths in the directory
file_paths_params = glob.glob(os.path.join(dir_path_params, "*"))

# Remove the directory path and the ".../params/" prefix from each file name
params_file_options = [os.path.basename(path) for path in file_paths_params]


random_mesh_size = widgets.IntSlider(
    value=4, min=4, max=6, description="Mesh size:")
random_gnn_msg_steps = widgets.IntSlider(
    value=4, min=1, max=32, description="GNN message steps:")
random_latent_size = widgets.Dropdown(
    options=[int(2**i) for i in range(4, 10)], value=32,description="Latent size:")
random_levels = widgets.Dropdown(
    options=[13, 37], value=13, description="Pressure levels:")


params_file = widgets.Dropdown(
    options=params_file_options,
    description="Params file:",
    layout={"width": "max-content"})

source_tab = widgets.Tab([
    widgets.VBox([
        random_mesh_size,
        random_gnn_msg_steps,
        random_latent_size,
        random_levels,
    ]),
    params_file,
])
source_tab.set_title(0, "Random")
source_tab.set_title(1, "Checkpoint")
widgets.VBox([
    source_tab,
    widgets.Label(value="Run the next cell to load the model. Rerunning this cell clears your selection.")
])


VBox(children=(Tab(children=(VBox(children=(IntSlider(value=4, description='Mesh size:', max=6, min=4), IntSli…

In [25]:
# @title Load the model

source = source_tab.get_title(source_tab.selected_index)

if source == "Random":
  params = None  # Filled in below
  state = {}
  model_config = graphcast.ModelConfig(
      resolution=0,
      mesh_size=random_mesh_size.value,
      latent_size=random_latent_size.value,
      gnn_msg_steps=random_gnn_msg_steps.value,
      hidden_layers=1,
      radius_query_fraction_edge_length=0.6)
  task_config = graphcast.TaskConfig(
      input_variables=graphcast.TASK.input_variables,
      target_variables=graphcast.TASK.target_variables,
      forcing_variables=graphcast.TASK.forcing_variables,
      pressure_levels=graphcast.PRESSURE_LEVELS[random_levels.value],
      input_duration=graphcast.TASK.input_duration,
  )
else:
  assert source == "Checkpoint"
  '''with gcs_bucket.blob(f"params/{params_file.value}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)'''
  
  with open(f"{dir_path_params}/{params_file.value}", "rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
  
  params = ckpt.params
  state = {}

  model_config = ckpt.model_config
  task_config = ckpt.task_config
  print("Model description:\n", ckpt.description, "\n")
  print("Model license:\n", ckpt.license, "\n")

model_config

ModelConfig(resolution=0, mesh_size=4, latent_size=32, gnn_msg_steps=4, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)

In [26]:
# @title Get and filter the list of available example datasets
# Rewrite by S.F. Sune, https://github.com/sfsun67.
# find the result in this path /root/data/dataset, and list of names of all files in the "dataset/", with the "dataset/" perfix removed from the names.

# Define the directory path
dir_path_dataset = "/root/data/dataset"

# Use glob to get all file paths in the directory
file_paths_dataset = glob.glob(os.path.join(dir_path_dataset, "*"))

# Remove the directory path and the ".../params/" prefix from each file name
dataset_file_options = [os.path.basename(path) for path in file_paths_dataset]
#print("dataset_file_options: ", dataset_file_options)

# Remove "dataset-" prefix from each file name
dataset_file_options = [name.removeprefix("dataset-") for name in dataset_file_options]


def data_valid_for_model(
    file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
  file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  #print("file_parts: ", file_parts)
  return (
      model_config.resolution in (0, float(file_parts["res"])) and
      len(task_config.pressure_levels) == int(file_parts["levels"]) and
      (
          ("total_precipitation_6hr" in task_config.input_variables and
           file_parts["source"] in ("era5", "fake")) or
          ("total_precipitation_6hr" not in task_config.input_variables and
           file_parts["source"] in ("hres", "fake"))
      )
  )


dataset_file = widgets.Dropdown(
    options=[
        (", ".join([f"{k}: {v}" for k, v in parse_file_parts(option.removesuffix(".nc")).items()]), option)
        for option in dataset_file_options
        if data_valid_for_model(option, model_config, task_config)
    ],
    description="Dataset file:",
    layout={"width": "max-content"})
widgets.VBox([
    dataset_file,
    widgets.Label(value="Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.")
])

VBox(children=(Dropdown(description='Dataset file:', layout=Layout(width='max-content'), options=(('source: er…

In [29]:
# @title Load weather data

if not data_valid_for_model(dataset_file.value, model_config, task_config):
  raise ValueError(
      "Invalid dataset file, rerun the cell above and choose a valid dataset file.")

'''with gcs_bucket.blob(f"dataset/{dataset_file.value}").open("rb") as f:
  example_batch = xarray.load_dataset(f).compute()'''

with open(f"{dir_path_dataset}/dataset-{dataset_file.value}", "rb") as f:
  example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file.value.removesuffix(".nc")).items()]))

example_batch

source: era5, date: 2022-01-01, res: 1.0, levels: 13, steps: 01


In [30]:
print(example_batch)


<xarray.Dataset>
Dimensions:                       (lon: 360, lat: 181, level: 13, time: 3,
                                   batch: 1)
Coordinates:
  * lon                           (lon) float32 0.0 1.0 2.0 ... 358.0 359.0
  * lat                           (lat) float32 -90.0 -89.0 -88.0 ... 89.0 90.0
  * level                         (level) int32 50 100 150 200 ... 850 925 1000
  * time                          (time) timedelta64[ns] 00:00:00 ... 12:00:00
    datetime                      (batch, time) datetime64[ns] 2022-01-01 ......
Dimensions without coordinates: batch
Data variables: (12/14)
    geopotential_at_surface       (lat, lon) float32 2.735e+04 ... -0.07617
    land_sea_mask                 (lat, lon) float32 1.0 1.0 1.0 ... 0.0 0.0 0.0
    2m_temperature                (batch, time, lat, lon) float32 250.7 ... 2...
    mean_sea_level_pressure       (batch, time, lat, lon) float32 9.931e+04 ....
    10m_v_component_of_wind       (batch, time, lat, lon) float32 -0.4393

In [11]:
# load example_batch from .csv file
import pandas as pd
import xarray as xr

# Load the data into a DataFrame
df = pd.read_csv("/root/data/SGP_test.csv", encoding='latin1')

# Convert the DataFrame to an xarray Dataset
ds = xr.Dataset.from_dataframe(df)

# 当 interpreted age 为 nan 时，删去该行
ds = ds.dropna(dim='index', subset=['interpreted age'])

In [12]:
# Function to process a part of the dataset

sedimentary_list = []
def groupby_and_average(ds):
    '''
    # 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均
    '''
    for site_longitude_value, site_longitude in ds.groupby("site longitude"):
        for site_latitude_value, site_latitude in site_longitude.groupby("site latitude"):
            for interpreted_age_value, sedimentary in site_latitude.groupby("interpreted age"):
                #sedimentary_dict = sedimentary.apply(np.mean).to_dict() 
                sedimentary_list.append(sedimentary.apply(np.mean))
    
    # Add an identifying dimension to each xr.Dataset of sedimentary_list 
    for i, ds in enumerate(sedimentary_list):
        ds = ds.expand_dims({'sample': [i]})

    # Concatenate the datasets
    combined = xr.concat(sedimentary_list, dim='index')


    return combined, site_longitude_value, site_latitude_value, interpreted_age_value

In [13]:
import multiprocessing as mp
from tqdm import tqdm


# Divide the dataset into parts
part_number = 6
dim = 'index'  # replace with your actual dimension
dim_size = ds.dims[dim]
indices = np.linspace(0, dim_size, part_number+1).astype(int)
parts = [ds.isel({dim: slice(indices[i], indices[i + 1])}) for i in range(part_number)]

# Create a multiprocessing Pool
pool = mp.Pool(mp.cpu_count())

# Process each part of the dataset in parallel with a progress bar
print('Processing SGP datasets, replacing duplicates with averages ...')
results = []
with tqdm(total=len(parts)) as pbar:
    for result in pool.imap_unordered(groupby_and_average, parts):
        results.append(result)
        pbar.update(1)

# Close the pool
pool.close()

# To combine multiple xarray.Dataset objects
result_list = [result[0] for result in results]
combined = xr.concat(result_list, dim='index')
combined

Processing SGP datasets, replacing duplicates with averages ...


100%|██████████| 6/6 [00:06<00:00,  1.01s/it]


In [20]:
# Create the new xr.Dataset
# When copy, notice that deep copy and shallow copy.

# define the resolution of longitiude and latitude
resolution = 1
batch = 1

# Create the dimensions
dims = {
    "lon": int(360/resolution),
    "lat": int(181/resolution),
    "level": 13,
    "time": len(combined["index"]),
}

# Create the coordinates
coords = {
    "lon": np.linspace(0, 359, int(dims["lon"] - (1/resolution - 1))),
    "lat": np.linspace(-90, 90, int(dims["lat"] - (1/resolution - 1))),
    "level": np.arange(50, 1050, 75),
    "time": pd.date_range("2000-01-01", periods=dims["time"]),
    "datetime": pd.date_range("2000-01-01", periods=dims["time"]),
}

# Create the new dataset
SGP_dataset = xr.Dataset(coords=coords)

In [21]:
print("SGP_dataset=", SGP_dataset)

SGP_dataset= <xarray.Dataset>
Dimensions:   (lon: 360, lat: 181, level: 14, time: 664, datetime: 664)
Coordinates:
  * lon       (lon) float64 0.0 1.0 2.0 3.0 4.0 ... 356.0 357.0 358.0 359.0
  * lat       (lat) float64 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 90.0
  * level     (level) int64 50 125 200 275 350 425 ... 650 725 800 875 950 1025
  * time      (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2001-10-25
  * datetime  (datetime) datetime64[ns] 2000-01-01 2000-01-02 ... 2001-10-25
Data variables:
    *empty*


In [11]:
# Rewrite the lon and lat according the resulation of dataset.

def Rewrite_lon_lat(data, resolution):
    '''
    根据 xarray 数据集中的分辨率 重写 lon 和 lat 
    Rewrite the lon and lat according the resulation of dataset.
    data: the original data
    resolution: the resolution of the data
    '''
    condition_number = int(1/resolution)
    data["site latitude"].data = np.round(data["site latitude"].data * condition_number) / condition_number
    data["site longitude"].data = np.round(data["site longitude"].data * condition_number) / condition_number

    return data


In [12]:
# Rewrite the lon and lat according the dataset of xarray.
combined = Rewrite_lon_lat(combined, resolution)
combined


In [13]:
print("combined=", combined)

combined= <xarray.Dataset>
Dimensions:                 (index: 664)
Dimensions without coordinates: index
Data variables: (12/103)
    site latitude           (index) float64 -24.0 -25.0 -25.0 ... -30.0 -27.0
    site longitude          (index) float64 16.0 16.0 16.0 ... 139.0 142.0 153.0
    interpreted age         (index) float64 549.0 548.2 548.0 ... 105.0 435.0
    Ag (ppm)                (index) float64 nan nan nan nan ... nan nan nan nan
    Al (wt%)                (index) float64 nan nan nan nan ... nan nan 8.646
    As (ppm)                (index) float64 nan nan nan nan ... nan nan nan nan
    ...                      ...
    C:N (atomic)            (index) float64 nan nan nan nan ... nan nan nan nan
    Delta13C-org (permil)   (index) float64 nan nan nan nan ... nan nan nan nan
    Delta15N (permil)       (index) float64 nan nan nan nan ... nan nan nan nan
    Delta98Mo (permil)      (index) float64 nan nan nan nan ... nan nan nan nan
    Delta34S-py (permil)    (index) float

In [14]:
# Create a simple xarray Dataset
data = xr.Dataset(
    {
        "Ag (ppm)": (("time", "lat", "lon"), np.random.rand(3, 2, 2)),
    },
    coords={
        "lon": [10, 20],
        "lat": [30, 40],
        "time": pd.date_range("2000-01-01", periods=3),
    },
)

# Manually set a value to 3.234 for demonstration
data["Ag (ppm)"][0, 0, 0] = 3.234

print(data)

<xarray.Dataset>
Dimensions:   (time: 3, lat: 2, lon: 2)
Coordinates:
  * lon       (lon) int64 10 20
  * lat       (lat) int64 30 40
  * time      (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03
Data variables:
    Ag (ppm)  (time, lat, lon) float64 3.234 0.8357 0.6475 ... 0.8129 0.07106


In [14]:
# Adjust the time.



# Get the dimensions from SGP_dataset
dims = SGP_dataset.dims

# Add the variables from the combined dataset to the new dataset
for var in combined.data_vars:
    # Skip the variables that are Coordinates.
    if var == "site latitude" or var == "site longitude" or var == "interpreted age":
        continue
    # Create a new DataArray with the same data but new dimensions
    new_dataarray = xr.DataArray(
        data=np.broadcast_to(combined[var].values, (dims["batch"], dims["time"], dims["lat"], dims["lon"], dims["level"], dims["index"])),
        dims=["batch", "time", "lat", "lon", "level", "index"],
        coords={"batch": SGP_dataset["batch"], "time": SGP_dataset["time"], "lat": SGP_dataset["lat"], "lon": SGP_dataset["lon"], "level": SGP_dataset["level"], "index": SGP_dataset["index"]}
    )
    # Add the new DataArray to the new dataset
    SGP_dataset[var] = new_dataarray

SGP_dataset["time"] = combined['interpreted age']

In [19]:
# 这里的标签"time"是错误的。需要重新做索引。然后看看能不能遍历数组，给 Origen 赋值。

SGP_dataset["Al (wt%)"][0,663,-46,168,0]

In [15]:
SGP_dataset

In [None]:
example_batch["datetime"] = combined['interpreted age']
example_batch["longitude"] = combined['site longitude']
example_batch["latitude"] = combined['site latitude']


example_batch["mean_sea_level_pressure"] = combined['Ca (wt%)']
example_batch["10m_v_component_of_wind"] = combined['Ce (ppm)']

example_batch

In [20]:
# @title Choose training and eval data to extract
train_steps = widgets.IntSlider(
    value=1, min=1, max=example_batch.sizes["time"]-2, description="Train steps")
eval_steps = widgets.IntSlider(
    value=example_batch.sizes["time"]-2, min=1, max=example_batch.sizes["time"]-2, description="Eval steps")

widgets.VBox([
    train_steps,
    eval_steps,
    widgets.Label(value="Run the next cell to extract the data. Rerunning this cell clears your selection.")
])

VBox(children=(IntSlider(value=1, description='Train steps', max=1, min=1), IntSlider(value=1, description='Ev…

In [21]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps.value*6}h"),
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


All Examples:   {'lon': 360, 'lat': 181, 'level': 13, 'time': 3, 'batch': 1}
Train Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360, 'level': 13}
Train Forcings: {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360}
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Eval Targets:   {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360, 'level': 13}
Eval Forcings:  {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360}


In [22]:
# @title Load normalization data
# Rewrite by S.F. Sune, https://github.com/sfsun67.
dir_path_stats = "/root/data/stats"

with open(f"{dir_path_stats}/stats-diffs_stddev_by_level.nc", "rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(f"{dir_path_stats}/stats-mean_by_level.nc", "rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with open(f"{dir_path_stats}/stats-stddev_by_level.nc", "rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()

In [23]:
# @title Build jitted functions, and possibly initialize random weights
# Construct the model and initialize the weights.
# 构建模型并初始化权重

# 模型组网
# Construct the model
def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor

# 前向运算
# forward
@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)

# 计算损失函数
# loss function
@hk.transform_with_state    # used to convert a pure function into a stateful function
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)    # constructs and wraps a GraphCast Predictor, which is a model used for making predictions in a graph-based machine learning task.
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

# 计算梯度
# gradient
def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets_template=train_targets,
      forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))



# 测试：1. 对数据集进行迭代    2. 权重保存与加载

In [None]:
# 对数据集进行迭代   模型为原始的。数据为步长40的数据集, train step = 2
for i in range(39):
    example_batch_slice = example_batch.isel(time=slice(i, 4+i))
    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
        example_batch_slice, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
        **dataclasses.asdict(task_config))
    # @title Gradient computation (backprop through time)
    loss, diagnostics, next_state, grads = grads_fn_jitted(
        inputs=train_inputs,
        targets=train_targets,
        forcings=train_forcings)
    mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
    print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")



In [None]:
# 对数据集进行迭代  0-20
for i in range(20):
    example_batch_slice = example_batch.isel(time=slice(i, 4+i))
    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
        example_batch_slice, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
        **dataclasses.asdict(task_config))
    # @title Gradient computation (backprop through time)
    loss, diagnostics, next_state, grads = grads_fn_jitted(
        inputs=train_inputs,
        targets=train_targets,
        forcings=train_forcings)
    mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
    print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

In [None]:
# ckpt 的已有数据
params_new = grads    # 由 grads_fn_jitted 计算得到
model_config = model_config    # 由模型加载得到   # @title Load the model
task_config = task_config    # 由模型加载得到   # @title Load the model 
description='\nGraphCast model ...(输入你的陈述)\n'
license='\nThe model weights are licensed 输入数据集的license\n'



In [None]:
# @title Save the model   by S.F. Sune  
import io

ckpt = save_model.Checkpoint(
    params = params_new,
    model_config = model_config,
    task_config = task_config,
    description = description,
    license = license
    )

buffer = io.BytesIO()    # 创建一个内存文件对象 creat a memory file object
checkpoint.dump(buffer, ckpt)
buffer.seek(0)    # 移动文件指针到文件的开头，便于读取 move the file pointer to the beginning of the file， to facilitate reading

# 保存buffer为.npy到本地 save buffer of .npy to local
with open("/root/data/params/params-GraphCast_test.npy", "wb") as f:
  f.write(buffer.read())
f.close()    # 关闭文件 close file

In [None]:
# 加载训练20次的模型

with open("/root/data/params/params-GraphCast_test.npy", "rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

params = ckpt.params
state = {}

model_config = ckpt.model_config
task_config = ckpt.task_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")
params

In [None]:
model_config

In [None]:
#继续迭代
for i in range(20,39):
    example_batch_slice = example_batch.isel(time=slice(i, 4+i))
    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
        example_batch_slice, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
        **dataclasses.asdict(task_config))
    # @title Gradient computation (backprop through time)
    loss, diagnostics, next_state, grads = grads_fn_jitted(
        inputs=train_inputs,
        targets=train_targets,
        forcings=train_forcings)
    mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
    print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

# Run the model

Note that the cell below may take a while (possibly minutes) to run the first time you execute them, because this will include the time it takes for the code to compile. The second time running will be significantly faster.

This use the python loop to iterate over prediction steps, where the 1-step prediction is jitted. This has lower memory requirements than the training steps below, and should enable making prediction with the small GraphCast model on 1 deg resolution data for 4 steps.

In [24]:
# @title Autoregressive rollout (loop in python)

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Targets:  {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360, 'level': 13}
Forcings: {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360}




# Train the model

The following operations require a large amount of memory and, depending on the accelerator being used, will only fit the very small "random" model on low resolution data. It uses the number of training steps selected above.

The first time executing the cell takes more time, as it include the time to jit the function.

In [None]:
# @title Loss computation (autoregressive loss over multiple steps)
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)

print("Loss:", float(loss))

In [26]:
# @title Gradient computation (backprop through time)
loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

Loss: 14.4468, Mean |grad|: 0.096610


In [None]:
# @title Autoregressive rollout (keep the loop in JAX)
print("Inputs:  ", train_inputs.dims.mapping)
print("Targets: ", train_targets.dims.mapping)
print("Forcings:", train_forcings.dims.mapping)

predictions = run_forward_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)
predictions