This is a pytorch implementation of google's GraphCast with example training script and ready to use dataloader.
To fit the model in a device with 40GB of memory, I excluded the destination node feature in the edge_mlp. Checkpointing is applied on every singler layer, please modify base on your system setup.
Data need to be stored in zarr format, placed in the following manner:
# -- data_root
# --atomospherical_data_dir
# --1979.zarr
# --.....zarr
# --2018.zarr
# --surface_data_dir
# --1979.zarr
# --.....zarr
# --2018.zarr
modify the config file in the config directory to modify resolution/pressure_level/variables.
Notice that regriding/bilinear subsampling is not supported with current dataset.