<a href="https://colab.research.google.com/github/tsussi/Cloud-variability-time-frequency/blob/master/1yrplotmypanguweather.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Downloading, running, and analysing the PanguWeather NWP foundation model

**Note: This notebook creates some big data objects and it's easy to crash the session by exhausing the System RAM. Make sure you monitor the Resources and use the `del` command to remove big data objects as necessary**

This notebook is based on the code in the official PanguWeather github repo:

https://github.com/198808xc/Pangu-Weather

In [None]:
# The following works in google colab but might not work in other IDEs
#
# The file IDs were extracted from the google drive links provided in the
# pangu-weather repository
#
# download pangu_weather_24.onnx into /content/
!gdown 1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP
# download input_surface.npy into /content/
!gdown 1pj8QEVNpC1FyJfUabDpV4oU3NpSe0BkD
# download input_upper.npy into /content/
!gdown 1--7xEBJt79E3oixizr8oFmK_haDE77SS

Downloading...
From (original): https://drive.google.com/uc?id=1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP
From (redirected): https://drive.google.com/uc?id=1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP&confirm=t&uuid=47dcdd6e-ae35-452c-865b-ad7f519c31eb
To: /content/pangu_weather_24.onnx
100% 1.18G/1.18G [00:26<00:00, 45.4MB/s]
Downloading...
From: https://drive.google.com/uc?id=1pj8QEVNpC1FyJfUabDpV4oU3NpSe0BkD
To: /content/input_surface.npy
100% 16.6M/16.6M [00:00<00:00, 18.4MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1--7xEBJt79E3oixizr8oFmK_haDE77SS
From (redirected): https://drive.google.com/uc?id=1--7xEBJt79E3oixizr8oFmK_haDE77SS&confirm=t&uuid=b7760d6e-1742-42c8-a825-aefb8f23a197
To: /content/input_upper.npy
100% 270M/270M [00:06<00:00, 41.9MB/s]


In [None]:
!pip install onnx==1.17
#If we just run a model, we only need onnxruntime*.
!pip install onnxruntime==1.21.1
!pip install onnxruntime-gpu==1.21.1

Collecting onnx==1.17
  Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m119.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx
Successfully installed onnx-1.17.0
Collecting onnxruntime==1.21.1
  Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting coloredlogs (from onnxruntime==1.21.1)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime==1.21.1)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m 

In [None]:
import onnx
import onnxruntime as ort
import numpy as np

In [None]:
## this line is included in the original panguweather example code, but I found it's
## not actually needed...
## The directory of your input and output data
# model_24 = onnx.load('/content/pangu_weather_24.onnx')

In [None]:
# check if cuda is available for onnx runtime (Confirm we are using GPU)
device = ort.get_device()
print(device)

GPU


In [None]:
# Set the behavior of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena=False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
# Increase this number for faster inference and more memory consumption
options.intra_op_num_threads = 1


# Initialize onnxruntime session for Pangu-Weather Models
if device == 'GPU':
  ort_session_24 = ort.InferenceSession('/content/pangu_weather_24.onnx',
                                        sess_options=options,
                                        providers=[('CUDAExecutionProvider', {'arena_extend_strategy':'kSameAsRequested',})])
else:
  ort_session_24 = ort.InferenceSession('/content/pangu_weather_24.onnx',
                                        sess_options=options,
                                        providers=['CPUExecutionProvider'])

In [None]:
upper_ids = {
'2000-01-01': '1_BriXxo6hQfIsNqkYccF7ZKmSStzprDM',
'2000-01-02': '1hwjv48IGZwjVkssrAEtW9-edy4EjZuUk',
'2000-01-03': '1MwzFENTAmUfVHABPEJCuiLhuGi5_vGcy',
'2000-01-04': '1TJvqcqaZLCOjMVae_rLwiYeTj5sNTUTc',
'2000-01-05': '1yg1kzYejBTE1nsYJ-TVYu5tj5iPY7I4m',
'2000-01-06': '1fiRSqydIE9Tvjre1kWGgg34SANzosMnu',
'2000-01-07': '19pdrfPzl9ylvniY3VudFXh2U0f8OohSo',
'2000-01-08': '1Uu1JkAJkXKwCkJpn1XnDgJ0H-PaKhJ1l',
'2000-01-09': '1nXyJ8cGAjnzMFsB7llBNrk2di03JEa47',
'2000-01-10': '1Kfo2ZqctAUvPncQOP3x2UzRcPEgwpGIL',
'2000-01-11': '1NOiHyoOdQ8MamIIWEPurJJyLqbKvCZP1',
'2000-01-12': '1lmjABwyFQgXqzuJXiSGB3C6tgwHNHeEB',
'2000-01-13': '1fRiZXeLrm7bsjfOhq9K-mF7D62otJbje',
'2000-01-14': '1m061E8ibjvVJEK8XoXlSA0lRDtcwB47X',
'2000-01-15': '15K_CqYbA5irwhVIXHPHbGdW9d_urKiyA',
'2000-01-16': '1XeKMYEdH5OiF1kvrI4v49cUofKqqyO25',
'2000-01-17': '1IAzmkxMHnICUkP3_S-LwxZrssYm6WkB8',
'2000-01-18': '1BgCXJLz8tzgIdutraIWNdVqC5bhjSzRl',
'2000-01-19': '1vkLgNRe4ZTDC_gjbL0cK5rFkRDqlKHP5',
'2000-01-20': '1tB1E2oSa2g7Uh-97vnlGvqToRWa3vepG',
'2000-01-21': '1Z8dRhLHnJr1yUWgWu_IVoug5SDHVH-7j',
'2000-01-22': '1JVaqLRjTpL7h-mG14UNPEVxiNPqWozsl',
'2000-01-23': '15T9AAo64GrgIWT7_jRnYcm13eEdEI7hd',
'2000-01-24': '1VPwL69TyMG5LgzBVVzOz6GZ9bwASuh2T',
'2000-01-25': '1E1r9T1vCVAAz9U5-KJCgtbv8-DF7vmA1',
'2000-01-26': '1npeTU-dosRskrjS2zQz7lguFToV4cypT',
'2000-01-27': '1G98us1MyIAMK1iZ4ZvMyL0n6ap2w_FyD',
'2000-01-28': '1XdE03wMcR42-Rvf085vdhbsboEWD-NSc',
'2000-01-29': '1LeaoyDJCW33GY1vkqb4AUDduqNeqTxRZ',
'2000-01-30': '1TqZh83UxzlHwPmppHJs5MMlFHRGZCjyq',
'2000-01-31': '1tkVM1bvTMzfXRGIBoObHt2QjHpy_yJu4'
}

surface_ids = {
'2000-01-01': '1Wj4BJ2adcXVchWZ_C6Rkkxl5-AGRmqH2',
'2000-01-02': '19edCDwI-EKg8I5VUPR3xtm7Kl3UxrM2J',
'2000-01-03': '1q3tcLtKd7ZRcwUWC201nIV91ZfKQ3fAq',
'2000-01-04': '1Gkby-FraqujedNpgz-6QENE_V54XodEI',
'2000-01-05': '1db2F89Tsag7l4CHUjydhQMd1kquBbIde',
'2000-01-06': '1YMJhZBdLEtnC7w2XkbFSJU5zNduO-uQm',
'2000-01-07': '15H1YR7x5XfyNHf__NvNpvapVejguS8yz',
'2000-01-08': '1r2tF16TRUEFCApwqxnLH670iYGFmVIX8',
'2000-01-09': '1orVKqMyoKpcbpNh-VMW3H0DzKjzlkVRc',
'2000-01-10': '1av3WGMy5bhrpuUwyCOddRjSuqKmhC3lR',
'2000-01-11': '1DApynVbDWz0MHHK0-MHK0kyMUMKkQx6a',
'2000-01-12': '1JOhvq-UMsbvMVYf8pjePnQFwmywZRe0m',
'2000-01-13': '1XT2A3O5g2HhUpu8KQ5NcXC2_txmWTVXC',
'2000-01-14': '16WLHCwx_y5nQWvPs6oPBU5YWv3QfcS5e',
'2000-01-15': '1tHJt-N9EJFc7E4dSUEl-QJibx2Y1Izxt',
'2000-01-16': '1SwxMaGOAj3YZOG3c8brrN2PJqEOwmD6c',
'2000-01-17': '1orhxOzyjiiJkz4Jsa8emqREGXHxx0ByK',
'2000-01-18': '1giZ_ixzKNpNQPMTYQnmo7LmbwWEn_TRT',
'2000-01-19': '1zd13tSPr9-DomRw0AIoAy388NYh6CKZc',
'2000-01-20': '1_44TAp8Mu21cSOSJnf02kvdGF_yvZJU4',
'2000-01-21': '14h9Xzc8XVi5fwWopqZfbNo7whgBQxHGM',
'2000-01-22': '1f3pAFXnpjufuX_WaDVGWKOtwGIEx7RUb',
'2000-01-23': '14suWnpovT-hfulhvWaIy8SWUmEQyFnfo',
'2000-01-24': '1wWngvgNXAzcfmt6ze5tftnpc8l3Tuir8',
'2000-01-25': '12c4bKZz7JTgZVx5mjl3fl1wLH4D_SEc9',
'2000-01-26': '1go6MKMFdAq9ojPiNQQK1fXh1EWjPRtco',
'2000-01-27': '1HlGVittKwBJUrGsN5H-qHysGgMh4N2SX',
'2000-01-28': '1vcNmEfI2jJCjcd0oKKcvjOOPq5nl2k47',
'2000-01-29': '1K6oElZrsQ-5_0h32zchL9Q2tf22slsee',
'2000-01-30': '1nfXjB2k7rqX6UThvbz637ie1XP1ig3di',
'2000-01-31': '13Bl1B-g0uDw4drcO5RNEG9nCJX6u_CCo'
}

In [None]:
# get upper and surface id for a given date and download files
date = '2000-01-01'
upper_id = upper_ids[date]
surface_id = surface_ids[date]
!gdown $upper_id
!gdown $surface_id

Downloading...
From (original): https://drive.google.com/uc?id=1_BriXxo6hQfIsNqkYccF7ZKmSStzprDM
From (redirected): https://drive.google.com/uc?id=1_BriXxo6hQfIsNqkYccF7ZKmSStzprDM&confirm=t&uuid=9cc2331d-b69a-488c-82ab-1c70dafbee0b
To: /content/input_upper_20000101.npy
100% 270M/270M [00:04<00:00, 58.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Wj4BJ2adcXVchWZ_C6Rkkxl5-AGRmqH2
To: /content/input_surface_20000101.npy
100% 16.6M/16.6M [00:00<00:00, 39.0MB/s]


In [None]:
# load files into numpy arrays
inputs_upper = np.load('/content/input_upper_20000101.npy').astype(np.float32)
inputs_surface = np.load('/content/input_surface_20000101.npy').astype(np.float32)

In [None]:
longitudes_list = list(np.arange(0,360,0.25))
latitudes_list = list(np.arange(90, -90-0.25, 0.25))
# list(arr)

In [None]:
import matplotlib
from matplotlib import pyplot as plt

# For further study: One month of data

Below are links to individual panguweather input data for January 2000.


In [None]:
inputs_surface.shape

(4, 721, 1440)

In [None]:
inputs_upper.shape

(5, 13, 721, 1440)

In [None]:
import xarray as xr
dims = ("time", "variable", "lat", "lon")
coords = {
    "time": np.arange(180),
    "lat": np.arange(-90, 90.25, 0.25),
    "lon": np.arange(0, 360, 0.25),
    "variable": ["mslp", "u10", "v10", "t2m"]

}

In [None]:
'''Plus4K Experiment'''
#Example to make prediction for 30 days ahead.
# Initialize inputs
current_input = inputs_upper.copy()
current_input_surface = inputs_surface.copy()
current_input_surface[3,...] = current_input_surface[3,...]+4

# Store all outputs
outputs_p4k = []
outputs_surface_p4k = []

# Run autoregressive loop for 180 days
for step in range(180):
    # Run inference
    output_upper, output_surface = ort_session_24.run(None, {
        'input': current_input,
        'input_surface': current_input_surface
    })

    # Plot outputs and verify for every 30 days
    # if (step + 1) % 30 == 0:
    #   # Verification
    #   # download target data for verification
      # # Plotting
      # t2m_in = inputs_surface[3,...]
      # t2m_out = output_surface[3,...]
      # fig, axs = plt.subplots(1, 3, figsize=(15, 5))
      # axs[0].imshow(t2m_in, cmap='jet')
      # axs[0].set_title('Input Surface T2M 1st Jan2020+4K')
      # axs[1].imshow(t2m_out, cmap='jet')
      # axs[1].set_title('Output Surface T2M Jan '+str(step + 1)+ ' 2020')
      # axs[2].imshow(t2m_out - t2m_in, cmap='RdBu_r')
      # axs[2].set_title('difference')
      # plt.colorbar(axs[0].images[0], ax=axs[0], fraction=0.02, pad=0.05)
      # plt.colorbar(axs[1].images[0], ax=axs[1], fraction=0.02, pad=0.05)
      # plt.colorbar(axs[2].images[0], ax=axs[2], fraction=0.02, pad=0.05)
      # plt.tight_layout()
      # plt.show()
      # if (step + 1) == 30:
      #   date = '2000-01-30'
      #   upper_id = upper_ids[date]
      #   surface_id = surface_ids[date]
      #   !gdown $upper_id
      #   !gdown $surface_id
      #   # load data
      #   targets_upper = np.load('/content/input_upper_20000130.npy')
      #   targets_surface = np.load('/content/input_surface_20000130.npy')
      #   print('standard deviation')
      #   std = np.std(targets_upper, axis=(1,2,3))
      #   print(*std, sep='\n')
      #   print('output_upper vs targets_upper')
      #   rms_upper_output_targets = np.mean((output_upper - targets_upper)**2, axis=(1,2,3))
      #   print(*rms_upper_output_targets, sep='\n')
      #   print('normalized output_upper vs targets_upper')
      #   #result = a / b[:, None, None, None]
      #   print(*rms_upper_output_targets/std[:, None, None, None], sep='\n')
      #   # for comparison, also calculate MSE of the persistence forecast (tomorrow will be like today)
      #   print('inputs_upper vs targets_upper')
      #   rms_upper_input_targets = np.mean((inputs_upper - targets_upper)**2, axis=(1,2,3))
      #   print(*rms_upper_input_targets, sep='\n')
      #   print('normalized input_upper vs targets_upper')
      #   print(*rms_upper_input_targets/std[:,None, None, None], sep='\n')

    # Save outputs for plus4K
    outputs_p4k.append(output_upper)
    outputs_surface_p4k.append(output_surface)

    # Prepare inputs for next step (autoregressive)
    current_input = output_upper
    current_input_surface = output_surface

In [None]:
# Convert to DataArray first
outputs_surface_p4k_xr= xr.DataArray(outputs_surface_p4k, dims=dims, coords=coords)
# Then convert to Dataset by splitting the variable dimension
xrdataset_surface_p4k = outputs_surface_p4k_xr.to_dataset(dim="variable")
print(xrdataset_surface_p4k)

In [None]:
# Assume 'data' has dimensions ['time', 'lat', 'lon']
# and you want to calculate a time-averaged global mean
import numpy as np
weights = np.cos(np.deg2rad(xrdataset_surface['lat']))
weights.name = "weights"

# Broadcast weights to match dimensions
global_mean_t2m_p4k = xrdataset_surface_p4k["t2m"].weighted(weights).mean(dim=["lat", "lon"])
print(global_mean_t2m_p4k)

In [None]:
#Example to make prediction for 180 days ahead.
# Initialize inputs
current_input = inputs_upper.copy()
current_input_surface = inputs_surface.copy()

# Store all outputs
outputs = []
outputs_surface = []

# Run autoregressive loop for 180 days
for step in range(180):
    # Run inference
    output_upper, output_surface = ort_session_24.run(None, {
        'input': current_input,
        'input_surface': current_input_surface
    })

    # # Plot outputs and verify for every 30 days
    # if (step + 1) % 30 == 0:
    #   # Verification
    #   # download target data for verification
    #   # Plotting
    #   t2m_in = inputs_surface[3,...]
    #   t2m_out = output_surface[3,...]
    #   fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    #   axs[0].imshow(t2m_in, cmap='jet')
    #   axs[0].set_title('Input Surface T2M 1st Jan2020')
    #   axs[1].imshow(t2m_out, cmap='jet')
    #   axs[1].set_title('Output Surface T2M Jan '+str(step + 1)+ ' 2020')
    #   axs[2].imshow(t2m_out - t2m_in, cmap='RdBu_r')
    #   axs[2].set_title('difference')
    #   plt.colorbar(axs[0].images[0], ax=axs[0], fraction=0.02, pad=0.05)
    #   plt.colorbar(axs[1].images[0], ax=axs[1], fraction=0.02, pad=0.05)
    #   plt.colorbar(axs[2].images[0], ax=axs[2], fraction=0.02, pad=0.05)
    #   plt.tight_layout()
    #   plt.show()
    #   if (step + 1) == 30:
    #     date = '2000-01-30'
    #     upper_id = upper_ids[date]
    #     surface_id = surface_ids[date]
    #     !gdown $upper_id
    #     !gdown $surface_id
    #     # load data
    #     targets_upper = np.load('/content/input_upper_20000130.npy')
    #     targets_surface = np.load('/content/input_surface_20000130.npy')
    #     print('standard deviation')
    #     std = np.std(targets_upper, axis=(1,2,3))
    #     print(*std, sep='\n')
    #     print('output_upper vs targets_upper')
    #     rms_upper_output_targets = np.mean((output_upper - targets_upper)**2, axis=(1,2,3))
    #     print(*rms_upper_output_targets, sep='\n')
    #     print('normalized output_upper vs targets_upper')
    #     #result = a / b[:, None, None, None]
    #     print(*rms_upper_output_targets/std[:, None, None, None], sep='\n')
    #     # for comparison, also calculate MSE of the persistence forecast (tomorrow will be like today)
    #     print('inputs_upper vs targets_upper')
    #     rms_upper_input_targets = np.mean((inputs_upper - targets_upper)**2, axis=(1,2,3))
    #     print(*rms_upper_input_targets, sep='\n')
    #     print('normalized input_upper vs targets_upper')
    #     print(*rms_upper_input_targets/std[:,None, None, None], sep='\n')

    # Save outputs
    outputs.append(output_upper)
    outputs_surface.append(output_surface)

    # Prepare inputs for next step (autoregressive)
    current_input = output_upper
    current_input_surface = output_surface


In [None]:
import xarray as xr
dims = ("time", "variable", "lat", "lon")
coords = {
    "time": np.arange(180),
    "lat": np.arange(-90, 90.25, 0.25),
    "lon": np.arange(0, 360, 0.25),
    "variable": ["mslp", "u10", "v10", "t2m"]

}
# Convert to DataArray first
outputs_surface_xr= xr.DataArray(outputs_surface, dims=dims, coords=coords)
# Then convert to Dataset by splitting the variable dimension
xrdataset_surface = outputs_surface_xr.to_dataset(dim="variable")
print(xrdataset_surface)

In [None]:
# Assume 'data' has dimensions ['time', 'lat', 'lon']
# and you want to calculate a time-averaged global mean

weights = np.cos(np.deg2rad(xrdataset_surface['lat']))
weights.name = "weights"

# Broadcast weights to match dimensions
global_mean_t2m = xrdataset_surface["t2m"].weighted(weights).mean(dim=["lat", "lon"])
print(global_mean_t2m)

In [None]:
import matplotlib
from matplotlib import pyplot as plt
plt.plot(global_mean_t2m)
plt.plot(global_mean_t2m_p4k)

# MSE comparison

Below is example code to download ERA5 data for one date, make a one-day ahead prediction with panguweather_24 and compare the output to ERA5 data for the following day.

In [None]:
# make pangu weather prediction
outputs_upper, outputs_surface = ort_session_24.run(None, {'input':inputs_upper, 'input_surface':inputs_surface})

In [None]:
# download target data for verification
date = '2000-01-30'
upper_id = upper_ids[date]
surface_id = surface_ids[date]
!gdown $upper_id
!gdown $surface_id

In [None]:
targets_upper = np.load('/content/input_upper_20000130.npy')
targets_surface = np.load('/content/input_surface_20000130.npy')

In [None]:
print(*np.mean((outputs_upper - targets_upper)**2, axis=(1,2,3)), sep='\n')

In [None]:
# for comparison, also calculate MSE of the persistence forecast (tomorrow will be like today)
print(*np.mean((inputs_upper - targets_upper)**2, axis=(1,2,3)), sep='\n')

In [None]:
# persistence mse is significantly higher than pangu weather mse
# roughly by a factor 10, indicating skillful predictions from
# panguweather that comfortably outperform the no-change forecast